Expand simple gathers into dynamic-slice.
Especially for (GPU) fusion, XLA prefers to call a slice a slice. PiperOrigin-RevId: 325526316 Change-Id: I12b98756eca017d520a9a40d03dc291a42b9eaa3
This commit is contained in:
		
							parent
							
								
									0d2f665129
								
							
						
					
					
						commit
						6e909825ed
					
				@ -2259,6 +2259,7 @@ tf_cc_test(
 | 
			
		||||
    srcs = ["gather_expander_test.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":gather_expander",
 | 
			
		||||
        ":hlo_query",
 | 
			
		||||
        "//tensorflow/compiler/xla:test",
 | 
			
		||||
        "//tensorflow/compiler/xla/tests:hlo_test_base",
 | 
			
		||||
        "//tensorflow/compiler/xla/tests:test_macros_header",
 | 
			
		||||
 | 
			
		||||
@ -183,6 +183,7 @@ cc_library(
 | 
			
		||||
        "//tensorflow/compiler/xla/service:hlo_verifier",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:indexed_array_analysis",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:llvm_compiler",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:gather_expander",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:reshape_mover",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:rng_expander",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:sort_simplifier",
 | 
			
		||||
 | 
			
		||||
@ -77,6 +77,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/dynamic_padder.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/gather_expander.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
 | 
			
		||||
@ -303,6 +304,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
 | 
			
		||||
    pass.AddPass<AlgebraicSimplifier>(options);
 | 
			
		||||
    pass.AddPass<SortSimplifier>();
 | 
			
		||||
    pass.AddPass<HloDCE>();
 | 
			
		||||
    pass.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
 | 
			
		||||
 | 
			
		||||
    // BatchNormExpander can create zero-sized ops, so zero-sized HLO
 | 
			
		||||
    // elimination has to come after that pass.
 | 
			
		||||
 | 
			
		||||
@ -269,6 +269,22 @@ static StatusOr<HloInstruction*> PermuteBatchAndOffsetDims(
 | 
			
		||||
  return MakeTransposeHlo(accumulator, permutation);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Computes how many trips a loop implementing this gather op would take.
 | 
			
		||||
static int64 GatherLoopTripCount(HloInstruction* gather_instr) {
 | 
			
		||||
  HloInstruction* start_indices = gather_instr->mutable_operand(1);
 | 
			
		||||
  const Shape& start_indices_shape = start_indices->shape();
 | 
			
		||||
  const GatherDimensionNumbers& dim_numbers =
 | 
			
		||||
      gather_instr->gather_dimension_numbers();
 | 
			
		||||
 | 
			
		||||
  int64 trip_count = 1;
 | 
			
		||||
  for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
 | 
			
		||||
    if (i != dim_numbers.index_vector_dim()) {
 | 
			
		||||
      trip_count *= start_indices_shape.dimensions(i);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return trip_count;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// High Level Algorithm
 | 
			
		||||
//
 | 
			
		||||
// We follow the following steps in sequence:
 | 
			
		||||
@ -311,20 +327,13 @@ StatusOr<HloInstruction*> GatherExpander::ExpandInstruction(
 | 
			
		||||
  HloComputation* computation = gather_instr->parent();
 | 
			
		||||
  HloInstruction* operand = gather_instr->mutable_operand(0);
 | 
			
		||||
  HloInstruction* start_indices = gather_instr->mutable_operand(1);
 | 
			
		||||
  const Shape& start_indices_shape = start_indices->shape();
 | 
			
		||||
  const Shape& output_shape = gather_instr->shape();
 | 
			
		||||
  int64 output_rank = output_shape.dimensions_size();
 | 
			
		||||
 | 
			
		||||
  const GatherDimensionNumbers& dim_numbers =
 | 
			
		||||
      gather_instr->gather_dimension_numbers();
 | 
			
		||||
 | 
			
		||||
  int64 gather_loop_trip_count = 1;
 | 
			
		||||
  for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
 | 
			
		||||
    if (i != dim_numbers.index_vector_dim()) {
 | 
			
		||||
      gather_loop_trip_count *= start_indices_shape.dimensions(i);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int64 gather_loop_trip_count = GatherLoopTripCount(gather_instr);
 | 
			
		||||
  if (!IsInt32(gather_loop_trip_count)) {
 | 
			
		||||
    return Unimplemented(
 | 
			
		||||
        "Gather operations with more than 2147483647 gather indices are not "
 | 
			
		||||
@ -373,7 +382,11 @@ bool GatherExpander::InstructionMatchesPattern(HloInstruction* inst) {
 | 
			
		||||
  return inst->opcode() == HloOpcode::kGather &&
 | 
			
		||||
         // Avoid expanding gather ops that produce zero sized tensors,
 | 
			
		||||
         // instead punt these to ZeroSizedHloElimination.
 | 
			
		||||
         !ShapeUtil::IsZeroElementArray(inst->shape());
 | 
			
		||||
         !ShapeUtil::IsZeroElementArray(inst->shape()) &&
 | 
			
		||||
         // In kEliminateSimpleGathers mode, we only simplify instructions
 | 
			
		||||
         // which can be represented without a loop -- i.e. we only simplify
 | 
			
		||||
         // gathers which have a trip count of 1.
 | 
			
		||||
         (mode_ == kEliminateAllGathers || GatherLoopTripCount(inst) == 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
 | 
			
		||||
@ -21,10 +21,30 @@ limitations under the License.
 | 
			
		||||
namespace xla {
 | 
			
		||||
 | 
			
		||||
// This pass rewrites gather operations into (roughly) while loops of dynamic
 | 
			
		||||
// slices.  This lets backends that don't support gather directly to
 | 
			
		||||
// nevertheless have a minimum level of support.
 | 
			
		||||
// slices.
 | 
			
		||||
//
 | 
			
		||||
// This pass can be used two ways:
 | 
			
		||||
//
 | 
			
		||||
//  - kEliminateAllGathers: For backends that don't support gather, this pass
 | 
			
		||||
//    can convert every gather to a loop.
 | 
			
		||||
//
 | 
			
		||||
//  - kEliminateSimpleGathers: For backends that *do* support gather, this pass
 | 
			
		||||
//    can strength-reduce "simple" gathers -- specifically, gathers that can be
 | 
			
		||||
//    represented without a loop -- to dyanmic-slices.
 | 
			
		||||
//
 | 
			
		||||
// Note that even in kEliminateSimpleGathers mode, this pass may still expand a
 | 
			
		||||
// gather into a loop (with a trip-count of 1).  It's up to other simplification
 | 
			
		||||
// passes to remove the loop.
 | 
			
		||||
//
 | 
			
		||||
class GatherExpander : public OpExpanderPass {
 | 
			
		||||
 public:
 | 
			
		||||
  enum Mode {
 | 
			
		||||
    kEliminateAllGathers,
 | 
			
		||||
    kEliminateSimpleGathers,
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  explicit GatherExpander(Mode m) : mode_(m) {}
 | 
			
		||||
 | 
			
		||||
  absl::string_view name() const override { return "gather_expander"; }
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
@ -32,6 +52,9 @@ class GatherExpander : public OpExpanderPass {
 | 
			
		||||
 | 
			
		||||
  StatusOr<HloInstruction*> ExpandInstruction(
 | 
			
		||||
      HloInstruction* gather_inst) override;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  Mode mode_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
 | 
			
		||||
@ -15,6 +15,7 @@ limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/compiler/xla/service/gather_expander.h"
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_query.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/test.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
 | 
			
		||||
@ -42,7 +43,9 @@ ENTRY main {
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(hlo_text));
 | 
			
		||||
 | 
			
		||||
  Status status = GatherExpander{}.Run(module.get()).status();
 | 
			
		||||
  Status status = GatherExpander{GatherExpander::kEliminateAllGathers}
 | 
			
		||||
                      .Run(module.get())
 | 
			
		||||
                      .status();
 | 
			
		||||
  EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
 | 
			
		||||
 | 
			
		||||
  ASSERT_THAT(
 | 
			
		||||
@ -68,7 +71,9 @@ ENTRY main {
 | 
			
		||||
)";
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(hlo_text));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      bool changed,
 | 
			
		||||
      GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get()));
 | 
			
		||||
  ASSERT_TRUE(changed);
 | 
			
		||||
 | 
			
		||||
  HloInstruction* while_instr = nullptr;
 | 
			
		||||
@ -129,7 +134,9 @@ ENTRY main {
 | 
			
		||||
  OpMetadata metadata;
 | 
			
		||||
  metadata.set_op_name("Gather");
 | 
			
		||||
  module->entry_computation()->root_instruction()->set_metadata(metadata);
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      bool changed,
 | 
			
		||||
      GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get()));
 | 
			
		||||
  ASSERT_TRUE(changed);
 | 
			
		||||
 | 
			
		||||
  HloInstruction* while_instr = nullptr;
 | 
			
		||||
@ -147,5 +154,54 @@ ENTRY main {
 | 
			
		||||
         "after gather expansion";
 | 
			
		||||
  EXPECT_EQ(while_instr->metadata().op_name(), "Gather");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(GatherExpanderTest, EliminateSimpleGathersSkipsNontrivialGather) {
 | 
			
		||||
  const string hlo_text = R"(
 | 
			
		||||
HloModule TensorFlowGatherV1
 | 
			
		||||
 | 
			
		||||
ENTRY main {
 | 
			
		||||
  operand = s32[3,3] parameter(0)
 | 
			
		||||
  indices = s32[2] parameter(1)
 | 
			
		||||
  ROOT gather = s32[2,3] gather(operand, indices),
 | 
			
		||||
      offset_dims={1},
 | 
			
		||||
      collapsed_slice_dims={0},
 | 
			
		||||
      start_index_map={0},
 | 
			
		||||
      index_vector_dim=1,
 | 
			
		||||
      slice_sizes={1, 3}
 | 
			
		||||
}
 | 
			
		||||
)";
 | 
			
		||||
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(hlo_text));
 | 
			
		||||
  GatherExpander pass(GatherExpander::kEliminateSimpleGathers);
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get()));
 | 
			
		||||
  ASSERT_FALSE(changed);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(GatherExpanderTest, EliminateSimpleGathersRewritesTrivialGather) {
 | 
			
		||||
  const string hlo_text = R"(
 | 
			
		||||
HloModule test
 | 
			
		||||
 | 
			
		||||
ENTRY main {
 | 
			
		||||
  operand = s32[100] parameter(0)
 | 
			
		||||
  indices = s32[1] parameter(1)
 | 
			
		||||
  ROOT gather = s32[10] gather(operand, indices),
 | 
			
		||||
      offset_dims={0},
 | 
			
		||||
      collapsed_slice_dims={},
 | 
			
		||||
      start_index_map={0},
 | 
			
		||||
      index_vector_dim=0,
 | 
			
		||||
      slice_sizes={10}
 | 
			
		||||
}
 | 
			
		||||
)";
 | 
			
		||||
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(hlo_text));
 | 
			
		||||
  GatherExpander pass(GatherExpander::kEliminateAllGathers);
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get()));
 | 
			
		||||
  ASSERT_TRUE(changed);
 | 
			
		||||
  ASSERT_FALSE(hlo_query::ContainsInstrWithOpcode(module->entry_computation(),
 | 
			
		||||
                                                  {HloOpcode::kGather}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
 | 
			
		||||
@ -1177,6 +1177,7 @@ cc_library(
 | 
			
		||||
        "//tensorflow/compiler/xla/service:dynamic_padder",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:executable",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:flatten_call_graph",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:gather_expander",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:hlo",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:hlo_constant_folding",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:hlo_cse",
 | 
			
		||||
 | 
			
		||||
@ -43,6 +43,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/dynamic_padder.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/gather_expander.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
 | 
			
		||||
@ -196,6 +197,8 @@ Status GpuCompiler::OptimizeHloModule(
 | 
			
		||||
      // elimination has to come after that pass.
 | 
			
		||||
      pass.AddPass<ZeroSizedHloElimination>();
 | 
			
		||||
 | 
			
		||||
      pass.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
 | 
			
		||||
 | 
			
		||||
      AlgebraicSimplifierOptions options;
 | 
			
		||||
      // When transposes appear in a fusion node, we can easily adjust the
 | 
			
		||||
      // multi-dimensional index to create the one needed for the operand. This
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user