Expand simple gathers into dynamic-slice.

Especially for (GPU) fusion, XLA prefers to call a slice a slice.

PiperOrigin-RevId: 325526316
Change-Id: I12b98756eca017d520a9a40d03dc291a42b9eaa3
This commit is contained in:
Justin Lebar 2020-08-07 16:01:44 -07:00 committed by TensorFlower Gardener
parent 0d2f665129
commit 6e909825ed
8 changed files with 114 additions and 14 deletions

View File

@ -2259,6 +2259,7 @@ tf_cc_test(
srcs = ["gather_expander_test.cc"],
deps = [
":gather_expander",
":hlo_query",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_macros_header",

View File

@ -183,6 +183,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:indexed_array_analysis",
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:gather_expander",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:rng_expander",
"//tensorflow/compiler/xla/service:sort_simplifier",

View File

@ -77,6 +77,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
#include "tensorflow/compiler/xla/service/dynamic_padder.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/gather_expander.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
@ -303,6 +304,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pass.AddPass<AlgebraicSimplifier>(options);
pass.AddPass<SortSimplifier>();
pass.AddPass<HloDCE>();
pass.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
// BatchNormExpander can create zero-sized ops, so zero-sized HLO
// elimination has to come after that pass.

View File

@ -269,6 +269,22 @@ static StatusOr<HloInstruction*> PermuteBatchAndOffsetDims(
return MakeTransposeHlo(accumulator, permutation);
}
// Computes how many trips a loop implementing this gather op would take.
static int64 GatherLoopTripCount(HloInstruction* gather_instr) {
HloInstruction* start_indices = gather_instr->mutable_operand(1);
const Shape& start_indices_shape = start_indices->shape();
const GatherDimensionNumbers& dim_numbers =
gather_instr->gather_dimension_numbers();
int64 trip_count = 1;
for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
if (i != dim_numbers.index_vector_dim()) {
trip_count *= start_indices_shape.dimensions(i);
}
}
return trip_count;
}
// High Level Algorithm
//
// We follow the following steps in sequence:
@ -311,20 +327,13 @@ StatusOr<HloInstruction*> GatherExpander::ExpandInstruction(
HloComputation* computation = gather_instr->parent();
HloInstruction* operand = gather_instr->mutable_operand(0);
HloInstruction* start_indices = gather_instr->mutable_operand(1);
const Shape& start_indices_shape = start_indices->shape();
const Shape& output_shape = gather_instr->shape();
int64 output_rank = output_shape.dimensions_size();
const GatherDimensionNumbers& dim_numbers =
gather_instr->gather_dimension_numbers();
int64 gather_loop_trip_count = 1;
for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
if (i != dim_numbers.index_vector_dim()) {
gather_loop_trip_count *= start_indices_shape.dimensions(i);
}
}
int64 gather_loop_trip_count = GatherLoopTripCount(gather_instr);
if (!IsInt32(gather_loop_trip_count)) {
return Unimplemented(
"Gather operations with more than 2147483647 gather indices are not "
@ -373,7 +382,11 @@ bool GatherExpander::InstructionMatchesPattern(HloInstruction* inst) {
return inst->opcode() == HloOpcode::kGather &&
// Avoid expanding gather ops that produce zero sized tensors,
// instead punt these to ZeroSizedHloElimination.
!ShapeUtil::IsZeroElementArray(inst->shape());
!ShapeUtil::IsZeroElementArray(inst->shape()) &&
// In kEliminateSimpleGathers mode, we only simplify instructions
// which can be represented without a loop -- i.e. we only simplify
// gathers which have a trip count of 1.
(mode_ == kEliminateAllGathers || GatherLoopTripCount(inst) == 1);
}
} // namespace xla

View File

@ -21,10 +21,30 @@ limitations under the License.
namespace xla {
// This pass rewrites gather operations into (roughly) while loops of dynamic
// slices. This lets backends that don't support gather directly to
// nevertheless have a minimum level of support.
// slices.
//
// This pass can be used two ways:
//
// - kEliminateAllGathers: For backends that don't support gather, this pass
// can convert every gather to a loop.
//
// - kEliminateSimpleGathers: For backends that *do* support gather, this pass
// can strength-reduce "simple" gathers -- specifically, gathers that can be
// represented without a loop -- to dyanmic-slices.
//
// Note that even in kEliminateSimpleGathers mode, this pass may still expand a
// gather into a loop (with a trip-count of 1). It's up to other simplification
// passes to remove the loop.
//
class GatherExpander : public OpExpanderPass {
public:
enum Mode {
kEliminateAllGathers,
kEliminateSimpleGathers,
};
explicit GatherExpander(Mode m) : mode_(m) {}
absl::string_view name() const override { return "gather_expander"; }
protected:
@ -32,6 +52,9 @@ class GatherExpander : public OpExpanderPass {
StatusOr<HloInstruction*> ExpandInstruction(
HloInstruction* gather_inst) override;
private:
Mode mode_;
};
} // namespace xla

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gather_expander.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
@ -42,7 +43,9 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
Status status = GatherExpander{}.Run(module.get()).status();
Status status = GatherExpander{GatherExpander::kEliminateAllGathers}
.Run(module.get())
.status();
EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
ASSERT_THAT(
@ -68,7 +71,9 @@ ENTRY main {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get()));
ASSERT_TRUE(changed);
HloInstruction* while_instr = nullptr;
@ -129,7 +134,9 @@ ENTRY main {
OpMetadata metadata;
metadata.set_op_name("Gather");
module->entry_computation()->root_instruction()->set_metadata(metadata);
TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get()));
ASSERT_TRUE(changed);
HloInstruction* while_instr = nullptr;
@ -147,5 +154,54 @@ ENTRY main {
"after gather expansion";
EXPECT_EQ(while_instr->metadata().op_name(), "Gather");
}
TEST_F(GatherExpanderTest, EliminateSimpleGathersSkipsNontrivialGather) {
const string hlo_text = R"(
HloModule TensorFlowGatherV1
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
ROOT gather = s32[2,3] gather(operand, indices),
offset_dims={1},
collapsed_slice_dims={0},
start_index_map={0},
index_vector_dim=1,
slice_sizes={1, 3}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
GatherExpander pass(GatherExpander::kEliminateSimpleGathers);
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get()));
ASSERT_FALSE(changed);
}
TEST_F(GatherExpanderTest, EliminateSimpleGathersRewritesTrivialGather) {
const string hlo_text = R"(
HloModule test
ENTRY main {
operand = s32[100] parameter(0)
indices = s32[1] parameter(1)
ROOT gather = s32[10] gather(operand, indices),
offset_dims={0},
collapsed_slice_dims={},
start_index_map={0},
index_vector_dim=0,
slice_sizes={10}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
GatherExpander pass(GatherExpander::kEliminateAllGathers);
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get()));
ASSERT_TRUE(changed);
ASSERT_FALSE(hlo_query::ContainsInstrWithOpcode(module->entry_computation(),
{HloOpcode::kGather}));
}
} // namespace
} // namespace xla

View File

@ -1177,6 +1177,7 @@ cc_library(
"//tensorflow/compiler/xla/service:dynamic_padder",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:gather_expander",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_constant_folding",
"//tensorflow/compiler/xla/service:hlo_cse",

View File

@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
#include "tensorflow/compiler/xla/service/dynamic_padder.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/gather_expander.h"
#include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
@ -196,6 +197,8 @@ Status GpuCompiler::OptimizeHloModule(
// elimination has to come after that pass.
pass.AddPass<ZeroSizedHloElimination>();
pass.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
AlgebraicSimplifierOptions options;
// When transposes appear in a fusion node, we can easily adjust the
// multi-dimensional index to create the one needed for the operand. This