From a0e750360af13c6511d8ad95df4610d87e976773 Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Fri, 23 Oct 2020 14:41:47 -0700 Subject: [PATCH] [XLA] Fix elemental_ir_emitter access control: * All data members are private now, and accessed through methods by derived classes. * MakeElementGenerator isn't virtual anymore. Instead, the customization point is moved to EmitConvolution. PiperOrigin-RevId: 338747947 Change-Id: I3a8382eed86f1104b46226afd50a158dddacbf3c --- .../xla/service/cpu/elemental_ir_emitter.cc | 36 +++++------- .../xla/service/cpu/elemental_ir_emitter.h | 12 ++-- .../xla/service/elemental_ir_emitter.cc | 11 ++++ .../xla/service/elemental_ir_emitter.h | 57 +++++++++---------- .../xla/service/gpu/elemental_ir_emitter.cc | 33 +++++------ .../xla/service/gpu/elemental_ir_emitter.h | 2 + 6 files changed, 80 insertions(+), 71 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index 05364a4492b..b15aa3689b7 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -41,8 +41,8 @@ StatusOr CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, switch (prim_type) { case F16: cast_result_to_fp16 = true; - lhs = FPCast(lhs, b_->getFloatTy()); - rhs = FPCast(rhs, b_->getFloatTy()); + lhs = FPCast(lhs, b()->getFloatTy()); + rhs = FPCast(rhs, b()->getFloatTy()); TF_FALLTHROUGH_INTENDED; case F32: function_name = "atan2f"; @@ -55,7 +55,7 @@ StatusOr CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, } // Create a function declaration. llvm::Function* function = llvm::dyn_cast( - module_ + module() ->getOrInsertFunction(function_name, lhs->getType(), lhs->getType(), rhs->getType()) .getCallee()); @@ -65,7 +65,7 @@ StatusOr CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, // Create an instruction to call the function. llvm::Value* result = Call(function, {lhs, rhs}); if (cast_result_to_fp16) { - result = FPCast(result, b_->getHalfTy()); + result = FPCast(result, b()->getHalfTy()); } return result; } @@ -77,7 +77,7 @@ StatusOr CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, switch (prim_type) { case F16: cast_result_to_fp16 = true; - value = FPCast(value, b_->getFloatTy()); + value = FPCast(value, b()->getFloatTy()); TF_FALLTHROUGH_INTENDED; case F32: function_name = "tanhf"; @@ -90,7 +90,7 @@ StatusOr CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, } // Create a function declaration. llvm::Function* function = llvm::dyn_cast( - module_ + module() ->getOrInsertFunction(function_name, value->getType(), value->getType()) .getCallee()); @@ -100,26 +100,20 @@ StatusOr CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, // Create an instruction to call the function. llvm::Value* result = Call(function, value); if (cast_result_to_fp16) { - result = FPCast(result, b_->getHalfTy()); + result = FPCast(result, b()->getHalfTy()); } return result; } -llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( +StatusOr CpuElementalIrEmitter::EmitConvolution( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) { - switch (hlo->opcode()) { - case HloOpcode::kConvolution: - return [this, hlo, &operand_to_generator](const IrArray::Index& index) { - return ir_emitter_->EmitElementalConvolution( - Cast(hlo), - operand_to_generator.at(hlo->operand(0)), - operand_to_generator.at(hlo->operand(1)), index); - }; - default: - return ElementalIrEmitter::MakeElementGenerator(hlo, - operand_to_generator); - } + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) { + return ir_emitter_->EmitElementalConvolution( + Cast(hlo), + operand_to_generator.at(hlo->operand(0)), + operand_to_generator.at(hlo->operand(1)), index); } + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index 4c3167e16d9..fbf582d3a8b 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -31,18 +31,19 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { public: CpuElementalIrEmitter(const HloModuleConfig& module_config, IrEmitter* ir_emitter, llvm::Module* module) - : ElementalIrEmitter(module_config, module, ir_emitter->b()), + : ElementalIrEmitter(module, ir_emitter->b()), + hlo_module_config_(module_config), ir_emitter_(ir_emitter) {} - llvm_ir::ElementGenerator MakeElementGenerator( - const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) override; - protected: StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) override; StatusOr EmitTanh(PrimitiveType prim_type, llvm::Value* value) override; + StatusOr EmitConvolution( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) override; StatusOr> EmitThreadLocalCall( const HloComputation& callee, absl::Span parameters, @@ -54,6 +55,7 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { return hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max(); } + const HloModuleConfig& hlo_module_config_; IrEmitter* ir_emitter_; }; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 3a449b7c2db..d3e00d04dfd 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -2486,6 +2486,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return EmitElementalReduce(reduce_instr, std::move(input_generators), std::move(initial_value_generators), index); }; + case HloOpcode::kConvolution: + return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + return EmitConvolution(hlo, operand_to_generator, index); + }; default: return [hlo](const IrArray::Index& index) { return Unimplemented("Unhandled opcode for elemental IR emission: %s", @@ -2730,6 +2734,13 @@ StatusOr ElementalIrEmitter::EmitElementalReduce( } } +StatusOr ElementalIrEmitter::EmitConvolution( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) { + return Unimplemented("Elemental convolution is not implemented"); +} + // Evaluate polynomial using Horner's method. StatusOr ElementalIrEmitter::EvaluatePolynomial( llvm::Type* type, llvm::Value* x, absl::Span coefficients) { diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 365e3f56b85..56833159647 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -26,7 +26,6 @@ limitations under the License. #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" @@ -39,22 +38,14 @@ class ElementalIrEmitter : public IrBuilderMixin { using HloToElementGeneratorMap = std::unordered_map; - ElementalIrEmitter(const HloModuleConfig& hlo_module_config, - llvm::Module* module, llvm::IRBuilder<>* b) - : b_(b), module_(module), hlo_module_config_(hlo_module_config) {} + ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b) + : b_(b), module_(module) {} virtual ~ElementalIrEmitter() = default; - virtual StatusOr EmitUnaryOp(const HloInstruction* op, - llvm::Value* operand_value); - - virtual StatusOr EmitBinaryOp(const HloInstruction* op, - llvm::Value* lhs_value, - llvm::Value* rhs_value); - // Returns a function to generate an element of the output of `hlo`, given a // map of functions to generate elements of its operands. - virtual llvm_ir::ElementGenerator MakeElementGenerator( + llvm_ir::ElementGenerator MakeElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator); @@ -66,6 +57,21 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm::Module* module() { return module_; } protected: + virtual StatusOr EmitFloatBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value); + + virtual llvm::Value* EmitExtractReal(llvm::Value* value); + virtual llvm::Value* EmitExtractImag(llvm::Value* value); + + private: + virtual StatusOr EmitUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); + + virtual StatusOr EmitBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value); + virtual StatusOr EmitIntegerUnaryOp(const HloInstruction* op, llvm::Value* operand_value); @@ -92,10 +98,6 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm::Value* rhs_value, bool is_signed); - virtual StatusOr EmitFloatBinaryOp(const HloInstruction* op, - llvm::Value* lhs_value, - llvm::Value* rhs_value); - virtual StatusOr EmitComplexBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value); @@ -175,9 +177,6 @@ class ElementalIrEmitter : public IrBuilderMixin { PrimitiveType prim_type, llvm::Value* operand_value); - virtual llvm::Value* EmitExtractReal(llvm::Value* value); - virtual llvm::Value* EmitExtractImag(llvm::Value* value); - // Composes a complex struct. imag may be nullptr for simple cast operations. llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, llvm::Value* imag); @@ -245,17 +244,11 @@ class ElementalIrEmitter : public IrBuilderMixin { std::vector initial_value_generators, const llvm_ir::IrArray::Index& index); - virtual bool fast_min_max() = 0; + virtual StatusOr EmitConvolution( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index); - llvm::IRBuilder<>* const b_; - - llvm::Module* module_; - - // The HloModuleConfig which gathers all settings and values which affect the - // compiled executable outside of the HLO code itself. - const HloModuleConfig& hlo_module_config_; - - private: // Computes the complex power function, returns (a + i*b)^(c + i*d). StatusOr EmitComplexPower(const HloInstruction* op, llvm::Value* a, llvm::Value* b, @@ -264,6 +257,12 @@ class ElementalIrEmitter : public IrBuilderMixin { // Evaluates a polynomial using Horner's method. StatusOr EvaluatePolynomial( llvm::Type* type, llvm::Value* x, absl::Span coefficients); + + virtual bool fast_min_max() = 0; + + llvm::IRBuilder<>* const b_; + + llvm::Module* module_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 3f000a2491d..e72c12813b7 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -72,7 +72,8 @@ bool IsFPLiteralWithValue(const HloInstruction* operand, float value) { GpuElementalIrEmitter::GpuElementalIrEmitter( const HloModuleConfig& hlo_module_config, llvm::Module* module, llvm::IRBuilder<>* b, NestedComputer compute_nested) - : ElementalIrEmitter(hlo_module_config, module, b), + : ElementalIrEmitter(module, b), + hlo_module_config_(hlo_module_config), compute_nested_(std::move(compute_nested)) {} StatusOr GpuElementalIrEmitter::EmitDeviceMathCall( @@ -91,7 +92,7 @@ StatusOr GpuElementalIrEmitter::EmitDeviceMathCall( for (int64 i = 0; i < operands.size(); ++i) { if (input_types[i] == F16) { converted_operands[i] = - FPCast(converted_operands[i], b_->getFloatTy()); + FPCast(converted_operands[i], b()->getFloatTy()); converted_input_types[i] = F32; } } @@ -106,12 +107,12 @@ StatusOr GpuElementalIrEmitter::EmitDeviceMathCall( PrimitiveType_Name(output_type)); } const string& munged_callee = - ObtainDeviceFunctionName(funcid, output_type, b_); + ObtainDeviceFunctionName(funcid, output_type, b()); llvm::Value* result = EmitMathCall(munged_callee, converted_operands, converted_input_types, output_type) .ValueOrDie(); if (cast_result_to_fp16) { - result = FPCast(result, b_->getHalfTy()); + result = FPCast(result, b()->getHalfTy()); } return result; } @@ -153,7 +154,7 @@ StatusOr GpuElementalIrEmitter::EmitMathCall( return EmitDeviceFunctionCall( callee_name, operands, input_types, output_type, - {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b_); + {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b()); } StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( @@ -168,7 +169,7 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( return llvm_ir::EmitCallToIntrinsic( opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum : llvm::Intrinsic::minnum, - {lhs_value, rhs_value}, {lhs_value->getType()}, b_); + {lhs_value, rhs_value}, {lhs_value->getType()}, b()); } switch (op->opcode()) { @@ -275,19 +276,19 @@ StatusOr GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, // This routine isn't numerically precise, but it's good enough for ML. // Upcast F16 to F32 if necessary. - llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType(); + llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType(); llvm::Value* input = FPCast(value, type); // If |value| >= kMaxValue, tanh() is set to -1.0 or 1.0. constexpr double kMaxValue = 20.0; auto max_value = llvm::ConstantFP::get(type, kMaxValue); llvm::Value* abs_value = - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b_); + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b()); - llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); + llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b(), input); auto one = llvm::ConstantFP::get(type, 1.0); auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign, - {one, input}, {type}, b_); + {one, input}, {type}, b()); return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign), value->getType(), "tanh"); } @@ -301,14 +302,14 @@ StatusOr GpuElementalIrEmitter::EmitComplexAbs( llvm::Value* GpuElementalIrEmitter::EmitThreadId() { llvm::Value* block_id = IntCast( - EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "block.id"); + EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b()), + b()->getIntNTy(128), /*isSigned=*/true, "block.id"); llvm::Value* thread_id_in_block = IntCast( - EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); + EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b()), + b()->getIntNTy(128), /*isSigned=*/true, "thread.id"); llvm::Value* threads_per_block = IntCast( - EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); + EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b()), + b()->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block); } diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 766a4c84df5..0303ea47e8d 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -126,6 +126,8 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const string& callee_name, absl::Span operands, absl::Span input_types, PrimitiveType output_type); + const HloModuleConfig& hlo_module_config_; + NestedComputer compute_nested_; };