diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index d51462ba073..dd16bd32dd1 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1843,6 +1843,7 @@ cc_library( ":hlo", ":hlo_creation_utils", ":hlo_pass", + ":op_expander_pass", ":while_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 45cb18c4de6..7b72d7ade54 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -290,7 +290,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*expansion_type=*/LogisticExpansionType::kExp); pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(ScatterExpander::kEliminateAllScatters); pipeline.AddPass(target_machine_features); { auto& pass = diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index b2caa2ddcf4..77fcf2c59f7 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -201,6 +201,7 @@ Status GpuCompiler::OptimizeHloModule( pass.AddPass(); pass.AddPass(GatherExpander::kEliminateSimpleGathers); + pass.AddPass(ScatterExpander::kEliminateSimpleScatters); AlgebraicSimplifierOptions options; // When transposes appear in a fusion node, we can easily adjust the diff --git a/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.cc b/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.cc index 6287f1e3ca2..31f011fa734 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.cc @@ -23,26 +23,11 @@ limitations under the License. namespace xla { -StatusOr GpuScatterExpander::Run(HloModule* module) { - auto is_nontrivial_scatter = [](HloInstruction* inst) { - // TODO(b/129698548): Scattering elements larger than 64 bits is not - // supported by XLA:GPU. - return inst->opcode() == HloOpcode::kScatter && - inst->shape().element_type() == C128; - }; - - std::vector scatter_instrs; - for (HloComputation* computation : module->MakeNonfusionComputations()) { - absl::c_copy_if(computation->instructions(), - std::back_inserter(scatter_instrs), is_nontrivial_scatter); - } - - for (HloInstruction* inst : scatter_instrs) { - TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(inst)); - TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root)); - } - - return !scatter_instrs.empty(); +bool GpuScatterExpander::InstructionMatchesPattern(HloInstruction* inst) { + // TODO(b/129698548): Scattering elements larger than 64 bits is not + // supported by XLA:GPU. + return inst->opcode() == HloOpcode::kScatter && + primitive_util::BitWidth(inst->shape().element_type()) > 64; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h b/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h index 0818b32474f..92acb909729 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h @@ -20,10 +20,17 @@ limitations under the License. namespace xla { +// Legalizes scatters on the GPU. class GpuScatterExpander : public ScatterExpander { public: + // Although we pass kEliminateAllScatters, we override this behavior in + // InstruuctionMatchesPattern and select only some scatters to expand. + GpuScatterExpander() : ScatterExpander(kEliminateAllScatters) {} + absl::string_view name() const override { return "gpu_scatter_expander"; } - StatusOr Run(HloModule* module) override; + + protected: + bool InstructionMatchesPattern(HloInstruction* inst) override; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index e3a3feb8640..bd99f920ea0 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -325,6 +325,22 @@ static StatusOr> ScatterLoopBody( {updated_operand, scatter_indices, updates}}; } +static int64 ScatterTripCount(HloInstruction* scatter) { + // Compute the trip count for the while loop to be used for scatter. This + // should be the number of indices we should scatter into the operand. + HloInstruction* scatter_indices = scatter->mutable_operand(1); + const Shape& scatter_indices_shape = scatter_indices->shape(); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + int64 scatter_loop_trip_count = 1; + for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + scatter_loop_trip_count *= scatter_indices_shape.dimensions(i); + } + } + return scatter_loop_trip_count; +} + // High Level Algorithm. // // 1. Canonicalize the scatter_indices tensor such that it has rank 2, where @@ -342,7 +358,7 @@ static StatusOr> ScatterLoopBody( // from c. and d. using the update_computation of scatter. // f. Write the updated value of the slice into the operand tensor. -StatusOr ScatterExpander::ExpandScatter( +StatusOr ScatterExpander::ExpandInstruction( HloInstruction* scatter) { HloInstruction* operand = scatter->mutable_operand(0); HloInstruction* scatter_indices = scatter->mutable_operand(1); @@ -358,13 +374,7 @@ StatusOr ScatterExpander::ExpandScatter( // Compute the trip count for the while loop to be used for scatter. This // should be the number of indices we should scatter into the operand. - const Shape& scatter_indices_shape = scatter_indices->shape(); - int64 scatter_loop_trip_count = 1; - for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { - if (i != dim_numbers.index_vector_dim()) { - scatter_loop_trip_count *= scatter_indices_shape.dimensions(i); - } - } + int64 scatter_loop_trip_count = ScatterTripCount(scatter); if (!IsInt32(scatter_loop_trip_count)) { return Unimplemented( "Scatter operations with more than 2147483647 scatter indices are not " @@ -408,23 +418,9 @@ StatusOr ScatterExpander::ExpandScatter( return scatter_loop_result.front(); } -StatusOr ScatterExpander::Run(HloModule* module) { - std::vector scatter_instrs; - for (HloComputation* computation : module->MakeNonfusionComputations()) { - for (HloInstruction* instr : computation->instructions()) { - if (instr->opcode() == HloOpcode::kScatter) { - scatter_instrs.push_back(instr); - } - } - } - - for (auto instr : scatter_instrs) { - TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(instr)); - TF_RETURN_IF_ERROR( - instr->parent()->ReplaceInstruction(instr, expanded_root)); - } - - return !scatter_instrs.empty(); +bool ScatterExpander::InstructionMatchesPattern(HloInstruction* inst) { + return inst->opcode() == HloOpcode::kScatter && + (mode_ == kEliminateAllScatters || ScatterTripCount(inst) == 1); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h index 533af060bc9..aa59e7ec3b0 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.h +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -16,17 +16,43 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" namespace xla { -class ScatterExpander : public HloModulePass { +// This pass rewrites scatter operations into (roughly) while loops of +// dynamic-update-slices. +// +// This pass can be used in two ways: +// +// - kEliminateAllScatters: For backends that don't support scatter, this pass +// can convert every scatter into a loop. +// +// - kEliminateSimpleScatters: For backends that *do* support scatter, this +// pass can strength-reduce "simple" scatters -- specifically, scatters that +// can be represented without a loop -- to dynamic-update-slices. +// +// Note that even in kEliminateSimpleScatters mode, this pass may still expand a +// scatter into a loop (with a trip-count of 1). It's up to other +// simplification passes to remove the loop. +class ScatterExpander : public OpExpanderPass { public: + enum Mode { + kEliminateAllScatters, + kEliminateSimpleScatters, + }; + + explicit ScatterExpander(Mode m) : mode_(m) {} + absl::string_view name() const override { return "scatter_expander"; } - StatusOr Run(HloModule* module) override; protected: - StatusOr ExpandScatter(HloInstruction* scatter); + bool InstructionMatchesPattern(HloInstruction* inst) override; + + StatusOr ExpandInstruction(HloInstruction* scatter) override; + + private: + Mode mode_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/scatter_expander_test.cc b/tensorflow/compiler/xla/service/scatter_expander_test.cc index 3852b82c1ef..9f4cc5406d8 100644 --- a/tensorflow/compiler/xla/service/scatter_expander_test.cc +++ b/tensorflow/compiler/xla/service/scatter_expander_test.cc @@ -57,11 +57,79 @@ TEST_F(ScatterExpanderTest, ScatterOperandWithoutLayout) { ParseAndReturnVerifiedModule(kModuleStr)); // The HLO parser changes all no layout shapes from the input to have a - // default layout, clear the layout of the scatter operand for testing. + // default layout. Clear the layout of the scatter operand for testing. HloInstruction* scatter_operand = FindInstruction(module.get(), "operand"); scatter_operand->mutable_shape()->clear_layout(); - ScatterExpander scatter_expander; + ScatterExpander scatter_expander(ScatterExpander::kEliminateAllScatters); + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&scatter_expander, module.get())); + EXPECT_TRUE(result); +} + +TEST_F(ScatterExpanderTest, EliminateSimpleScattersSkipsNontrivialScatter) { + const char* kModuleStr = R"( + HloModule scatter_expander + + scatter_computation { + parameter0 = s32[] parameter(0) + ROOT parameter1 = s32[] parameter(1) + } + + ENTRY kernel_entry { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=scatter_computation, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + // The HLO parser changes all no layout shapes from the input to have a + // default layout. Clear the layout of the scatter operand for testing. + HloInstruction* scatter_operand = FindInstruction(module.get(), "operand"); + scatter_operand->mutable_shape()->clear_layout(); + + ScatterExpander scatter_expander(ScatterExpander::kEliminateSimpleScatters); + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&scatter_expander, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(ScatterExpanderTest, EliminateSimpleScattersRewritesTrivialScatter) { + const char* kModuleStr = R"( + HloModule scatter_expander + + scatter_computation { + parameter0 = s32[] parameter(0) + ROOT parameter1 = s32[] parameter(1) + } + + ENTRY kernel_entry { + operand = s32[5] iota(), iota_dimension=0 + indices = s32[1] parameter(0) + update = s32[] constant(0) + ROOT scatter = s32[5]{0} scatter(operand, indices, update), + update_window_dims={}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, index_vector_dim=0, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + // The HLO parser changes all no layout shapes from the input to have a + // default layout. Clear the layout of the scatter operand for testing. + HloInstruction* scatter_operand = FindInstruction(module.get(), "operand"); + scatter_operand->mutable_shape()->clear_layout(); + + ScatterExpander scatter_expander(ScatterExpander::kEliminateSimpleScatters); TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&scatter_expander, module.get())); EXPECT_TRUE(result);