[XLA] Fix elemental_ir_emitter access control:

* All data members are private now, and accessed through methods by derived classes.
* MakeElementGenerator isn't virtual anymore. Instead, the customization point is moved to EmitConvolution.

PiperOrigin-RevId: 338747947
Change-Id: I3a8382eed86f1104b46226afd50a158dddacbf3c
This commit is contained in:
Tim Shen 2020-10-23 14:41:47 -07:00 committed by TensorFlower Gardener
parent be3be65d1a
commit a0e750360a
6 changed files with 80 additions and 71 deletions

View File

@ -41,8 +41,8 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
switch (prim_type) {
case F16:
cast_result_to_fp16 = true;
lhs = FPCast(lhs, b_->getFloatTy());
rhs = FPCast(rhs, b_->getFloatTy());
lhs = FPCast(lhs, b()->getFloatTy());
rhs = FPCast(rhs, b()->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "atan2f";
@ -55,7 +55,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
}
// Create a function declaration.
llvm::Function* function = llvm::dyn_cast<llvm::Function>(
module_
module()
->getOrInsertFunction(function_name, lhs->getType(), lhs->getType(),
rhs->getType())
.getCallee());
@ -65,7 +65,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
// Create an instruction to call the function.
llvm::Value* result = Call(function, {lhs, rhs});
if (cast_result_to_fp16) {
result = FPCast(result, b_->getHalfTy());
result = FPCast(result, b()->getHalfTy());
}
return result;
}
@ -77,7 +77,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
switch (prim_type) {
case F16:
cast_result_to_fp16 = true;
value = FPCast(value, b_->getFloatTy());
value = FPCast(value, b()->getFloatTy());
TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "tanhf";
@ -90,7 +90,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
}
// Create a function declaration.
llvm::Function* function = llvm::dyn_cast<llvm::Function>(
module_
module()
->getOrInsertFunction(function_name, value->getType(),
value->getType())
.getCallee());
@ -100,26 +100,20 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
// Create an instruction to call the function.
llvm::Value* result = Call(function, value);
if (cast_result_to_fp16) {
result = FPCast(result, b_->getHalfTy());
result = FPCast(result, b()->getHalfTy());
}
return result;
}
llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitConvolution(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator) {
switch (hlo->opcode()) {
case HloOpcode::kConvolution:
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
return ir_emitter_->EmitElementalConvolution(
Cast<HloConvolutionInstruction>(hlo),
operand_to_generator.at(hlo->operand(0)),
operand_to_generator.at(hlo->operand(1)), index);
};
default:
return ElementalIrEmitter::MakeElementGenerator(hlo,
operand_to_generator);
}
const HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& index) {
return ir_emitter_->EmitElementalConvolution(
Cast<HloConvolutionInstruction>(hlo),
operand_to_generator.at(hlo->operand(0)),
operand_to_generator.at(hlo->operand(1)), index);
}
} // namespace cpu
} // namespace xla

View File

@ -31,18 +31,19 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
public:
CpuElementalIrEmitter(const HloModuleConfig& module_config,
IrEmitter* ir_emitter, llvm::Module* module)
: ElementalIrEmitter(module_config, module, ir_emitter->b()),
: ElementalIrEmitter(module, ir_emitter->b()),
hlo_module_config_(module_config),
ir_emitter_(ir_emitter) {}
llvm_ir::ElementGenerator MakeElementGenerator(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator) override;
protected:
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
llvm::Value* rhs) override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) override;
StatusOr<llvm::Value*> EmitConvolution(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& index) override;
StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
@ -54,6 +55,7 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
return hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max();
}
const HloModuleConfig& hlo_module_config_;
IrEmitter* ir_emitter_;
};

View File

@ -2486,6 +2486,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
return EmitElementalReduce(reduce_instr, std::move(input_generators),
std::move(initial_value_generators), index);
};
case HloOpcode::kConvolution:
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
return EmitConvolution(hlo, operand_to_generator, index);
};
default:
return [hlo](const IrArray::Index& index) {
return Unimplemented("Unhandled opcode for elemental IR emission: %s",
@ -2730,6 +2734,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& index) {
return Unimplemented("Elemental convolution is not implemented");
}
// Evaluate polynomial using Horner's method.
StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
@ -39,22 +38,14 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
using HloToElementGeneratorMap =
std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>;
ElementalIrEmitter(const HloModuleConfig& hlo_module_config,
llvm::Module* module, llvm::IRBuilder<>* b)
: b_(b), module_(module), hlo_module_config_(hlo_module_config) {}
ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b)
: b_(b), module_(module) {}
virtual ~ElementalIrEmitter() = default;
virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
llvm::Value* operand_value);
virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,
llvm::Value* rhs_value);
// Returns a function to generate an element of the output of `hlo`, given a
// map of functions to generate elements of its operands.
virtual llvm_ir::ElementGenerator MakeElementGenerator(
llvm_ir::ElementGenerator MakeElementGenerator(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator);
@ -66,6 +57,21 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
llvm::Module* module() { return module_; }
protected:
virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,
llvm::Value* rhs_value);
virtual llvm::Value* EmitExtractReal(llvm::Value* value);
virtual llvm::Value* EmitExtractImag(llvm::Value* value);
private:
virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
llvm::Value* operand_value);
virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,
llvm::Value* rhs_value);
virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op,
llvm::Value* operand_value);
@ -92,10 +98,6 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
llvm::Value* rhs_value,
bool is_signed);
virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,
llvm::Value* rhs_value);
virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,
llvm::Value* rhs_value);
@ -175,9 +177,6 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
PrimitiveType prim_type,
llvm::Value* operand_value);
virtual llvm::Value* EmitExtractReal(llvm::Value* value);
virtual llvm::Value* EmitExtractImag(llvm::Value* value);
// Composes a complex struct. imag may be nullptr for simple cast operations.
llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
llvm::Value* imag);
@ -245,17 +244,11 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
std::vector<llvm_ir::ElementGenerator> initial_value_generators,
const llvm_ir::IrArray::Index& index);
virtual bool fast_min_max() = 0;
virtual StatusOr<llvm::Value*> EmitConvolution(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& index);
llvm::IRBuilder<>* const b_;
llvm::Module* module_;
// The HloModuleConfig which gathers all settings and values which affect the
// compiled executable outside of the HLO code itself.
const HloModuleConfig& hlo_module_config_;
private:
// Computes the complex power function, returns (a + i*b)^(c + i*d).
StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op,
llvm::Value* a, llvm::Value* b,
@ -264,6 +257,12 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
// Evaluates a polynomial using Horner's method.
StatusOr<llvm::Value*> EvaluatePolynomial(
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients);
virtual bool fast_min_max() = 0;
llvm::IRBuilder<>* const b_;
llvm::Module* module_;
};
} // namespace xla

View File

@ -72,7 +72,8 @@ bool IsFPLiteralWithValue(const HloInstruction* operand, float value) {
GpuElementalIrEmitter::GpuElementalIrEmitter(
const HloModuleConfig& hlo_module_config, llvm::Module* module,
llvm::IRBuilder<>* b, NestedComputer compute_nested)
: ElementalIrEmitter(hlo_module_config, module, b),
: ElementalIrEmitter(module, b),
hlo_module_config_(hlo_module_config),
compute_nested_(std::move(compute_nested)) {}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
@ -91,7 +92,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
for (int64 i = 0; i < operands.size(); ++i) {
if (input_types[i] == F16) {
converted_operands[i] =
FPCast(converted_operands[i], b_->getFloatTy());
FPCast(converted_operands[i], b()->getFloatTy());
converted_input_types[i] = F32;
}
}
@ -106,12 +107,12 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
PrimitiveType_Name(output_type));
}
const string& munged_callee =
ObtainDeviceFunctionName(funcid, output_type, b_);
ObtainDeviceFunctionName(funcid, output_type, b());
llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
converted_input_types, output_type)
.ValueOrDie();
if (cast_result_to_fp16) {
result = FPCast(result, b_->getHalfTy());
result = FPCast(result, b()->getHalfTy());
}
return result;
}
@ -153,7 +154,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
return EmitDeviceFunctionCall(
callee_name, operands, input_types, output_type,
{llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b_);
{llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b());
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
@ -168,7 +169,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
return llvm_ir::EmitCallToIntrinsic(
opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum
: llvm::Intrinsic::minnum,
{lhs_value, rhs_value}, {lhs_value->getType()}, b_);
{lhs_value, rhs_value}, {lhs_value->getType()}, b());
}
switch (op->opcode()) {
@ -275,19 +276,19 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
// This routine isn't numerically precise, but it's good enough for ML.
// Upcast F16 to F32 if necessary.
llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
llvm::Type* type = prim_type == F16 ? b()->getFloatTy() : value->getType();
llvm::Value* input = FPCast(value, type);
// If |value| >= kMaxValue, tanh() is set to -1.0 or 1.0.
constexpr double kMaxValue = 20.0;
auto max_value = llvm::ConstantFP::get(type, kMaxValue);
llvm::Value* abs_value =
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b_);
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b());
llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b(), input);
auto one = llvm::ConstantFP::get(type, 1.0);
auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
{one, input}, {type}, b_);
{one, input}, {type}, b());
return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign),
value->getType(), "tanh");
}
@ -301,14 +302,14 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexAbs(
llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
llvm::Value* block_id = IntCast(
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_),
b_->getIntNTy(128), /*isSigned=*/true, "block.id");
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b()),
b()->getIntNTy(128), /*isSigned=*/true, "block.id");
llvm::Value* thread_id_in_block = IntCast(
EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_),
b_->getIntNTy(128), /*isSigned=*/true, "thread.id");
EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b()),
b()->getIntNTy(128), /*isSigned=*/true, "thread.id");
llvm::Value* threads_per_block = IntCast(
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b_),
b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b()),
b()->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
}

View File

@ -126,6 +126,8 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
const string& callee_name, absl::Span<llvm::Value* const> operands,
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
const HloModuleConfig& hlo_module_config_;
NestedComputer compute_nested_;
};