[XLA:CPU/GPU] Merge the emission of elemental kMap
There's not a lot of duplication here, but no need to have it twice. PiperOrigin-RevId: 310910166 Change-Id: I6dfff87d56f4cc1788344300e826975cc38fe452
This commit is contained in:
parent
d32ec0bf0b
commit
d5c1743dde
|
@ -3846,6 +3846,7 @@ cc_library(
|
|||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:core",
|
||||
"@llvm-project//llvm:transform_utils",
|
||||
],
|
||||
|
|
|
@ -109,18 +109,6 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
|
|||
const HloInstruction* hlo,
|
||||
const HloToElementGeneratorMap& operand_to_generator) {
|
||||
switch (hlo->opcode()) {
|
||||
case HloOpcode::kMap:
|
||||
return [this, hlo, &operand_to_generator](
|
||||
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||
std::vector<llvm::Value*> operands;
|
||||
for (int i = 0; i < hlo->operand_count(); i++) {
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
|
||||
operand_to_generator.at(hlo->operand(i))(index));
|
||||
operands.push_back(operand_value);
|
||||
}
|
||||
return ir_emitter_->EmitElementalMap(*Cast<HloMapInstruction>(hlo),
|
||||
operands, llvm_ir::IrName(hlo));
|
||||
};
|
||||
case HloOpcode::kConvolution:
|
||||
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
||||
return ir_emitter_->EmitElementalConvolution(
|
||||
|
|
|
@ -695,13 +695,6 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
llvm::Value* IrEmitter::EmitElementalMap(
|
||||
const HloMapInstruction& map_instr,
|
||||
absl::Span<llvm::Value* const> elemental_operands, absl::string_view name) {
|
||||
return EmitScalarReturningThreadLocalCall(*map_instr.to_apply(),
|
||||
elemental_operands, name);
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
|
||||
// Pseudo code for reduce window:
|
||||
//
|
||||
|
|
|
@ -115,11 +115,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
|||
// Emit an LLVM global variable for every constant buffer allocation.
|
||||
Status EmitConstantGlobals();
|
||||
|
||||
// Emit code to map one element according to `map_instr`.
|
||||
llvm::Value* EmitElementalMap(
|
||||
const HloMapInstruction& map_instr,
|
||||
absl::Span<llvm::Value* const> elemental_operands,
|
||||
absl::string_view name);
|
||||
// Emit code to emit the element at `index` for a convolution instruction.
|
||||
StatusOr<llvm::Value*> EmitElementalConvolution(
|
||||
const HloConvolutionInstruction* convolution,
|
||||
|
|
|
@ -2422,6 +2422,21 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
|||
-> StatusOr<llvm::Value*> {
|
||||
return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
|
||||
};
|
||||
case HloOpcode::kMap:
|
||||
return [this, hlo, &operand_to_generator](
|
||||
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||
std::vector<llvm::Value*> operands;
|
||||
for (int i = 0; i < hlo->operand_count(); i++) {
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
|
||||
operand_to_generator.at(hlo->operand(i))(index));
|
||||
operands.push_back(operand_value);
|
||||
}
|
||||
std::vector<llvm_ir::ElementGenerator> input_generators;
|
||||
for (const HloInstruction* instr : hlo->operands()) {
|
||||
input_generators.push_back(operand_to_generator.at(instr));
|
||||
}
|
||||
return EmitElementalMap(Cast<HloMapInstruction>(hlo), operands);
|
||||
};
|
||||
case HloOpcode::kReduceWindow:
|
||||
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
||||
return EmitElementalReduceWindow(
|
||||
|
@ -2473,6 +2488,17 @@ llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
|
|||
return complex;
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap(
|
||||
const HloMapInstruction* map_instr,
|
||||
absl::Span<llvm::Value* const> elemental_operands) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<llvm::Value*> values,
|
||||
EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands,
|
||||
llvm_ir::IrName(map_instr)));
|
||||
CHECK_EQ(values.size(), 1);
|
||||
return values[0];
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduceWindow(
|
||||
const HloReduceWindowInstruction* reduce_window,
|
||||
const llvm_ir::ElementGenerator& input_generator,
|
||||
|
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
|
@ -228,6 +229,10 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
|||
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
||||
absl::string_view name) = 0;
|
||||
|
||||
StatusOr<llvm::Value*> EmitElementalMap(
|
||||
const HloMapInstruction* map_instr,
|
||||
absl::Span<llvm::Value* const> elemental_operands);
|
||||
|
||||
StatusOr<llvm::Value*> EmitElementalReduceWindow(
|
||||
const HloReduceWindowInstruction* reduce_window,
|
||||
const llvm_ir::ElementGenerator& input_generator,
|
||||
|
|
|
@ -305,29 +305,5 @@ llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
|
|||
return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
|
||||
}
|
||||
|
||||
llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
|
||||
const HloInstruction* hlo,
|
||||
const HloToElementGeneratorMap& operand_to_generator) {
|
||||
switch (hlo->opcode()) {
|
||||
case HloOpcode::kMap:
|
||||
return [=, &operand_to_generator](
|
||||
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||
TF_RET_CHECK(!hlo->operands().empty())
|
||||
<< "Zero operand map not implemented in GPU backend.";
|
||||
TF_RET_CHECK(hlo->to_apply()->num_parameters() > 0);
|
||||
std::vector<llvm::Value*> operand_elements;
|
||||
for (HloInstruction* operand : hlo->operands()) {
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * value,
|
||||
operand_to_generator.at(operand)(index));
|
||||
operand_elements.push_back(value);
|
||||
}
|
||||
return compute_nested_(*hlo->to_apply(), operand_elements);
|
||||
};
|
||||
default:
|
||||
return ElementalIrEmitter::MakeElementGenerator(hlo,
|
||||
operand_to_generator);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
|
|
@ -47,10 +47,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
|
|||
llvm::Module* module, llvm::IRBuilder<>* b,
|
||||
NestedComputer compute_nested);
|
||||
|
||||
llvm_ir::ElementGenerator MakeElementGenerator(
|
||||
const HloInstruction* hlo,
|
||||
const HloToElementGeneratorMap& operand_to_generator) override;
|
||||
|
||||
protected:
|
||||
StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
|
||||
llvm::Value* lhs_value,
|
||||
|
|
Loading…
Reference in New Issue