From 6e909825ed44655636170c739d43b6030c742201 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 7 Aug 2020 16:01:44 -0700 Subject: [PATCH] Expand simple gathers into dynamic-slice. Especially for (GPU) fusion, XLA prefers to call a slice a slice. PiperOrigin-RevId: 325526316 Change-Id: I12b98756eca017d520a9a40d03dc291a42b9eaa3 --- tensorflow/compiler/xla/service/BUILD | 1 + tensorflow/compiler/xla/service/cpu/BUILD | 1 + .../compiler/xla/service/cpu/cpu_compiler.cc | 2 + .../compiler/xla/service/gather_expander.cc | 31 +++++++--- .../compiler/xla/service/gather_expander.h | 27 +++++++- .../xla/service/gather_expander_test.cc | 62 ++++++++++++++++++- tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../compiler/xla/service/gpu/gpu_compiler.cc | 3 + 8 files changed, 114 insertions(+), 14 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 49431b19a69..bfcdf6fae34 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2259,6 +2259,7 @@ tf_cc_test( srcs = ["gather_expander_test.cc"], deps = [ ":gather_expander", + ":hlo_query", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_macros_header", diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 6eaf43902fe..e0317574e59 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -183,6 +183,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:indexed_array_analysis", "//tensorflow/compiler/xla/service:llvm_compiler", + "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:rng_expander", "//tensorflow/compiler/xla/service:sort_simplifier", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 0826d7b8ce1..eb5d9e704f5 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -77,6 +77,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/dynamic_padder.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" @@ -303,6 +304,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pass.AddPass(options); pass.AddPass(); pass.AddPass(); + pass.AddPass(GatherExpander::kEliminateSimpleGathers); // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 1838f65e6ea..d38873a501d 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -269,6 +269,22 @@ static StatusOr PermuteBatchAndOffsetDims( return MakeTransposeHlo(accumulator, permutation); } +// Computes how many trips a loop implementing this gather op would take. +static int64 GatherLoopTripCount(HloInstruction* gather_instr) { + HloInstruction* start_indices = gather_instr->mutable_operand(1); + const Shape& start_indices_shape = start_indices->shape(); + const GatherDimensionNumbers& dim_numbers = + gather_instr->gather_dimension_numbers(); + + int64 trip_count = 1; + for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + trip_count *= start_indices_shape.dimensions(i); + } + } + return trip_count; +} + // High Level Algorithm // // We follow the following steps in sequence: @@ -311,20 +327,13 @@ StatusOr GatherExpander::ExpandInstruction( HloComputation* computation = gather_instr->parent(); HloInstruction* operand = gather_instr->mutable_operand(0); HloInstruction* start_indices = gather_instr->mutable_operand(1); - const Shape& start_indices_shape = start_indices->shape(); const Shape& output_shape = gather_instr->shape(); int64 output_rank = output_shape.dimensions_size(); const GatherDimensionNumbers& dim_numbers = gather_instr->gather_dimension_numbers(); - int64 gather_loop_trip_count = 1; - for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { - if (i != dim_numbers.index_vector_dim()) { - gather_loop_trip_count *= start_indices_shape.dimensions(i); - } - } - + int64 gather_loop_trip_count = GatherLoopTripCount(gather_instr); if (!IsInt32(gather_loop_trip_count)) { return Unimplemented( "Gather operations with more than 2147483647 gather indices are not " @@ -373,7 +382,11 @@ bool GatherExpander::InstructionMatchesPattern(HloInstruction* inst) { return inst->opcode() == HloOpcode::kGather && // Avoid expanding gather ops that produce zero sized tensors, // instead punt these to ZeroSizedHloElimination. - !ShapeUtil::IsZeroElementArray(inst->shape()); + !ShapeUtil::IsZeroElementArray(inst->shape()) && + // In kEliminateSimpleGathers mode, we only simplify instructions + // which can be represented without a loop -- i.e. we only simplify + // gathers which have a trip count of 1. + (mode_ == kEliminateAllGathers || GatherLoopTripCount(inst) == 1); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index 5625a37cb46..e665fcd713c 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -21,10 +21,30 @@ limitations under the License. namespace xla { // This pass rewrites gather operations into (roughly) while loops of dynamic -// slices. This lets backends that don't support gather directly to -// nevertheless have a minimum level of support. +// slices. +// +// This pass can be used two ways: +// +// - kEliminateAllGathers: For backends that don't support gather, this pass +// can convert every gather to a loop. +// +// - kEliminateSimpleGathers: For backends that *do* support gather, this pass +// can strength-reduce "simple" gathers -- specifically, gathers that can be +// represented without a loop -- to dyanmic-slices. +// +// Note that even in kEliminateSimpleGathers mode, this pass may still expand a +// gather into a loop (with a trip-count of 1). It's up to other simplification +// passes to remove the loop. +// class GatherExpander : public OpExpanderPass { public: + enum Mode { + kEliminateAllGathers, + kEliminateSimpleGathers, + }; + + explicit GatherExpander(Mode m) : mode_(m) {} + absl::string_view name() const override { return "gather_expander"; } protected: @@ -32,6 +52,9 @@ class GatherExpander : public OpExpanderPass { StatusOr ExpandInstruction( HloInstruction* gather_inst) override; + + private: + Mode mode_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index 706327091d9..4b0808e9aaf 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gather_expander.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -42,7 +43,9 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - Status status = GatherExpander{}.Run(module.get()).status(); + Status status = GatherExpander{GatherExpander::kEliminateAllGathers} + .Run(module.get()) + .status(); EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); ASSERT_THAT( @@ -68,7 +71,9 @@ ENTRY main { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get())); ASSERT_TRUE(changed); HloInstruction* while_instr = nullptr; @@ -129,7 +134,9 @@ ENTRY main { OpMetadata metadata; metadata.set_op_name("Gather"); module->entry_computation()->root_instruction()->set_metadata(metadata); - TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get())); ASSERT_TRUE(changed); HloInstruction* while_instr = nullptr; @@ -147,5 +154,54 @@ ENTRY main { "after gather expansion"; EXPECT_EQ(while_instr->metadata().op_name(), "Gather"); } + +TEST_F(GatherExpanderTest, EliminateSimpleGathersSkipsNontrivialGather) { + const string hlo_text = R"( +HloModule TensorFlowGatherV1 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[2,3] gather(operand, indices), + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, + index_vector_dim=1, + slice_sizes={1, 3} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GatherExpander pass(GatherExpander::kEliminateSimpleGathers); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get())); + ASSERT_FALSE(changed); +} + +TEST_F(GatherExpanderTest, EliminateSimpleGathersRewritesTrivialGather) { + const string hlo_text = R"( +HloModule test + +ENTRY main { + operand = s32[100] parameter(0) + indices = s32[1] parameter(1) + ROOT gather = s32[10] gather(operand, indices), + offset_dims={0}, + collapsed_slice_dims={}, + start_index_map={0}, + index_vector_dim=0, + slice_sizes={10} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GatherExpander pass(GatherExpander::kEliminateAllGathers); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get())); + ASSERT_TRUE(changed); + ASSERT_FALSE(hlo_query::ContainsInstrWithOpcode(module->entry_computation(), + {HloOpcode::kGather})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 8dfd73e9a6a..47af5756f87 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1177,6 +1177,7 @@ cc_library( "//tensorflow/compiler/xla/service:dynamic_padder", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", + "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 6d441903b25..225fa328f3d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/dynamic_padder.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" @@ -196,6 +197,8 @@ Status GpuCompiler::OptimizeHloModule( // elimination has to come after that pass. pass.AddPass(); + pass.AddPass(GatherExpander::kEliminateSimpleGathers); + AlgebraicSimplifierOptions options; // When transposes appear in a fusion node, we can easily adjust the // multi-dimensional index to create the one needed for the operand. This