Expand simple gathers into dynamic-slice.
Especially for (GPU) fusion, XLA prefers to call a slice a slice. PiperOrigin-RevId: 325526316 Change-Id: I12b98756eca017d520a9a40d03dc291a42b9eaa3
This commit is contained in:
parent
0d2f665129
commit
6e909825ed
@ -2259,6 +2259,7 @@ tf_cc_test(
|
|||||||
srcs = ["gather_expander_test.cc"],
|
srcs = ["gather_expander_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":gather_expander",
|
":gather_expander",
|
||||||
|
":hlo_query",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:test_macros_header",
|
"//tensorflow/compiler/xla/tests:test_macros_header",
|
||||||
|
|||||||
@ -183,6 +183,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||||
"//tensorflow/compiler/xla/service:indexed_array_analysis",
|
"//tensorflow/compiler/xla/service:indexed_array_analysis",
|
||||||
"//tensorflow/compiler/xla/service:llvm_compiler",
|
"//tensorflow/compiler/xla/service:llvm_compiler",
|
||||||
|
"//tensorflow/compiler/xla/service:gather_expander",
|
||||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||||
"//tensorflow/compiler/xla/service:rng_expander",
|
"//tensorflow/compiler/xla/service:rng_expander",
|
||||||
"//tensorflow/compiler/xla/service:sort_simplifier",
|
"//tensorflow/compiler/xla/service:sort_simplifier",
|
||||||
|
|||||||
@ -77,6 +77,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
||||||
#include "tensorflow/compiler/xla/service/dynamic_padder.h"
|
#include "tensorflow/compiler/xla/service/dynamic_padder.h"
|
||||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/gather_expander.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
|
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
|
||||||
@ -303,6 +304,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
|||||||
pass.AddPass<AlgebraicSimplifier>(options);
|
pass.AddPass<AlgebraicSimplifier>(options);
|
||||||
pass.AddPass<SortSimplifier>();
|
pass.AddPass<SortSimplifier>();
|
||||||
pass.AddPass<HloDCE>();
|
pass.AddPass<HloDCE>();
|
||||||
|
pass.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
|
||||||
|
|
||||||
// BatchNormExpander can create zero-sized ops, so zero-sized HLO
|
// BatchNormExpander can create zero-sized ops, so zero-sized HLO
|
||||||
// elimination has to come after that pass.
|
// elimination has to come after that pass.
|
||||||
|
|||||||
@ -269,6 +269,22 @@ static StatusOr<HloInstruction*> PermuteBatchAndOffsetDims(
|
|||||||
return MakeTransposeHlo(accumulator, permutation);
|
return MakeTransposeHlo(accumulator, permutation);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computes how many trips a loop implementing this gather op would take.
|
||||||
|
static int64 GatherLoopTripCount(HloInstruction* gather_instr) {
|
||||||
|
HloInstruction* start_indices = gather_instr->mutable_operand(1);
|
||||||
|
const Shape& start_indices_shape = start_indices->shape();
|
||||||
|
const GatherDimensionNumbers& dim_numbers =
|
||||||
|
gather_instr->gather_dimension_numbers();
|
||||||
|
|
||||||
|
int64 trip_count = 1;
|
||||||
|
for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
|
||||||
|
if (i != dim_numbers.index_vector_dim()) {
|
||||||
|
trip_count *= start_indices_shape.dimensions(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return trip_count;
|
||||||
|
}
|
||||||
|
|
||||||
// High Level Algorithm
|
// High Level Algorithm
|
||||||
//
|
//
|
||||||
// We follow the following steps in sequence:
|
// We follow the following steps in sequence:
|
||||||
@ -311,20 +327,13 @@ StatusOr<HloInstruction*> GatherExpander::ExpandInstruction(
|
|||||||
HloComputation* computation = gather_instr->parent();
|
HloComputation* computation = gather_instr->parent();
|
||||||
HloInstruction* operand = gather_instr->mutable_operand(0);
|
HloInstruction* operand = gather_instr->mutable_operand(0);
|
||||||
HloInstruction* start_indices = gather_instr->mutable_operand(1);
|
HloInstruction* start_indices = gather_instr->mutable_operand(1);
|
||||||
const Shape& start_indices_shape = start_indices->shape();
|
|
||||||
const Shape& output_shape = gather_instr->shape();
|
const Shape& output_shape = gather_instr->shape();
|
||||||
int64 output_rank = output_shape.dimensions_size();
|
int64 output_rank = output_shape.dimensions_size();
|
||||||
|
|
||||||
const GatherDimensionNumbers& dim_numbers =
|
const GatherDimensionNumbers& dim_numbers =
|
||||||
gather_instr->gather_dimension_numbers();
|
gather_instr->gather_dimension_numbers();
|
||||||
|
|
||||||
int64 gather_loop_trip_count = 1;
|
int64 gather_loop_trip_count = GatherLoopTripCount(gather_instr);
|
||||||
for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
|
|
||||||
if (i != dim_numbers.index_vector_dim()) {
|
|
||||||
gather_loop_trip_count *= start_indices_shape.dimensions(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!IsInt32(gather_loop_trip_count)) {
|
if (!IsInt32(gather_loop_trip_count)) {
|
||||||
return Unimplemented(
|
return Unimplemented(
|
||||||
"Gather operations with more than 2147483647 gather indices are not "
|
"Gather operations with more than 2147483647 gather indices are not "
|
||||||
@ -373,7 +382,11 @@ bool GatherExpander::InstructionMatchesPattern(HloInstruction* inst) {
|
|||||||
return inst->opcode() == HloOpcode::kGather &&
|
return inst->opcode() == HloOpcode::kGather &&
|
||||||
// Avoid expanding gather ops that produce zero sized tensors,
|
// Avoid expanding gather ops that produce zero sized tensors,
|
||||||
// instead punt these to ZeroSizedHloElimination.
|
// instead punt these to ZeroSizedHloElimination.
|
||||||
!ShapeUtil::IsZeroElementArray(inst->shape());
|
!ShapeUtil::IsZeroElementArray(inst->shape()) &&
|
||||||
|
// In kEliminateSimpleGathers mode, we only simplify instructions
|
||||||
|
// which can be represented without a loop -- i.e. we only simplify
|
||||||
|
// gathers which have a trip count of 1.
|
||||||
|
(mode_ == kEliminateAllGathers || GatherLoopTripCount(inst) == 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
@ -21,10 +21,30 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// This pass rewrites gather operations into (roughly) while loops of dynamic
|
// This pass rewrites gather operations into (roughly) while loops of dynamic
|
||||||
// slices. This lets backends that don't support gather directly to
|
// slices.
|
||||||
// nevertheless have a minimum level of support.
|
//
|
||||||
|
// This pass can be used two ways:
|
||||||
|
//
|
||||||
|
// - kEliminateAllGathers: For backends that don't support gather, this pass
|
||||||
|
// can convert every gather to a loop.
|
||||||
|
//
|
||||||
|
// - kEliminateSimpleGathers: For backends that *do* support gather, this pass
|
||||||
|
// can strength-reduce "simple" gathers -- specifically, gathers that can be
|
||||||
|
// represented without a loop -- to dyanmic-slices.
|
||||||
|
//
|
||||||
|
// Note that even in kEliminateSimpleGathers mode, this pass may still expand a
|
||||||
|
// gather into a loop (with a trip-count of 1). It's up to other simplification
|
||||||
|
// passes to remove the loop.
|
||||||
|
//
|
||||||
class GatherExpander : public OpExpanderPass {
|
class GatherExpander : public OpExpanderPass {
|
||||||
public:
|
public:
|
||||||
|
enum Mode {
|
||||||
|
kEliminateAllGathers,
|
||||||
|
kEliminateSimpleGathers,
|
||||||
|
};
|
||||||
|
|
||||||
|
explicit GatherExpander(Mode m) : mode_(m) {}
|
||||||
|
|
||||||
absl::string_view name() const override { return "gather_expander"; }
|
absl::string_view name() const override { return "gather_expander"; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -32,6 +52,9 @@ class GatherExpander : public OpExpanderPass {
|
|||||||
|
|
||||||
StatusOr<HloInstruction*> ExpandInstruction(
|
StatusOr<HloInstruction*> ExpandInstruction(
|
||||||
HloInstruction* gather_inst) override;
|
HloInstruction* gather_inst) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Mode mode_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/gather_expander.h"
|
#include "tensorflow/compiler/xla/service/gather_expander.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_query.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
@ -42,7 +43,9 @@ ENTRY main {
|
|||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
ParseAndReturnVerifiedModule(hlo_text));
|
ParseAndReturnVerifiedModule(hlo_text));
|
||||||
|
|
||||||
Status status = GatherExpander{}.Run(module.get()).status();
|
Status status = GatherExpander{GatherExpander::kEliminateAllGathers}
|
||||||
|
.Run(module.get())
|
||||||
|
.status();
|
||||||
EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
|
EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
|
||||||
|
|
||||||
ASSERT_THAT(
|
ASSERT_THAT(
|
||||||
@ -68,7 +71,9 @@ ENTRY main {
|
|||||||
)";
|
)";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
ParseAndReturnVerifiedModule(hlo_text));
|
ParseAndReturnVerifiedModule(hlo_text));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
bool changed,
|
||||||
|
GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get()));
|
||||||
ASSERT_TRUE(changed);
|
ASSERT_TRUE(changed);
|
||||||
|
|
||||||
HloInstruction* while_instr = nullptr;
|
HloInstruction* while_instr = nullptr;
|
||||||
@ -129,7 +134,9 @@ ENTRY main {
|
|||||||
OpMetadata metadata;
|
OpMetadata metadata;
|
||||||
metadata.set_op_name("Gather");
|
metadata.set_op_name("Gather");
|
||||||
module->entry_computation()->root_instruction()->set_metadata(metadata);
|
module->entry_computation()->root_instruction()->set_metadata(metadata);
|
||||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
bool changed,
|
||||||
|
GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get()));
|
||||||
ASSERT_TRUE(changed);
|
ASSERT_TRUE(changed);
|
||||||
|
|
||||||
HloInstruction* while_instr = nullptr;
|
HloInstruction* while_instr = nullptr;
|
||||||
@ -147,5 +154,54 @@ ENTRY main {
|
|||||||
"after gather expansion";
|
"after gather expansion";
|
||||||
EXPECT_EQ(while_instr->metadata().op_name(), "Gather");
|
EXPECT_EQ(while_instr->metadata().op_name(), "Gather");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GatherExpanderTest, EliminateSimpleGathersSkipsNontrivialGather) {
|
||||||
|
const string hlo_text = R"(
|
||||||
|
HloModule TensorFlowGatherV1
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
operand = s32[3,3] parameter(0)
|
||||||
|
indices = s32[2] parameter(1)
|
||||||
|
ROOT gather = s32[2,3] gather(operand, indices),
|
||||||
|
offset_dims={1},
|
||||||
|
collapsed_slice_dims={0},
|
||||||
|
start_index_map={0},
|
||||||
|
index_vector_dim=1,
|
||||||
|
slice_sizes={1, 3}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_text));
|
||||||
|
GatherExpander pass(GatherExpander::kEliminateSimpleGathers);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get()));
|
||||||
|
ASSERT_FALSE(changed);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GatherExpanderTest, EliminateSimpleGathersRewritesTrivialGather) {
|
||||||
|
const string hlo_text = R"(
|
||||||
|
HloModule test
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
operand = s32[100] parameter(0)
|
||||||
|
indices = s32[1] parameter(1)
|
||||||
|
ROOT gather = s32[10] gather(operand, indices),
|
||||||
|
offset_dims={0},
|
||||||
|
collapsed_slice_dims={},
|
||||||
|
start_index_map={0},
|
||||||
|
index_vector_dim=0,
|
||||||
|
slice_sizes={10}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_text));
|
||||||
|
GatherExpander pass(GatherExpander::kEliminateAllGathers);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get()));
|
||||||
|
ASSERT_TRUE(changed);
|
||||||
|
ASSERT_FALSE(hlo_query::ContainsInstrWithOpcode(module->entry_computation(),
|
||||||
|
{HloOpcode::kGather}));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
@ -1177,6 +1177,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:dynamic_padder",
|
"//tensorflow/compiler/xla/service:dynamic_padder",
|
||||||
"//tensorflow/compiler/xla/service:executable",
|
"//tensorflow/compiler/xla/service:executable",
|
||||||
"//tensorflow/compiler/xla/service:flatten_call_graph",
|
"//tensorflow/compiler/xla/service:flatten_call_graph",
|
||||||
|
"//tensorflow/compiler/xla/service:gather_expander",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_constant_folding",
|
"//tensorflow/compiler/xla/service:hlo_constant_folding",
|
||||||
"//tensorflow/compiler/xla/service:hlo_cse",
|
"//tensorflow/compiler/xla/service:hlo_cse",
|
||||||
|
|||||||
@ -43,6 +43,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
||||||
#include "tensorflow/compiler/xla/service/dynamic_padder.h"
|
#include "tensorflow/compiler/xla/service/dynamic_padder.h"
|
||||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/gather_expander.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h"
|
#include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
|
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
|
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
|
||||||
@ -196,6 +197,8 @@ Status GpuCompiler::OptimizeHloModule(
|
|||||||
// elimination has to come after that pass.
|
// elimination has to come after that pass.
|
||||||
pass.AddPass<ZeroSizedHloElimination>();
|
pass.AddPass<ZeroSizedHloElimination>();
|
||||||
|
|
||||||
|
pass.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
|
||||||
|
|
||||||
AlgebraicSimplifierOptions options;
|
AlgebraicSimplifierOptions options;
|
||||||
// When transposes appear in a fusion node, we can easily adjust the
|
// When transposes appear in a fusion node, we can easily adjust the
|
||||||
// multi-dimensional index to create the one needed for the operand. This
|
// multi-dimensional index to create the one needed for the operand. This
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user