From d5c1743ddef5ce653645dc85fd7437c044df9e7a Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 11 May 2020 07:47:39 -0700 Subject: [PATCH] [XLA:CPU/GPU] Merge the emission of elemental kMap There's not a lot of duplication here, but no need to have it twice. PiperOrigin-RevId: 310910166 Change-Id: I6dfff87d56f4cc1788344300e826975cc38fe452 --- tensorflow/compiler/xla/service/BUILD | 1 + .../xla/service/cpu/elemental_ir_emitter.cc | 12 --------- .../compiler/xla/service/cpu/ir_emitter.cc | 7 ----- .../compiler/xla/service/cpu/ir_emitter.h | 5 ---- .../xla/service/elemental_ir_emitter.cc | 26 +++++++++++++++++++ .../xla/service/elemental_ir_emitter.h | 5 ++++ .../xla/service/gpu/elemental_ir_emitter.cc | 24 ----------------- .../xla/service/gpu/elemental_ir_emitter.h | 4 --- 8 files changed, 32 insertions(+), 52 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 499c4e25828..3349528ebc2 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3846,6 +3846,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:core", "@llvm-project//llvm:transform_utils", ], diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index ccd17bb791d..05364a4492b 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -109,18 +109,6 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) { switch (hlo->opcode()) { - case HloOpcode::kMap: - return [this, hlo, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { - std::vector operands; - for (int i = 0; i < hlo->operand_count(); i++) { - TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(i))(index)); - operands.push_back(operand_value); - } - return ir_emitter_->EmitElementalMap(*Cast(hlo), - operands, llvm_ir::IrName(hlo)); - }; case HloOpcode::kConvolution: return [this, hlo, &operand_to_generator](const IrArray::Index& index) { return ir_emitter_->EmitElementalConvolution( diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 78d859cb34a..2b715bfa17a 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -695,13 +695,6 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -llvm::Value* IrEmitter::EmitElementalMap( - const HloMapInstruction& map_instr, - absl::Span elemental_operands, absl::string_view name) { - return EmitScalarReturningThreadLocalCall(*map_instr.to_apply(), - elemental_operands, name); -} - Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { // Pseudo code for reduce window: // diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index c5e05db40bd..24524c67b11 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -115,11 +115,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Emit an LLVM global variable for every constant buffer allocation. Status EmitConstantGlobals(); - // Emit code to map one element according to `map_instr`. - llvm::Value* EmitElementalMap( - const HloMapInstruction& map_instr, - absl::Span elemental_operands, - absl::string_view name); // Emit code to emit the element at `index` for a convolution instruction. StatusOr EmitElementalConvolution( const HloConvolutionInstruction* convolution, diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index b4ea18a8a1e..8cb660de46c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -2422,6 +2422,21 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( -> StatusOr { return EmitElementalDot(hlo, operand_to_generator, dot_result_index); }; + case HloOpcode::kMap: + return [this, hlo, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + std::vector operands; + for (int i = 0; i < hlo->operand_count(); i++) { + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(i))(index)); + operands.push_back(operand_value); + } + std::vector input_generators; + for (const HloInstruction* instr : hlo->operands()) { + input_generators.push_back(operand_to_generator.at(instr)); + } + return EmitElementalMap(Cast(hlo), operands); + }; case HloOpcode::kReduceWindow: return [this, hlo, &operand_to_generator](const IrArray::Index& index) { return EmitElementalReduceWindow( @@ -2473,6 +2488,17 @@ llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, return complex; } +StatusOr ElementalIrEmitter::EmitElementalMap( + const HloMapInstruction* map_instr, + absl::Span elemental_operands) { + TF_ASSIGN_OR_RETURN( + std::vector values, + EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands, + llvm_ir::IrName(map_instr))); + CHECK_EQ(values.size(), 1); + return values[0]; +} + StatusOr ElementalIrEmitter::EmitElementalReduceWindow( const HloReduceWindowInstruction* reduce_window, const llvm_ir::ElementGenerator& input_generator, diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 270ec358f5e..06a9d7b194c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" @@ -228,6 +229,10 @@ class ElementalIrEmitter : public IrBuilderMixin { const HloComputation& callee, absl::Span parameters, absl::string_view name) = 0; + StatusOr EmitElementalMap( + const HloMapInstruction* map_instr, + absl::Span elemental_operands); + StatusOr EmitElementalReduceWindow( const HloReduceWindowInstruction* reduce_window, const llvm_ir::ElementGenerator& input_generator, diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 0a44cd8cc69..1be0b1b4e7b 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -305,29 +305,5 @@ llvm::Value* GpuElementalIrEmitter::EmitThreadId() { return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block); } -llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( - const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) { - switch (hlo->opcode()) { - case HloOpcode::kMap: - return [=, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { - TF_RET_CHECK(!hlo->operands().empty()) - << "Zero operand map not implemented in GPU backend."; - TF_RET_CHECK(hlo->to_apply()->num_parameters() > 0); - std::vector operand_elements; - for (HloInstruction* operand : hlo->operands()) { - TF_ASSIGN_OR_RETURN(llvm::Value * value, - operand_to_generator.at(operand)(index)); - operand_elements.push_back(value); - } - return compute_nested_(*hlo->to_apply(), operand_elements); - }; - default: - return ElementalIrEmitter::MakeElementGenerator(hlo, - operand_to_generator); - } -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index c846a6fa939..3c4e9f7c1e6 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -47,10 +47,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { llvm::Module* module, llvm::IRBuilder<>* b, NestedComputer compute_nested); - llvm_ir::ElementGenerator MakeElementGenerator( - const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) override; - protected: StatusOr EmitFloatBinaryOp(const HloInstruction* op, llvm::Value* lhs_value,