[XLA] Fix elemental_ir_emitter access control:
* All data members are private now, and accessed through methods by derived classes. * MakeElementGenerator isn't virtual anymore. Instead, the customization point is moved to EmitConvolution. PiperOrigin-RevId: 338747947 Change-Id: I3a8382eed86f1104b46226afd50a158dddacbf3c
This commit is contained in:
parent
be3be65d1a
commit
a0e750360a
@ -41,8 +41,8 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
|
|||||||
switch (prim_type) {
|
switch (prim_type) {
|
||||||
case F16:
|
case F16:
|
||||||
cast_result_to_fp16 = true;
|
cast_result_to_fp16 = true;
|
||||||
lhs = FPCast(lhs, b_->getFloatTy());
|
lhs = FPCast(lhs, b()->getFloatTy());
|
||||||
rhs = FPCast(rhs, b_->getFloatTy());
|
rhs = FPCast(rhs, b()->getFloatTy());
|
||||||
TF_FALLTHROUGH_INTENDED;
|
TF_FALLTHROUGH_INTENDED;
|
||||||
case F32:
|
case F32:
|
||||||
function_name = "atan2f";
|
function_name = "atan2f";
|
||||||
@ -55,7 +55,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
|
|||||||
}
|
}
|
||||||
// Create a function declaration.
|
// Create a function declaration.
|
||||||
llvm::Function* function = llvm::dyn_cast<llvm::Function>(
|
llvm::Function* function = llvm::dyn_cast<llvm::Function>(
|
||||||
module_
|
module()
|
||||||
->getOrInsertFunction(function_name, lhs->getType(), lhs->getType(),
|
->getOrInsertFunction(function_name, lhs->getType(), lhs->getType(),
|
||||||
rhs->getType())
|
rhs->getType())
|
||||||
.getCallee());
|
.getCallee());
|
||||||
@ -65,7 +65,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
|
|||||||
// Create an instruction to call the function.
|
// Create an instruction to call the function.
|
||||||
llvm::Value* result = Call(function, {lhs, rhs});
|
llvm::Value* result = Call(function, {lhs, rhs});
|
||||||
if (cast_result_to_fp16) {
|
if (cast_result_to_fp16) {
|
||||||
result = FPCast(result, b_->getHalfTy());
|
result = FPCast(result, b()->getHalfTy());
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -77,7 +77,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
|
|||||||
switch (prim_type) {
|
switch (prim_type) {
|
||||||
case F16:
|
case F16:
|
||||||
cast_result_to_fp16 = true;
|
cast_result_to_fp16 = true;
|
||||||
value = FPCast(value, b_->getFloatTy());
|
value = FPCast(value, b()->getFloatTy());
|
||||||
TF_FALLTHROUGH_INTENDED;
|
TF_FALLTHROUGH_INTENDED;
|
||||||
case F32:
|
case F32:
|
||||||
function_name = "tanhf";
|
function_name = "tanhf";
|
||||||
@ -90,7 +90,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
|
|||||||
}
|
}
|
||||||
// Create a function declaration.
|
// Create a function declaration.
|
||||||
llvm::Function* function = llvm::dyn_cast<llvm::Function>(
|
llvm::Function* function = llvm::dyn_cast<llvm::Function>(
|
||||||
module_
|
module()
|
||||||
->getOrInsertFunction(function_name, value->getType(),
|
->getOrInsertFunction(function_name, value->getType(),
|
||||||
value->getType())
|
value->getType())
|
||||||
.getCallee());
|
.getCallee());
|
||||||
@ -100,26 +100,20 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
|
|||||||
// Create an instruction to call the function.
|
// Create an instruction to call the function.
|
||||||
llvm::Value* result = Call(function, value);
|
llvm::Value* result = Call(function, value);
|
||||||
if (cast_result_to_fp16) {
|
if (cast_result_to_fp16) {
|
||||||
result = FPCast(result, b_->getHalfTy());
|
result = FPCast(result, b()->getHalfTy());
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
|
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitConvolution(
|
||||||
const HloInstruction* hlo,
|
const HloInstruction* hlo,
|
||||||
const HloToElementGeneratorMap& operand_to_generator) {
|
const HloToElementGeneratorMap& operand_to_generator,
|
||||||
switch (hlo->opcode()) {
|
const llvm_ir::IrArray::Index& index) {
|
||||||
case HloOpcode::kConvolution:
|
return ir_emitter_->EmitElementalConvolution(
|
||||||
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
Cast<HloConvolutionInstruction>(hlo),
|
||||||
return ir_emitter_->EmitElementalConvolution(
|
operand_to_generator.at(hlo->operand(0)),
|
||||||
Cast<HloConvolutionInstruction>(hlo),
|
operand_to_generator.at(hlo->operand(1)), index);
|
||||||
operand_to_generator.at(hlo->operand(0)),
|
|
||||||
operand_to_generator.at(hlo->operand(1)), index);
|
|
||||||
};
|
|
||||||
default:
|
|
||||||
return ElementalIrEmitter::MakeElementGenerator(hlo,
|
|
||||||
operand_to_generator);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cpu
|
} // namespace cpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
@ -31,18 +31,19 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
|
|||||||
public:
|
public:
|
||||||
CpuElementalIrEmitter(const HloModuleConfig& module_config,
|
CpuElementalIrEmitter(const HloModuleConfig& module_config,
|
||||||
IrEmitter* ir_emitter, llvm::Module* module)
|
IrEmitter* ir_emitter, llvm::Module* module)
|
||||||
: ElementalIrEmitter(module_config, module, ir_emitter->b()),
|
: ElementalIrEmitter(module, ir_emitter->b()),
|
||||||
|
hlo_module_config_(module_config),
|
||||||
ir_emitter_(ir_emitter) {}
|
ir_emitter_(ir_emitter) {}
|
||||||
|
|
||||||
llvm_ir::ElementGenerator MakeElementGenerator(
|
|
||||||
const HloInstruction* hlo,
|
|
||||||
const HloToElementGeneratorMap& operand_to_generator) override;
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
|
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
|
||||||
llvm::Value* rhs) override;
|
llvm::Value* rhs) override;
|
||||||
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
|
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
|
||||||
llvm::Value* value) override;
|
llvm::Value* value) override;
|
||||||
|
StatusOr<llvm::Value*> EmitConvolution(
|
||||||
|
const HloInstruction* hlo,
|
||||||
|
const HloToElementGeneratorMap& operand_to_generator,
|
||||||
|
const llvm_ir::IrArray::Index& index) override;
|
||||||
|
|
||||||
StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
|
StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
|
||||||
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
||||||
@ -54,6 +55,7 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
|
|||||||
return hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max();
|
return hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const HloModuleConfig& hlo_module_config_;
|
||||||
IrEmitter* ir_emitter_;
|
IrEmitter* ir_emitter_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -2486,6 +2486,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
|||||||
return EmitElementalReduce(reduce_instr, std::move(input_generators),
|
return EmitElementalReduce(reduce_instr, std::move(input_generators),
|
||||||
std::move(initial_value_generators), index);
|
std::move(initial_value_generators), index);
|
||||||
};
|
};
|
||||||
|
case HloOpcode::kConvolution:
|
||||||
|
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
||||||
|
return EmitConvolution(hlo, operand_to_generator, index);
|
||||||
|
};
|
||||||
default:
|
default:
|
||||||
return [hlo](const IrArray::Index& index) {
|
return [hlo](const IrArray::Index& index) {
|
||||||
return Unimplemented("Unhandled opcode for elemental IR emission: %s",
|
return Unimplemented("Unhandled opcode for elemental IR emission: %s",
|
||||||
@ -2730,6 +2734,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution(
|
||||||
|
const HloInstruction* hlo,
|
||||||
|
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
|
||||||
|
const llvm_ir::IrArray::Index& index) {
|
||||||
|
return Unimplemented("Elemental convolution is not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
// Evaluate polynomial using Horner's method.
|
// Evaluate polynomial using Horner's method.
|
||||||
StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
|
StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
|
||||||
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {
|
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {
|
||||||
|
|||||||
@ -26,7 +26,6 @@ limitations under the License.
|
|||||||
#include "llvm/IR/Value.h"
|
#include "llvm/IR/Value.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
|
||||||
@ -39,22 +38,14 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
|||||||
using HloToElementGeneratorMap =
|
using HloToElementGeneratorMap =
|
||||||
std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>;
|
std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>;
|
||||||
|
|
||||||
ElementalIrEmitter(const HloModuleConfig& hlo_module_config,
|
ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b)
|
||||||
llvm::Module* module, llvm::IRBuilder<>* b)
|
: b_(b), module_(module) {}
|
||||||
: b_(b), module_(module), hlo_module_config_(hlo_module_config) {}
|
|
||||||
|
|
||||||
virtual ~ElementalIrEmitter() = default;
|
virtual ~ElementalIrEmitter() = default;
|
||||||
|
|
||||||
virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
|
|
||||||
llvm::Value* operand_value);
|
|
||||||
|
|
||||||
virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
|
|
||||||
llvm::Value* lhs_value,
|
|
||||||
llvm::Value* rhs_value);
|
|
||||||
|
|
||||||
// Returns a function to generate an element of the output of `hlo`, given a
|
// Returns a function to generate an element of the output of `hlo`, given a
|
||||||
// map of functions to generate elements of its operands.
|
// map of functions to generate elements of its operands.
|
||||||
virtual llvm_ir::ElementGenerator MakeElementGenerator(
|
llvm_ir::ElementGenerator MakeElementGenerator(
|
||||||
const HloInstruction* hlo,
|
const HloInstruction* hlo,
|
||||||
const HloToElementGeneratorMap& operand_to_generator);
|
const HloToElementGeneratorMap& operand_to_generator);
|
||||||
|
|
||||||
@ -66,6 +57,21 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
|||||||
llvm::Module* module() { return module_; }
|
llvm::Module* module() { return module_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
|
||||||
|
llvm::Value* lhs_value,
|
||||||
|
llvm::Value* rhs_value);
|
||||||
|
|
||||||
|
virtual llvm::Value* EmitExtractReal(llvm::Value* value);
|
||||||
|
virtual llvm::Value* EmitExtractImag(llvm::Value* value);
|
||||||
|
|
||||||
|
private:
|
||||||
|
virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
|
||||||
|
llvm::Value* operand_value);
|
||||||
|
|
||||||
|
virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
|
||||||
|
llvm::Value* lhs_value,
|
||||||
|
llvm::Value* rhs_value);
|
||||||
|
|
||||||
virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op,
|
virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op,
|
||||||
llvm::Value* operand_value);
|
llvm::Value* operand_value);
|
||||||
|
|
||||||
@ -92,10 +98,6 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
|||||||
llvm::Value* rhs_value,
|
llvm::Value* rhs_value,
|
||||||
bool is_signed);
|
bool is_signed);
|
||||||
|
|
||||||
virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
|
|
||||||
llvm::Value* lhs_value,
|
|
||||||
llvm::Value* rhs_value);
|
|
||||||
|
|
||||||
virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op,
|
virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op,
|
||||||
llvm::Value* lhs_value,
|
llvm::Value* lhs_value,
|
||||||
llvm::Value* rhs_value);
|
llvm::Value* rhs_value);
|
||||||
@ -175,9 +177,6 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
|||||||
PrimitiveType prim_type,
|
PrimitiveType prim_type,
|
||||||
llvm::Value* operand_value);
|
llvm::Value* operand_value);
|
||||||
|
|
||||||
virtual llvm::Value* EmitExtractReal(llvm::Value* value);
|
|
||||||
virtual llvm::Value* EmitExtractImag(llvm::Value* value);
|
|
||||||
|
|
||||||
// Composes a complex struct. imag may be nullptr for simple cast operations.
|
// Composes a complex struct. imag may be nullptr for simple cast operations.
|
||||||
llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
|
llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
|
||||||
llvm::Value* imag);
|
llvm::Value* imag);
|
||||||
@ -245,17 +244,11 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
|||||||
std::vector<llvm_ir::ElementGenerator> initial_value_generators,
|
std::vector<llvm_ir::ElementGenerator> initial_value_generators,
|
||||||
const llvm_ir::IrArray::Index& index);
|
const llvm_ir::IrArray::Index& index);
|
||||||
|
|
||||||
virtual bool fast_min_max() = 0;
|
virtual StatusOr<llvm::Value*> EmitConvolution(
|
||||||
|
const HloInstruction* hlo,
|
||||||
|
const HloToElementGeneratorMap& operand_to_generator,
|
||||||
|
const llvm_ir::IrArray::Index& index);
|
||||||
|
|
||||||
llvm::IRBuilder<>* const b_;
|
|
||||||
|
|
||||||
llvm::Module* module_;
|
|
||||||
|
|
||||||
// The HloModuleConfig which gathers all settings and values which affect the
|
|
||||||
// compiled executable outside of the HLO code itself.
|
|
||||||
const HloModuleConfig& hlo_module_config_;
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Computes the complex power function, returns (a + i*b)^(c + i*d).
|
// Computes the complex power function, returns (a + i*b)^(c + i*d).
|
||||||
StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op,
|
StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op,
|
||||||
llvm::Value* a, llvm::Value* b,
|
llvm::Value* a, llvm::Value* b,
|
||||||
@ -264,6 +257,12 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
|||||||
// Evaluates a polynomial using Horner's method.
|
// Evaluates a polynomial using Horner's method.
|
||||||
StatusOr<llvm::Value*> EvaluatePolynomial(
|
StatusOr<llvm::Value*> EvaluatePolynomial(
|
||||||
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients);
|
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients);
|
||||||
|
|
||||||
|
virtual bool fast_min_max() = 0;
|
||||||
|
|
||||||
|
llvm::IRBuilder<>* const b_;
|
||||||
|
|
||||||
|
llvm::Module* module_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
@ -72,7 +72,8 @@ bool IsFPLiteralWithValue(const HloInstruction* operand, float value) {
|
|||||||
GpuElementalIrEmitter::GpuElementalIrEmitter(
|
GpuElementalIrEmitter::GpuElementalIrEmitter(
|
||||||
const HloModuleConfig& hlo_module_config, llvm::Module* module,
|
const HloModuleConfig& hlo_module_config, llvm::Module* module,
|
||||||
llvm::IRBuilder<>* b, NestedComputer compute_nested)
|
llvm::IRBuilder<>* b, NestedComputer compute_nested)
|
||||||
: ElementalIrEmitter(hlo_module_config, module, b),
|
: ElementalIrEmitter(module, b),
|
||||||
|
hlo_module_config_(hlo_module_config),
|
||||||
compute_nested_(std::move(compute_nested)) {}
|
compute_nested_(std::move(compute_nested)) {}
|
||||||
|
|
||||||
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
|
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
|
||||||
@ -91,7 +92,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
|
|||||||
for (int64 i = 0; i < operands.size(); ++i) {
|
for (int64 i = 0; i < operands.size(); ++i) {
|
||||||
if (input_types[i] == F16) {
|
if (input_types[i] == F16) {
|
||||||
converted_operands[i] =
|
converted_operands[i] =
|
||||||
FPCast(converted_operands[i], b_->getFloatTy());
|
FPCast(converted_operands[i], b()->getFloatTy());
|
||||||
converted_input_types[i] = F32;
|
converted_input_types[i] = F32;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -106,12 +107,12 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
|
|||||||
PrimitiveType_Name(output_type));
|
PrimitiveType_Name(output_type));
|
||||||
}
|
}
|
||||||
const string& munged_callee =
|
const string& munged_callee =
|
||||||
ObtainDeviceFunctionName(funcid, output_type, b_);
|
ObtainDeviceFunctionName(funcid, output_type, b());
|
||||||
llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
|
llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
|
||||||
converted_input_types, output_type)
|
converted_input_types, output_type)
|
||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
if (cast_result_to_fp16) {
|
if (cast_result_to_fp16) {
|
||||||
result = FPCast(result, b_->getHalfTy());
|
result = FPCast(result, b()->getHalfTy());
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -153,7 +154,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
|
|||||||
|
|
||||||
return EmitDeviceFunctionCall(
|
return EmitDeviceFunctionCall(
|
||||||
callee_name, operands, input_types, output_type,
|
callee_name, operands, input_types, output_type,
|
||||||
{llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b_);
|
{llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
|
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
|
||||||
@ -168,7 +169,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
|
|||||||
return llvm_ir::EmitCallToIntrinsic(
|
return llvm_ir::EmitCallToIntrinsic(
|
||||||
opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum
|
opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum
|
||||||
: llvm::Intrinsic::minnum,
|
: llvm::Intrinsic::minnum,
|
||||||
{lhs_value, rhs_value}, {lhs_value->getType()}, b_);
|
{lhs_value, rhs_value}, {lhs_value->getType()}, b());
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (op->opcode()) {
|
switch (op->opcode()) {
|
||||||
@ -275,19 +276,19 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
|
|||||||
// This routine isn't numerically precise, but it's good enough for ML.
|
// This routine isn't numerically precise, but it's good enough for ML.
|
||||||
|
|
||||||
// Upcast F16 to F32 if necessary.
|
// Upcast F16 to F32 if necessary.
|
||||||
llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
|
llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType();
|
||||||
llvm::Value* input = FPCast(value, type);
|
llvm::Value* input = FPCast(value, type);
|
||||||
|
|
||||||
// If |value| >= kMaxValue, tanh() is set to -1.0 or 1.0.
|
// If |value| >= kMaxValue, tanh() is set to -1.0 or 1.0.
|
||||||
constexpr double kMaxValue = 20.0;
|
constexpr double kMaxValue = 20.0;
|
||||||
auto max_value = llvm::ConstantFP::get(type, kMaxValue);
|
auto max_value = llvm::ConstantFP::get(type, kMaxValue);
|
||||||
llvm::Value* abs_value =
|
llvm::Value* abs_value =
|
||||||
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b_);
|
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b());
|
||||||
|
|
||||||
llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
|
llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b(), input);
|
||||||
auto one = llvm::ConstantFP::get(type, 1.0);
|
auto one = llvm::ConstantFP::get(type, 1.0);
|
||||||
auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
|
auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
|
||||||
{one, input}, {type}, b_);
|
{one, input}, {type}, b());
|
||||||
return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign),
|
return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign),
|
||||||
value->getType(), "tanh");
|
value->getType(), "tanh");
|
||||||
}
|
}
|
||||||
@ -301,14 +302,14 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexAbs(
|
|||||||
|
|
||||||
llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
|
llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
|
||||||
llvm::Value* block_id = IntCast(
|
llvm::Value* block_id = IntCast(
|
||||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_),
|
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b()),
|
||||||
b_->getIntNTy(128), /*isSigned=*/true, "block.id");
|
b()->getIntNTy(128), /*isSigned=*/true, "block.id");
|
||||||
llvm::Value* thread_id_in_block = IntCast(
|
llvm::Value* thread_id_in_block = IntCast(
|
||||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_),
|
EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b()),
|
||||||
b_->getIntNTy(128), /*isSigned=*/true, "thread.id");
|
b()->getIntNTy(128), /*isSigned=*/true, "thread.id");
|
||||||
llvm::Value* threads_per_block = IntCast(
|
llvm::Value* threads_per_block = IntCast(
|
||||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b_),
|
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b()),
|
||||||
b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
|
b()->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
|
||||||
return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
|
return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -126,6 +126,8 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
|
|||||||
const string& callee_name, absl::Span<llvm::Value* const> operands,
|
const string& callee_name, absl::Span<llvm::Value* const> operands,
|
||||||
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
|
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
|
||||||
|
|
||||||
|
const HloModuleConfig& hlo_module_config_;
|
||||||
|
|
||||||
NestedComputer compute_nested_;
|
NestedComputer compute_nested_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user