[XLA] Fix elemental_ir_emitter access control:
* All data members are private now, and accessed through methods by derived classes. * MakeElementGenerator isn't virtual anymore. Instead, the customization point is moved to EmitConvolution. PiperOrigin-RevId: 338747947 Change-Id: I3a8382eed86f1104b46226afd50a158dddacbf3c
This commit is contained in:
		
							parent
							
								
									be3be65d1a
								
							
						
					
					
						commit
						a0e750360a
					
				| @ -41,8 +41,8 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, | ||||
|   switch (prim_type) { | ||||
|     case F16: | ||||
|       cast_result_to_fp16 = true; | ||||
|       lhs = FPCast(lhs, b_->getFloatTy()); | ||||
|       rhs = FPCast(rhs, b_->getFloatTy()); | ||||
|       lhs = FPCast(lhs, b()->getFloatTy()); | ||||
|       rhs = FPCast(rhs, b()->getFloatTy()); | ||||
|       TF_FALLTHROUGH_INTENDED; | ||||
|     case F32: | ||||
|       function_name = "atan2f"; | ||||
| @ -55,7 +55,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, | ||||
|   } | ||||
|   // Create a function declaration.
 | ||||
|   llvm::Function* function = llvm::dyn_cast<llvm::Function>( | ||||
|       module_ | ||||
|       module() | ||||
|           ->getOrInsertFunction(function_name, lhs->getType(), lhs->getType(), | ||||
|                                 rhs->getType()) | ||||
|           .getCallee()); | ||||
| @ -65,7 +65,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, | ||||
|   // Create an instruction to call the function.
 | ||||
|   llvm::Value* result = Call(function, {lhs, rhs}); | ||||
|   if (cast_result_to_fp16) { | ||||
|     result = FPCast(result, b_->getHalfTy()); | ||||
|     result = FPCast(result, b()->getHalfTy()); | ||||
|   } | ||||
|   return result; | ||||
| } | ||||
| @ -77,7 +77,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, | ||||
|   switch (prim_type) { | ||||
|     case F16: | ||||
|       cast_result_to_fp16 = true; | ||||
|       value = FPCast(value, b_->getFloatTy()); | ||||
|       value = FPCast(value, b()->getFloatTy()); | ||||
|       TF_FALLTHROUGH_INTENDED; | ||||
|     case F32: | ||||
|       function_name = "tanhf"; | ||||
| @ -90,7 +90,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, | ||||
|   } | ||||
|   // Create a function declaration.
 | ||||
|   llvm::Function* function = llvm::dyn_cast<llvm::Function>( | ||||
|       module_ | ||||
|       module() | ||||
|           ->getOrInsertFunction(function_name, value->getType(), | ||||
|                                 value->getType()) | ||||
|           .getCallee()); | ||||
| @ -100,26 +100,20 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, | ||||
|   // Create an instruction to call the function.
 | ||||
|   llvm::Value* result = Call(function, value); | ||||
|   if (cast_result_to_fp16) { | ||||
|     result = FPCast(result, b_->getHalfTy()); | ||||
|     result = FPCast(result, b()->getHalfTy()); | ||||
|   } | ||||
|   return result; | ||||
| } | ||||
| 
 | ||||
| llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( | ||||
| StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitConvolution( | ||||
|     const HloInstruction* hlo, | ||||
|     const HloToElementGeneratorMap& operand_to_generator) { | ||||
|   switch (hlo->opcode()) { | ||||
|     case HloOpcode::kConvolution: | ||||
|       return [this, hlo, &operand_to_generator](const IrArray::Index& index) { | ||||
|     const HloToElementGeneratorMap& operand_to_generator, | ||||
|     const llvm_ir::IrArray::Index& index) { | ||||
|   return ir_emitter_->EmitElementalConvolution( | ||||
|       Cast<HloConvolutionInstruction>(hlo), | ||||
|       operand_to_generator.at(hlo->operand(0)), | ||||
|       operand_to_generator.at(hlo->operand(1)), index); | ||||
|       }; | ||||
|     default: | ||||
|       return ElementalIrEmitter::MakeElementGenerator(hlo, | ||||
|                                                       operand_to_generator); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| }  // namespace cpu
 | ||||
| }  // namespace xla
 | ||||
|  | ||||
| @ -31,18 +31,19 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { | ||||
|  public: | ||||
|   CpuElementalIrEmitter(const HloModuleConfig& module_config, | ||||
|                         IrEmitter* ir_emitter, llvm::Module* module) | ||||
|       : ElementalIrEmitter(module_config, module, ir_emitter->b()), | ||||
|       : ElementalIrEmitter(module, ir_emitter->b()), | ||||
|         hlo_module_config_(module_config), | ||||
|         ir_emitter_(ir_emitter) {} | ||||
| 
 | ||||
|   llvm_ir::ElementGenerator MakeElementGenerator( | ||||
|       const HloInstruction* hlo, | ||||
|       const HloToElementGeneratorMap& operand_to_generator) override; | ||||
| 
 | ||||
|  protected: | ||||
|   StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, | ||||
|                                    llvm::Value* rhs) override; | ||||
|   StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type, | ||||
|                                   llvm::Value* value) override; | ||||
|   StatusOr<llvm::Value*> EmitConvolution( | ||||
|       const HloInstruction* hlo, | ||||
|       const HloToElementGeneratorMap& operand_to_generator, | ||||
|       const llvm_ir::IrArray::Index& index) override; | ||||
| 
 | ||||
|   StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall( | ||||
|       const HloComputation& callee, absl::Span<llvm::Value* const> parameters, | ||||
| @ -54,6 +55,7 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { | ||||
|     return hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max(); | ||||
|   } | ||||
| 
 | ||||
|   const HloModuleConfig& hlo_module_config_; | ||||
|   IrEmitter* ir_emitter_; | ||||
| }; | ||||
| 
 | ||||
|  | ||||
| @ -2486,6 +2486,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( | ||||
|         return EmitElementalReduce(reduce_instr, std::move(input_generators), | ||||
|                                    std::move(initial_value_generators), index); | ||||
|       }; | ||||
|     case HloOpcode::kConvolution: | ||||
|       return [this, hlo, &operand_to_generator](const IrArray::Index& index) { | ||||
|         return EmitConvolution(hlo, operand_to_generator, index); | ||||
|       }; | ||||
|     default: | ||||
|       return [hlo](const IrArray::Index& index) { | ||||
|         return Unimplemented("Unhandled opcode for elemental IR emission: %s", | ||||
| @ -2730,6 +2734,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce( | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution( | ||||
|     const HloInstruction* hlo, | ||||
|     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, | ||||
|     const llvm_ir::IrArray::Index& index) { | ||||
|   return Unimplemented("Elemental convolution is not implemented"); | ||||
| } | ||||
| 
 | ||||
| // Evaluate polynomial using Horner's method.
 | ||||
| StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial( | ||||
|     llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) { | ||||
|  | ||||
| @ -26,7 +26,6 @@ limitations under the License. | ||||
| #include "llvm/IR/Value.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_instructions.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_module_config.h" | ||||
| #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" | ||||
| #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" | ||||
| #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" | ||||
| @ -39,22 +38,14 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> { | ||||
|   using HloToElementGeneratorMap = | ||||
|       std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>; | ||||
| 
 | ||||
|   ElementalIrEmitter(const HloModuleConfig& hlo_module_config, | ||||
|                      llvm::Module* module, llvm::IRBuilder<>* b) | ||||
|       : b_(b), module_(module), hlo_module_config_(hlo_module_config) {} | ||||
|   ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b) | ||||
|       : b_(b), module_(module) {} | ||||
| 
 | ||||
|   virtual ~ElementalIrEmitter() = default; | ||||
| 
 | ||||
|   virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op, | ||||
|                                              llvm::Value* operand_value); | ||||
| 
 | ||||
|   virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op, | ||||
|                                               llvm::Value* lhs_value, | ||||
|                                               llvm::Value* rhs_value); | ||||
| 
 | ||||
|   // Returns a function to generate an element of the output of `hlo`, given a
 | ||||
|   // map of functions to generate elements of its operands.
 | ||||
|   virtual llvm_ir::ElementGenerator MakeElementGenerator( | ||||
|   llvm_ir::ElementGenerator MakeElementGenerator( | ||||
|       const HloInstruction* hlo, | ||||
|       const HloToElementGeneratorMap& operand_to_generator); | ||||
| 
 | ||||
| @ -66,6 +57,21 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> { | ||||
|   llvm::Module* module() { return module_; } | ||||
| 
 | ||||
|  protected: | ||||
|   virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op, | ||||
|                                                    llvm::Value* lhs_value, | ||||
|                                                    llvm::Value* rhs_value); | ||||
| 
 | ||||
|   virtual llvm::Value* EmitExtractReal(llvm::Value* value); | ||||
|   virtual llvm::Value* EmitExtractImag(llvm::Value* value); | ||||
| 
 | ||||
|  private: | ||||
|   virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op, | ||||
|                                              llvm::Value* operand_value); | ||||
| 
 | ||||
|   virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op, | ||||
|                                               llvm::Value* lhs_value, | ||||
|                                               llvm::Value* rhs_value); | ||||
| 
 | ||||
|   virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op, | ||||
|                                                     llvm::Value* operand_value); | ||||
| 
 | ||||
| @ -92,10 +98,6 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> { | ||||
|                                                      llvm::Value* rhs_value, | ||||
|                                                      bool is_signed); | ||||
| 
 | ||||
|   virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op, | ||||
|                                                    llvm::Value* lhs_value, | ||||
|                                                    llvm::Value* rhs_value); | ||||
| 
 | ||||
|   virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op, | ||||
|                                                      llvm::Value* lhs_value, | ||||
|                                                      llvm::Value* rhs_value); | ||||
| @ -175,9 +177,6 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> { | ||||
|                                                   PrimitiveType prim_type, | ||||
|                                                   llvm::Value* operand_value); | ||||
| 
 | ||||
|   virtual llvm::Value* EmitExtractReal(llvm::Value* value); | ||||
|   virtual llvm::Value* EmitExtractImag(llvm::Value* value); | ||||
| 
 | ||||
|   // Composes a complex struct. imag may be nullptr for simple cast operations.
 | ||||
|   llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, | ||||
|                                   llvm::Value* imag); | ||||
| @ -245,17 +244,11 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> { | ||||
|       std::vector<llvm_ir::ElementGenerator> initial_value_generators, | ||||
|       const llvm_ir::IrArray::Index& index); | ||||
| 
 | ||||
|   virtual bool fast_min_max() = 0; | ||||
|   virtual StatusOr<llvm::Value*> EmitConvolution( | ||||
|       const HloInstruction* hlo, | ||||
|       const HloToElementGeneratorMap& operand_to_generator, | ||||
|       const llvm_ir::IrArray::Index& index); | ||||
| 
 | ||||
|   llvm::IRBuilder<>* const b_; | ||||
| 
 | ||||
|   llvm::Module* module_; | ||||
| 
 | ||||
|   // The HloModuleConfig which gathers all settings and values which affect the
 | ||||
|   // compiled executable outside of the HLO code itself.
 | ||||
|   const HloModuleConfig& hlo_module_config_; | ||||
| 
 | ||||
|  private: | ||||
|   // Computes the complex power function, returns (a + i*b)^(c + i*d).
 | ||||
|   StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op, | ||||
|                                           llvm::Value* a, llvm::Value* b, | ||||
| @ -264,6 +257,12 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> { | ||||
|   // Evaluates a polynomial using Horner's method.
 | ||||
|   StatusOr<llvm::Value*> EvaluatePolynomial( | ||||
|       llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients); | ||||
| 
 | ||||
|   virtual bool fast_min_max() = 0; | ||||
| 
 | ||||
|   llvm::IRBuilder<>* const b_; | ||||
| 
 | ||||
|   llvm::Module* module_; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace xla
 | ||||
|  | ||||
| @ -72,7 +72,8 @@ bool IsFPLiteralWithValue(const HloInstruction* operand, float value) { | ||||
| GpuElementalIrEmitter::GpuElementalIrEmitter( | ||||
|     const HloModuleConfig& hlo_module_config, llvm::Module* module, | ||||
|     llvm::IRBuilder<>* b, NestedComputer compute_nested) | ||||
|     : ElementalIrEmitter(hlo_module_config, module, b), | ||||
|     : ElementalIrEmitter(module, b), | ||||
|       hlo_module_config_(hlo_module_config), | ||||
|       compute_nested_(std::move(compute_nested)) {} | ||||
| 
 | ||||
| StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall( | ||||
| @ -91,7 +92,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall( | ||||
|       for (int64 i = 0; i < operands.size(); ++i) { | ||||
|         if (input_types[i] == F16) { | ||||
|           converted_operands[i] = | ||||
|               FPCast(converted_operands[i], b_->getFloatTy()); | ||||
|               FPCast(converted_operands[i], b()->getFloatTy()); | ||||
|           converted_input_types[i] = F32; | ||||
|         } | ||||
|       } | ||||
| @ -106,12 +107,12 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall( | ||||
|                            PrimitiveType_Name(output_type)); | ||||
|   } | ||||
|   const string& munged_callee = | ||||
|       ObtainDeviceFunctionName(funcid, output_type, b_); | ||||
|       ObtainDeviceFunctionName(funcid, output_type, b()); | ||||
|   llvm::Value* result = EmitMathCall(munged_callee, converted_operands, | ||||
|                                      converted_input_types, output_type) | ||||
|                             .ValueOrDie(); | ||||
|   if (cast_result_to_fp16) { | ||||
|     result = FPCast(result, b_->getHalfTy()); | ||||
|     result = FPCast(result, b()->getHalfTy()); | ||||
|   } | ||||
|   return result; | ||||
| } | ||||
| @ -153,7 +154,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall( | ||||
| 
 | ||||
|   return EmitDeviceFunctionCall( | ||||
|       callee_name, operands, input_types, output_type, | ||||
|       {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b_); | ||||
|       {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b()); | ||||
| } | ||||
| 
 | ||||
| StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp( | ||||
| @ -168,7 +169,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp( | ||||
|     return llvm_ir::EmitCallToIntrinsic( | ||||
|         opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum | ||||
|                                       : llvm::Intrinsic::minnum, | ||||
|         {lhs_value, rhs_value}, {lhs_value->getType()}, b_); | ||||
|         {lhs_value, rhs_value}, {lhs_value->getType()}, b()); | ||||
|   } | ||||
| 
 | ||||
|   switch (op->opcode()) { | ||||
| @ -275,19 +276,19 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, | ||||
|   // This routine isn't numerically precise, but it's good enough for ML.
 | ||||
| 
 | ||||
|   // Upcast F16 to F32 if necessary.
 | ||||
|   llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType(); | ||||
|   llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType(); | ||||
|   llvm::Value* input = FPCast(value, type); | ||||
| 
 | ||||
|   // If |value| >= kMaxValue, tanh() is set to -1.0 or 1.0.
 | ||||
|   constexpr double kMaxValue = 20.0; | ||||
|   auto max_value = llvm::ConstantFP::get(type, kMaxValue); | ||||
|   llvm::Value* abs_value = | ||||
|       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b_); | ||||
|       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b()); | ||||
| 
 | ||||
|   llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); | ||||
|   llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b(), input); | ||||
|   auto one = llvm::ConstantFP::get(type, 1.0); | ||||
|   auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign, | ||||
|                                                     {one, input}, {type}, b_); | ||||
|                                                     {one, input}, {type}, b()); | ||||
|   return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign), | ||||
|                 value->getType(), "tanh"); | ||||
| } | ||||
| @ -301,14 +302,14 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexAbs( | ||||
| 
 | ||||
| llvm::Value* GpuElementalIrEmitter::EmitThreadId() { | ||||
|   llvm::Value* block_id = IntCast( | ||||
|       EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_), | ||||
|       b_->getIntNTy(128), /*isSigned=*/true, "block.id"); | ||||
|       EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b()), | ||||
|       b()->getIntNTy(128), /*isSigned=*/true, "block.id"); | ||||
|   llvm::Value* thread_id_in_block = IntCast( | ||||
|       EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_), | ||||
|       b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); | ||||
|       EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b()), | ||||
|       b()->getIntNTy(128), /*isSigned=*/true, "thread.id"); | ||||
|   llvm::Value* threads_per_block = IntCast( | ||||
|       EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b_), | ||||
|       b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); | ||||
|       EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b()), | ||||
|       b()->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); | ||||
|   return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block); | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -126,6 +126,8 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { | ||||
|       const string& callee_name, absl::Span<llvm::Value* const> operands, | ||||
|       absl::Span<const PrimitiveType> input_types, PrimitiveType output_type); | ||||
| 
 | ||||
|   const HloModuleConfig& hlo_module_config_; | ||||
| 
 | ||||
|   NestedComputer compute_nested_; | ||||
| }; | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user