[XLA] Expand simple scatter operations into dynamic-update-slice.
This allows them to be fused. PiperOrigin-RevId: 327291810 Change-Id: I8e706a6add56e5e9fb4e9262e886f19ee11ac2df
This commit is contained in:
parent
e99f1d3efb
commit
ad7e6583cd
@ -1843,6 +1843,7 @@ cc_library(
|
||||
":hlo",
|
||||
":hlo_creation_utils",
|
||||
":hlo_pass",
|
||||
":op_expander_pass",
|
||||
":while_util",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -290,7 +290,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
/*expansion_type=*/LogisticExpansionType::kExp);
|
||||
pipeline.AddPass<ConditionalCanonicalizer>();
|
||||
pipeline.AddPass<DynamicPadder>();
|
||||
pipeline.AddPass<ScatterExpander>();
|
||||
pipeline.AddPass<ScatterExpander>(ScatterExpander::kEliminateAllScatters);
|
||||
pipeline.AddPass<ConvCanonicalization>(target_machine_features);
|
||||
{
|
||||
auto& pass =
|
||||
|
@ -201,6 +201,7 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
pass.AddPass<ZeroSizedHloElimination>();
|
||||
|
||||
pass.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
|
||||
pass.AddPass<ScatterExpander>(ScatterExpander::kEliminateSimpleScatters);
|
||||
|
||||
AlgebraicSimplifierOptions options;
|
||||
// When transposes appear in a fusion node, we can easily adjust the
|
||||
|
@ -23,26 +23,11 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<bool> GpuScatterExpander::Run(HloModule* module) {
|
||||
auto is_nontrivial_scatter = [](HloInstruction* inst) {
|
||||
// TODO(b/129698548): Scattering elements larger than 64 bits is not
|
||||
// supported by XLA:GPU.
|
||||
return inst->opcode() == HloOpcode::kScatter &&
|
||||
inst->shape().element_type() == C128;
|
||||
};
|
||||
|
||||
std::vector<HloInstruction*> scatter_instrs;
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
absl::c_copy_if(computation->instructions(),
|
||||
std::back_inserter(scatter_instrs), is_nontrivial_scatter);
|
||||
}
|
||||
|
||||
for (HloInstruction* inst : scatter_instrs) {
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(inst));
|
||||
TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root));
|
||||
}
|
||||
|
||||
return !scatter_instrs.empty();
|
||||
bool GpuScatterExpander::InstructionMatchesPattern(HloInstruction* inst) {
|
||||
// TODO(b/129698548): Scattering elements larger than 64 bits is not
|
||||
// supported by XLA:GPU.
|
||||
return inst->opcode() == HloOpcode::kScatter &&
|
||||
primitive_util::BitWidth(inst->shape().element_type()) > 64;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -20,10 +20,17 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Legalizes scatters on the GPU.
|
||||
class GpuScatterExpander : public ScatterExpander {
|
||||
public:
|
||||
// Although we pass kEliminateAllScatters, we override this behavior in
|
||||
// InstruuctionMatchesPattern and select only some scatters to expand.
|
||||
GpuScatterExpander() : ScatterExpander(kEliminateAllScatters) {}
|
||||
|
||||
absl::string_view name() const override { return "gpu_scatter_expander"; }
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
protected:
|
||||
bool InstructionMatchesPattern(HloInstruction* inst) override;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -325,6 +325,22 @@ static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
|
||||
{updated_operand, scatter_indices, updates}};
|
||||
}
|
||||
|
||||
static int64 ScatterTripCount(HloInstruction* scatter) {
|
||||
// Compute the trip count for the while loop to be used for scatter. This
|
||||
// should be the number of indices we should scatter into the operand.
|
||||
HloInstruction* scatter_indices = scatter->mutable_operand(1);
|
||||
const Shape& scatter_indices_shape = scatter_indices->shape();
|
||||
const ScatterDimensionNumbers& dim_numbers =
|
||||
scatter->scatter_dimension_numbers();
|
||||
int64 scatter_loop_trip_count = 1;
|
||||
for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
|
||||
if (i != dim_numbers.index_vector_dim()) {
|
||||
scatter_loop_trip_count *= scatter_indices_shape.dimensions(i);
|
||||
}
|
||||
}
|
||||
return scatter_loop_trip_count;
|
||||
}
|
||||
|
||||
// High Level Algorithm.
|
||||
//
|
||||
// 1. Canonicalize the scatter_indices tensor such that it has rank 2, where
|
||||
@ -342,7 +358,7 @@ static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
|
||||
// from c. and d. using the update_computation of scatter.
|
||||
// f. Write the updated value of the slice into the operand tensor.
|
||||
|
||||
StatusOr<HloInstruction*> ScatterExpander::ExpandScatter(
|
||||
StatusOr<HloInstruction*> ScatterExpander::ExpandInstruction(
|
||||
HloInstruction* scatter) {
|
||||
HloInstruction* operand = scatter->mutable_operand(0);
|
||||
HloInstruction* scatter_indices = scatter->mutable_operand(1);
|
||||
@ -358,13 +374,7 @@ StatusOr<HloInstruction*> ScatterExpander::ExpandScatter(
|
||||
|
||||
// Compute the trip count for the while loop to be used for scatter. This
|
||||
// should be the number of indices we should scatter into the operand.
|
||||
const Shape& scatter_indices_shape = scatter_indices->shape();
|
||||
int64 scatter_loop_trip_count = 1;
|
||||
for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
|
||||
if (i != dim_numbers.index_vector_dim()) {
|
||||
scatter_loop_trip_count *= scatter_indices_shape.dimensions(i);
|
||||
}
|
||||
}
|
||||
int64 scatter_loop_trip_count = ScatterTripCount(scatter);
|
||||
if (!IsInt32(scatter_loop_trip_count)) {
|
||||
return Unimplemented(
|
||||
"Scatter operations with more than 2147483647 scatter indices are not "
|
||||
@ -408,23 +418,9 @@ StatusOr<HloInstruction*> ScatterExpander::ExpandScatter(
|
||||
return scatter_loop_result.front();
|
||||
}
|
||||
|
||||
StatusOr<bool> ScatterExpander::Run(HloModule* module) {
|
||||
std::vector<HloInstruction*> scatter_instrs;
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
for (HloInstruction* instr : computation->instructions()) {
|
||||
if (instr->opcode() == HloOpcode::kScatter) {
|
||||
scatter_instrs.push_back(instr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto instr : scatter_instrs) {
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(instr));
|
||||
TF_RETURN_IF_ERROR(
|
||||
instr->parent()->ReplaceInstruction(instr, expanded_root));
|
||||
}
|
||||
|
||||
return !scatter_instrs.empty();
|
||||
bool ScatterExpander::InstructionMatchesPattern(HloInstruction* inst) {
|
||||
return inst->opcode() == HloOpcode::kScatter &&
|
||||
(mode_ == kEliminateAllScatters || ScatterTripCount(inst) == 1);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -16,17 +16,43 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
class ScatterExpander : public HloModulePass {
|
||||
// This pass rewrites scatter operations into (roughly) while loops of
|
||||
// dynamic-update-slices.
|
||||
//
|
||||
// This pass can be used in two ways:
|
||||
//
|
||||
// - kEliminateAllScatters: For backends that don't support scatter, this pass
|
||||
// can convert every scatter into a loop.
|
||||
//
|
||||
// - kEliminateSimpleScatters: For backends that *do* support scatter, this
|
||||
// pass can strength-reduce "simple" scatters -- specifically, scatters that
|
||||
// can be represented without a loop -- to dynamic-update-slices.
|
||||
//
|
||||
// Note that even in kEliminateSimpleScatters mode, this pass may still expand a
|
||||
// scatter into a loop (with a trip-count of 1). It's up to other
|
||||
// simplification passes to remove the loop.
|
||||
class ScatterExpander : public OpExpanderPass {
|
||||
public:
|
||||
enum Mode {
|
||||
kEliminateAllScatters,
|
||||
kEliminateSimpleScatters,
|
||||
};
|
||||
|
||||
explicit ScatterExpander(Mode m) : mode_(m) {}
|
||||
|
||||
absl::string_view name() const override { return "scatter_expander"; }
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
protected:
|
||||
StatusOr<HloInstruction*> ExpandScatter(HloInstruction* scatter);
|
||||
bool InstructionMatchesPattern(HloInstruction* inst) override;
|
||||
|
||||
StatusOr<HloInstruction*> ExpandInstruction(HloInstruction* scatter) override;
|
||||
|
||||
private:
|
||||
Mode mode_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -57,11 +57,79 @@ TEST_F(ScatterExpanderTest, ScatterOperandWithoutLayout) {
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
|
||||
// The HLO parser changes all no layout shapes from the input to have a
|
||||
// default layout, clear the layout of the scatter operand for testing.
|
||||
// default layout. Clear the layout of the scatter operand for testing.
|
||||
HloInstruction* scatter_operand = FindInstruction(module.get(), "operand");
|
||||
scatter_operand->mutable_shape()->clear_layout();
|
||||
|
||||
ScatterExpander scatter_expander;
|
||||
ScatterExpander scatter_expander(ScatterExpander::kEliminateAllScatters);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result,
|
||||
RunHloPass(&scatter_expander, module.get()));
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
TEST_F(ScatterExpanderTest, EliminateSimpleScattersSkipsNontrivialScatter) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule scatter_expander
|
||||
|
||||
scatter_computation {
|
||||
parameter0 = s32[] parameter(0)
|
||||
ROOT parameter1 = s32[] parameter(1)
|
||||
}
|
||||
|
||||
ENTRY kernel_entry {
|
||||
operand = s32[3,3] parameter(0)
|
||||
indices = s32[2] parameter(1)
|
||||
updates = s32[2,3] parameter(2)
|
||||
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
|
||||
to_apply=scatter_computation,
|
||||
update_window_dims={1},
|
||||
inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=1
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
|
||||
// The HLO parser changes all no layout shapes from the input to have a
|
||||
// default layout. Clear the layout of the scatter operand for testing.
|
||||
HloInstruction* scatter_operand = FindInstruction(module.get(), "operand");
|
||||
scatter_operand->mutable_shape()->clear_layout();
|
||||
|
||||
ScatterExpander scatter_expander(ScatterExpander::kEliminateSimpleScatters);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result,
|
||||
RunHloPass(&scatter_expander, module.get()));
|
||||
EXPECT_FALSE(result);
|
||||
}
|
||||
|
||||
TEST_F(ScatterExpanderTest, EliminateSimpleScattersRewritesTrivialScatter) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule scatter_expander
|
||||
|
||||
scatter_computation {
|
||||
parameter0 = s32[] parameter(0)
|
||||
ROOT parameter1 = s32[] parameter(1)
|
||||
}
|
||||
|
||||
ENTRY kernel_entry {
|
||||
operand = s32[5] iota(), iota_dimension=0
|
||||
indices = s32[1] parameter(0)
|
||||
update = s32[] constant(0)
|
||||
ROOT scatter = s32[5]{0} scatter(operand, indices, update),
|
||||
update_window_dims={}, inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0}, index_vector_dim=0,
|
||||
to_apply=scatter_computation
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
|
||||
// The HLO parser changes all no layout shapes from the input to have a
|
||||
// default layout. Clear the layout of the scatter operand for testing.
|
||||
HloInstruction* scatter_operand = FindInstruction(module.get(), "operand");
|
||||
scatter_operand->mutable_shape()->clear_layout();
|
||||
|
||||
ScatterExpander scatter_expander(ScatterExpander::kEliminateSimpleScatters);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result,
|
||||
RunHloPass(&scatter_expander, module.get()));
|
||||
EXPECT_TRUE(result);
|
||||
|
Loading…
x
Reference in New Issue
Block a user