[XLA:CPU/GPU] Merge the emission of elemental kMap
There's not a lot of duplication here, but no need to have it twice. PiperOrigin-RevId: 310910166 Change-Id: I6dfff87d56f4cc1788344300e826975cc38fe452
This commit is contained in:
parent
d32ec0bf0b
commit
d5c1743dde
|
@ -3846,6 +3846,7 @@ cc_library(
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
"@llvm-project//llvm:core",
|
"@llvm-project//llvm:core",
|
||||||
"@llvm-project//llvm:transform_utils",
|
"@llvm-project//llvm:transform_utils",
|
||||||
],
|
],
|
||||||
|
|
|
@ -109,18 +109,6 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
|
||||||
const HloInstruction* hlo,
|
const HloInstruction* hlo,
|
||||||
const HloToElementGeneratorMap& operand_to_generator) {
|
const HloToElementGeneratorMap& operand_to_generator) {
|
||||||
switch (hlo->opcode()) {
|
switch (hlo->opcode()) {
|
||||||
case HloOpcode::kMap:
|
|
||||||
return [this, hlo, &operand_to_generator](
|
|
||||||
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
|
||||||
std::vector<llvm::Value*> operands;
|
|
||||||
for (int i = 0; i < hlo->operand_count(); i++) {
|
|
||||||
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
|
|
||||||
operand_to_generator.at(hlo->operand(i))(index));
|
|
||||||
operands.push_back(operand_value);
|
|
||||||
}
|
|
||||||
return ir_emitter_->EmitElementalMap(*Cast<HloMapInstruction>(hlo),
|
|
||||||
operands, llvm_ir::IrName(hlo));
|
|
||||||
};
|
|
||||||
case HloOpcode::kConvolution:
|
case HloOpcode::kConvolution:
|
||||||
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
||||||
return ir_emitter_->EmitElementalConvolution(
|
return ir_emitter_->EmitElementalConvolution(
|
||||||
|
|
|
@ -695,13 +695,6 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::Value* IrEmitter::EmitElementalMap(
|
|
||||||
const HloMapInstruction& map_instr,
|
|
||||||
absl::Span<llvm::Value* const> elemental_operands, absl::string_view name) {
|
|
||||||
return EmitScalarReturningThreadLocalCall(*map_instr.to_apply(),
|
|
||||||
elemental_operands, name);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
|
Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
|
||||||
// Pseudo code for reduce window:
|
// Pseudo code for reduce window:
|
||||||
//
|
//
|
||||||
|
|
|
@ -115,11 +115,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||||
// Emit an LLVM global variable for every constant buffer allocation.
|
// Emit an LLVM global variable for every constant buffer allocation.
|
||||||
Status EmitConstantGlobals();
|
Status EmitConstantGlobals();
|
||||||
|
|
||||||
// Emit code to map one element according to `map_instr`.
|
|
||||||
llvm::Value* EmitElementalMap(
|
|
||||||
const HloMapInstruction& map_instr,
|
|
||||||
absl::Span<llvm::Value* const> elemental_operands,
|
|
||||||
absl::string_view name);
|
|
||||||
// Emit code to emit the element at `index` for a convolution instruction.
|
// Emit code to emit the element at `index` for a convolution instruction.
|
||||||
StatusOr<llvm::Value*> EmitElementalConvolution(
|
StatusOr<llvm::Value*> EmitElementalConvolution(
|
||||||
const HloConvolutionInstruction* convolution,
|
const HloConvolutionInstruction* convolution,
|
||||||
|
|
|
@ -2422,6 +2422,21 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||||
-> StatusOr<llvm::Value*> {
|
-> StatusOr<llvm::Value*> {
|
||||||
return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
|
return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
|
||||||
};
|
};
|
||||||
|
case HloOpcode::kMap:
|
||||||
|
return [this, hlo, &operand_to_generator](
|
||||||
|
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||||
|
std::vector<llvm::Value*> operands;
|
||||||
|
for (int i = 0; i < hlo->operand_count(); i++) {
|
||||||
|
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
|
||||||
|
operand_to_generator.at(hlo->operand(i))(index));
|
||||||
|
operands.push_back(operand_value);
|
||||||
|
}
|
||||||
|
std::vector<llvm_ir::ElementGenerator> input_generators;
|
||||||
|
for (const HloInstruction* instr : hlo->operands()) {
|
||||||
|
input_generators.push_back(operand_to_generator.at(instr));
|
||||||
|
}
|
||||||
|
return EmitElementalMap(Cast<HloMapInstruction>(hlo), operands);
|
||||||
|
};
|
||||||
case HloOpcode::kReduceWindow:
|
case HloOpcode::kReduceWindow:
|
||||||
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
||||||
return EmitElementalReduceWindow(
|
return EmitElementalReduceWindow(
|
||||||
|
@ -2473,6 +2488,17 @@ llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
|
||||||
return complex;
|
return complex;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap(
|
||||||
|
const HloMapInstruction* map_instr,
|
||||||
|
absl::Span<llvm::Value* const> elemental_operands) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::vector<llvm::Value*> values,
|
||||||
|
EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands,
|
||||||
|
llvm_ir::IrName(map_instr)));
|
||||||
|
CHECK_EQ(values.size(), 1);
|
||||||
|
return values[0];
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduceWindow(
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduceWindow(
|
||||||
const HloReduceWindowInstruction* reduce_window,
|
const HloReduceWindowInstruction* reduce_window,
|
||||||
const llvm_ir::ElementGenerator& input_generator,
|
const llvm_ir::ElementGenerator& input_generator,
|
||||||
|
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/types/span.h"
|
||||||
#include "llvm/IR/IRBuilder.h"
|
#include "llvm/IR/IRBuilder.h"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "llvm/IR/Value.h"
|
#include "llvm/IR/Value.h"
|
||||||
|
@ -228,6 +229,10 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
||||||
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
|
||||||
absl::string_view name) = 0;
|
absl::string_view name) = 0;
|
||||||
|
|
||||||
|
StatusOr<llvm::Value*> EmitElementalMap(
|
||||||
|
const HloMapInstruction* map_instr,
|
||||||
|
absl::Span<llvm::Value* const> elemental_operands);
|
||||||
|
|
||||||
StatusOr<llvm::Value*> EmitElementalReduceWindow(
|
StatusOr<llvm::Value*> EmitElementalReduceWindow(
|
||||||
const HloReduceWindowInstruction* reduce_window,
|
const HloReduceWindowInstruction* reduce_window,
|
||||||
const llvm_ir::ElementGenerator& input_generator,
|
const llvm_ir::ElementGenerator& input_generator,
|
||||||
|
|
|
@ -305,29 +305,5 @@ llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
|
||||||
return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
|
return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
|
|
||||||
const HloInstruction* hlo,
|
|
||||||
const HloToElementGeneratorMap& operand_to_generator) {
|
|
||||||
switch (hlo->opcode()) {
|
|
||||||
case HloOpcode::kMap:
|
|
||||||
return [=, &operand_to_generator](
|
|
||||||
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
|
||||||
TF_RET_CHECK(!hlo->operands().empty())
|
|
||||||
<< "Zero operand map not implemented in GPU backend.";
|
|
||||||
TF_RET_CHECK(hlo->to_apply()->num_parameters() > 0);
|
|
||||||
std::vector<llvm::Value*> operand_elements;
|
|
||||||
for (HloInstruction* operand : hlo->operands()) {
|
|
||||||
TF_ASSIGN_OR_RETURN(llvm::Value * value,
|
|
||||||
operand_to_generator.at(operand)(index));
|
|
||||||
operand_elements.push_back(value);
|
|
||||||
}
|
|
||||||
return compute_nested_(*hlo->to_apply(), operand_elements);
|
|
||||||
};
|
|
||||||
default:
|
|
||||||
return ElementalIrEmitter::MakeElementGenerator(hlo,
|
|
||||||
operand_to_generator);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
@ -47,10 +47,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
|
||||||
llvm::Module* module, llvm::IRBuilder<>* b,
|
llvm::Module* module, llvm::IRBuilder<>* b,
|
||||||
NestedComputer compute_nested);
|
NestedComputer compute_nested);
|
||||||
|
|
||||||
llvm_ir::ElementGenerator MakeElementGenerator(
|
|
||||||
const HloInstruction* hlo,
|
|
||||||
const HloToElementGeneratorMap& operand_to_generator) override;
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
|
StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
|
||||||
llvm::Value* lhs_value,
|
llvm::Value* lhs_value,
|
||||||
|
|
Loading…
Reference in New Issue