[XLA:CPU/GPU] Merge the emission of elemental kMap

There's not a lot of duplication here, but no need to have it twice.

PiperOrigin-RevId: 310910166
Change-Id: I6dfff87d56f4cc1788344300e826975cc38fe452
This commit is contained in:
Benjamin Kramer 2020-05-11 07:47:39 -07:00 committed by TensorFlower Gardener
parent d32ec0bf0b
commit d5c1743dde
8 changed files with 32 additions and 52 deletions

View File

@ -3846,6 +3846,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:core",
"@llvm-project//llvm:transform_utils",
],

View File

@ -109,18 +109,6 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator) {
switch (hlo->opcode()) {
case HloOpcode::kMap:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
std::vector<llvm::Value*> operands;
for (int i = 0; i < hlo->operand_count(); i++) {
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(i))(index));
operands.push_back(operand_value);
}
return ir_emitter_->EmitElementalMap(*Cast<HloMapInstruction>(hlo),
operands, llvm_ir::IrName(hlo));
};
case HloOpcode::kConvolution:
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
return ir_emitter_->EmitElementalConvolution(

View File

@ -695,13 +695,6 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
return Status::OK();
}
llvm::Value* IrEmitter::EmitElementalMap(
const HloMapInstruction& map_instr,
absl::Span<llvm::Value* const> elemental_operands, absl::string_view name) {
return EmitScalarReturningThreadLocalCall(*map_instr.to_apply(),
elemental_operands, name);
}
Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
// Pseudo code for reduce window:
//

View File

@ -115,11 +115,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// Emit an LLVM global variable for every constant buffer allocation.
Status EmitConstantGlobals();
// Emit code to map one element according to `map_instr`.
llvm::Value* EmitElementalMap(
const HloMapInstruction& map_instr,
absl::Span<llvm::Value* const> elemental_operands,
absl::string_view name);
// Emit code to emit the element at `index` for a convolution instruction.
StatusOr<llvm::Value*> EmitElementalConvolution(
const HloConvolutionInstruction* convolution,

View File

@ -2422,6 +2422,21 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
-> StatusOr<llvm::Value*> {
return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
};
case HloOpcode::kMap:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
std::vector<llvm::Value*> operands;
for (int i = 0; i < hlo->operand_count(); i++) {
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(i))(index));
operands.push_back(operand_value);
}
std::vector<llvm_ir::ElementGenerator> input_generators;
for (const HloInstruction* instr : hlo->operands()) {
input_generators.push_back(operand_to_generator.at(instr));
}
return EmitElementalMap(Cast<HloMapInstruction>(hlo), operands);
};
case HloOpcode::kReduceWindow:
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
return EmitElementalReduceWindow(
@ -2473,6 +2488,17 @@ llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
return complex;
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap(
const HloMapInstruction* map_instr,
absl::Span<llvm::Value* const> elemental_operands) {
TF_ASSIGN_OR_RETURN(
std::vector<llvm::Value*> values,
EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands,
llvm_ir::IrName(map_instr)));
CHECK_EQ(values.size(), 1);
return values[0];
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduceWindow(
const HloReduceWindowInstruction* reduce_window,
const llvm_ir::ElementGenerator& input_generator,

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
@ -228,6 +229,10 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
absl::string_view name) = 0;
StatusOr<llvm::Value*> EmitElementalMap(
const HloMapInstruction* map_instr,
absl::Span<llvm::Value* const> elemental_operands);
StatusOr<llvm::Value*> EmitElementalReduceWindow(
const HloReduceWindowInstruction* reduce_window,
const llvm_ir::ElementGenerator& input_generator,

View File

@ -305,29 +305,5 @@ llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
}
llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator) {
switch (hlo->opcode()) {
case HloOpcode::kMap:
return [=, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
TF_RET_CHECK(!hlo->operands().empty())
<< "Zero operand map not implemented in GPU backend.";
TF_RET_CHECK(hlo->to_apply()->num_parameters() > 0);
std::vector<llvm::Value*> operand_elements;
for (HloInstruction* operand : hlo->operands()) {
TF_ASSIGN_OR_RETURN(llvm::Value * value,
operand_to_generator.at(operand)(index));
operand_elements.push_back(value);
}
return compute_nested_(*hlo->to_apply(), operand_elements);
};
default:
return ElementalIrEmitter::MakeElementGenerator(hlo,
operand_to_generator);
}
}
} // namespace gpu
} // namespace xla

View File

@ -47,10 +47,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
llvm::Module* module, llvm::IRBuilder<>* b,
NestedComputer compute_nested);
llvm_ir::ElementGenerator MakeElementGenerator(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator) override;
protected:
StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,