[XLA] Fix elemental_ir_emitter access control:
* All data members are private now, and accessed through methods by derived classes. * MakeElementGenerator isn't virtual anymore. Instead, the customization point is moved to EmitConvolution. PiperOrigin-RevId: 338747947 Change-Id: I3a8382eed86f1104b46226afd50a158dddacbf3c
This commit is contained in:
parent
be3be65d1a
commit
a0e750360a
@ -41,8 +41,8 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
|
||||
switch (prim_type) {
|
||||
case F16:
|
||||
cast_result_to_fp16 = true;
|
||||
lhs = FPCast(lhs, b_->getFloatTy());
|
||||
rhs = FPCast(rhs, b_->getFloatTy());
|
||||
lhs = FPCast(lhs, b()->getFloatTy());
|
||||
rhs = FPCast(rhs, b()->getFloatTy());
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
case F32:
|
||||
function_name = "atan2f";
|
||||
@ -55,7 +55,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
|
||||
}
|
||||
// Create a function declaration.
|
||||
llvm::Function* function = llvm::dyn_cast<llvm::Function>(
|
||||
module_
|
||||
module()
|
||||
->getOrInsertFunction(function_name, lhs->getType(), lhs->getType(),
|
||||
rhs->getType())
|
||||
.getCallee());
|
||||
@ -65,7 +65,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
|
||||
// Create an instruction to call the function.
|
||||
llvm::Value* result = Call(function, {lhs, rhs});
|
||||
if (cast_result_to_fp16) {
|
||||
result = FPCast(result, b_->getHalfTy());
|
||||
result = FPCast(result, b()->getHalfTy());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -77,7 +77,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
|
||||
switch (prim_type) {
|
||||
case F16:
|
||||
cast_result_to_fp16 = true;
|
||||
value = FPCast(value, b_->getFloatTy());
|
||||
value = FPCast(value, b()->getFloatTy());
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
case F32:
|
||||
function_name = "tanhf";
|
||||
@ -90,7 +90,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
|
||||
}
|
||||
// Create a function declaration.
|
||||
llvm::Function* function = llvm::dyn_cast<llvm::Function>(
|
||||
module_
|
||||
module()
|
||||
->getOrInsertFunction(function_name, value->getType(),
|
||||
value->getType())
|
||||
.getCallee());
|
||||
@ -100,26 +100,20 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
|
||||
// Create an instruction to call the function.
|
||||
llvm::Value* result = Call(function, value);
|
||||
if (cast_result_to_fp16) {
|
||||
result = FPCast(result, b_->getHalfTy());
|
||||
result = FPCast(result, b()->getHalfTy());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
|
||||
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitConvolution(
|
||||
const HloInstruction* hlo,
|
||||
const HloToElementGeneratorMap& operand_to_generator) {
|
||||
switch (hlo->opcode()) {
|
||||
case HloOpcode::kConvolution:
|
||||
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
||||
return ir_emitter_->EmitElementalConvolution(
|
||||
Cast<HloConvolutionInstruction>(hlo),
|
||||
operand_to_generator.at(hlo->operand(0)),
|
||||
operand_to_generator.at(hlo->operand(1)), index);
|
||||
};
|
||||
default:
|
||||
return ElementalIrEmitter::MakeElementGenerator(hlo,
|
||||
operand_to_generator);
|
||||
}
|
||||
const HloToElementGeneratorMap& operand_to_generator,
|
||||
const llvm_ir::IrArray::Index& index) {
|
||||
return ir_emitter_->EmitElementalConvolution(
|
||||
Cast<HloConvolutionInstruction>(hlo),
|
||||
operand_to_generator.at(hlo->operand(0)),
|
||||
operand_to_generator.at(hlo->operand(1)), index);
|
||||
}
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
||||
@ -31,18 +31,19 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
|
||||
public:
|
||||
CpuElementalIrEmitter(const HloModuleConfig& module_config,
|
||||
IrEmitter* ir_emitter, llvm::Module* module)
|
||||
: ElementalIrEmitter(module_config, module, ir_emitter->b()),
|
||||
: ElementalIrEmitter(module, ir_emitter->b()),
|
||||
hlo_module_config_(module_config),
|
||||
ir_emitter_(ir_emitter) {}
|
||||
|
||||
llvm_ir::ElementGenerator MakeElementGenerator(
|
||||
const HloInstruction* hlo,
|
||||
const HloToElementGeneratorMap& operand_to_generator) override;
|
||||
|
||||
protected:
|
||||
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
|
||||
llvm::Value* rhs) override;
|
||||
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
|
||||
llvm::Value* value) override;
|
||||
StatusOr<llvm::Value*> EmitConvolution(
|
||||
const HloInstruction* hlo,
|
||||
const HloToElementGeneratorMap& operand_to_generator,
|
||||
const llvm_ir::IrArray::Index& index) override;
|
||||
|
||||
StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
|
||||
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
||||
@ -54,6 +55,7 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
|
||||
return hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max();
|
||||
}
|
||||
|
||||
const HloModuleConfig& hlo_module_config_;
|
||||
IrEmitter* ir_emitter_;
|
||||
};
|
||||
|
||||
|
||||
@ -2486,6 +2486,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
return EmitElementalReduce(reduce_instr, std::move(input_generators),
|
||||
std::move(initial_value_generators), index);
|
||||
};
|
||||
case HloOpcode::kConvolution:
|
||||
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
||||
return EmitConvolution(hlo, operand_to_generator, index);
|
||||
};
|
||||
default:
|
||||
return [hlo](const IrArray::Index& index) {
|
||||
return Unimplemented("Unhandled opcode for elemental IR emission: %s",
|
||||
@ -2730,6 +2734,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution(
|
||||
const HloInstruction* hlo,
|
||||
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
|
||||
const llvm_ir::IrArray::Index& index) {
|
||||
return Unimplemented("Elemental convolution is not implemented");
|
||||
}
|
||||
|
||||
// Evaluate polynomial using Horner's method.
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
|
||||
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {
|
||||
|
||||
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
|
||||
@ -39,22 +38,14 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
||||
using HloToElementGeneratorMap =
|
||||
std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>;
|
||||
|
||||
ElementalIrEmitter(const HloModuleConfig& hlo_module_config,
|
||||
llvm::Module* module, llvm::IRBuilder<>* b)
|
||||
: b_(b), module_(module), hlo_module_config_(hlo_module_config) {}
|
||||
ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b)
|
||||
: b_(b), module_(module) {}
|
||||
|
||||
virtual ~ElementalIrEmitter() = default;
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
|
||||
llvm::Value* operand_value);
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
|
||||
llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value);
|
||||
|
||||
// Returns a function to generate an element of the output of `hlo`, given a
|
||||
// map of functions to generate elements of its operands.
|
||||
virtual llvm_ir::ElementGenerator MakeElementGenerator(
|
||||
llvm_ir::ElementGenerator MakeElementGenerator(
|
||||
const HloInstruction* hlo,
|
||||
const HloToElementGeneratorMap& operand_to_generator);
|
||||
|
||||
@ -66,6 +57,21 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
||||
llvm::Module* module() { return module_; }
|
||||
|
||||
protected:
|
||||
virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
|
||||
llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value);
|
||||
|
||||
virtual llvm::Value* EmitExtractReal(llvm::Value* value);
|
||||
virtual llvm::Value* EmitExtractImag(llvm::Value* value);
|
||||
|
||||
private:
|
||||
virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
|
||||
llvm::Value* operand_value);
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
|
||||
llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value);
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op,
|
||||
llvm::Value* operand_value);
|
||||
|
||||
@ -92,10 +98,6 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
||||
llvm::Value* rhs_value,
|
||||
bool is_signed);
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
|
||||
llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value);
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op,
|
||||
llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value);
|
||||
@ -175,9 +177,6 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
||||
PrimitiveType prim_type,
|
||||
llvm::Value* operand_value);
|
||||
|
||||
virtual llvm::Value* EmitExtractReal(llvm::Value* value);
|
||||
virtual llvm::Value* EmitExtractImag(llvm::Value* value);
|
||||
|
||||
// Composes a complex struct. imag may be nullptr for simple cast operations.
|
||||
llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
|
||||
llvm::Value* imag);
|
||||
@ -245,17 +244,11 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
||||
std::vector<llvm_ir::ElementGenerator> initial_value_generators,
|
||||
const llvm_ir::IrArray::Index& index);
|
||||
|
||||
virtual bool fast_min_max() = 0;
|
||||
virtual StatusOr<llvm::Value*> EmitConvolution(
|
||||
const HloInstruction* hlo,
|
||||
const HloToElementGeneratorMap& operand_to_generator,
|
||||
const llvm_ir::IrArray::Index& index);
|
||||
|
||||
llvm::IRBuilder<>* const b_;
|
||||
|
||||
llvm::Module* module_;
|
||||
|
||||
// The HloModuleConfig which gathers all settings and values which affect the
|
||||
// compiled executable outside of the HLO code itself.
|
||||
const HloModuleConfig& hlo_module_config_;
|
||||
|
||||
private:
|
||||
// Computes the complex power function, returns (a + i*b)^(c + i*d).
|
||||
StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op,
|
||||
llvm::Value* a, llvm::Value* b,
|
||||
@ -264,6 +257,12 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
||||
// Evaluates a polynomial using Horner's method.
|
||||
StatusOr<llvm::Value*> EvaluatePolynomial(
|
||||
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients);
|
||||
|
||||
virtual bool fast_min_max() = 0;
|
||||
|
||||
llvm::IRBuilder<>* const b_;
|
||||
|
||||
llvm::Module* module_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
@ -72,7 +72,8 @@ bool IsFPLiteralWithValue(const HloInstruction* operand, float value) {
|
||||
GpuElementalIrEmitter::GpuElementalIrEmitter(
|
||||
const HloModuleConfig& hlo_module_config, llvm::Module* module,
|
||||
llvm::IRBuilder<>* b, NestedComputer compute_nested)
|
||||
: ElementalIrEmitter(hlo_module_config, module, b),
|
||||
: ElementalIrEmitter(module, b),
|
||||
hlo_module_config_(hlo_module_config),
|
||||
compute_nested_(std::move(compute_nested)) {}
|
||||
|
||||
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
|
||||
@ -91,7 +92,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
|
||||
for (int64 i = 0; i < operands.size(); ++i) {
|
||||
if (input_types[i] == F16) {
|
||||
converted_operands[i] =
|
||||
FPCast(converted_operands[i], b_->getFloatTy());
|
||||
FPCast(converted_operands[i], b()->getFloatTy());
|
||||
converted_input_types[i] = F32;
|
||||
}
|
||||
}
|
||||
@ -106,12 +107,12 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
|
||||
PrimitiveType_Name(output_type));
|
||||
}
|
||||
const string& munged_callee =
|
||||
ObtainDeviceFunctionName(funcid, output_type, b_);
|
||||
ObtainDeviceFunctionName(funcid, output_type, b());
|
||||
llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
|
||||
converted_input_types, output_type)
|
||||
.ValueOrDie();
|
||||
if (cast_result_to_fp16) {
|
||||
result = FPCast(result, b_->getHalfTy());
|
||||
result = FPCast(result, b()->getHalfTy());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -153,7 +154,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
|
||||
|
||||
return EmitDeviceFunctionCall(
|
||||
callee_name, operands, input_types, output_type,
|
||||
{llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b_);
|
||||
{llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b());
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
|
||||
@ -168,7 +169,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
|
||||
return llvm_ir::EmitCallToIntrinsic(
|
||||
opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum
|
||||
: llvm::Intrinsic::minnum,
|
||||
{lhs_value, rhs_value}, {lhs_value->getType()}, b_);
|
||||
{lhs_value, rhs_value}, {lhs_value->getType()}, b());
|
||||
}
|
||||
|
||||
switch (op->opcode()) {
|
||||
@ -275,19 +276,19 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
|
||||
// This routine isn't numerically precise, but it's good enough for ML.
|
||||
|
||||
// Upcast F16 to F32 if necessary.
|
||||
llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
|
||||
llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType();
|
||||
llvm::Value* input = FPCast(value, type);
|
||||
|
||||
// If |value| >= kMaxValue, tanh() is set to -1.0 or 1.0.
|
||||
constexpr double kMaxValue = 20.0;
|
||||
auto max_value = llvm::ConstantFP::get(type, kMaxValue);
|
||||
llvm::Value* abs_value =
|
||||
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b_);
|
||||
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b());
|
||||
|
||||
llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
|
||||
llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b(), input);
|
||||
auto one = llvm::ConstantFP::get(type, 1.0);
|
||||
auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
|
||||
{one, input}, {type}, b_);
|
||||
{one, input}, {type}, b());
|
||||
return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign),
|
||||
value->getType(), "tanh");
|
||||
}
|
||||
@ -301,14 +302,14 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexAbs(
|
||||
|
||||
llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
|
||||
llvm::Value* block_id = IntCast(
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_),
|
||||
b_->getIntNTy(128), /*isSigned=*/true, "block.id");
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b()),
|
||||
b()->getIntNTy(128), /*isSigned=*/true, "block.id");
|
||||
llvm::Value* thread_id_in_block = IntCast(
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_),
|
||||
b_->getIntNTy(128), /*isSigned=*/true, "thread.id");
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b()),
|
||||
b()->getIntNTy(128), /*isSigned=*/true, "thread.id");
|
||||
llvm::Value* threads_per_block = IntCast(
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b_),
|
||||
b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b()),
|
||||
b()->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
|
||||
return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
|
||||
}
|
||||
|
||||
|
||||
@ -126,6 +126,8 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
|
||||
const string& callee_name, absl::Span<llvm::Value* const> operands,
|
||||
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
|
||||
|
||||
const HloModuleConfig& hlo_module_config_;
|
||||
|
||||
NestedComputer compute_nested_;
|
||||
};
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user