Expand simple gathers into dynamic-slice.
Especially for (GPU) fusion, XLA prefers to call a slice a slice. PiperOrigin-RevId: 325526316 Change-Id: I12b98756eca017d520a9a40d03dc291a42b9eaa3
This commit is contained in:
parent
0d2f665129
commit
6e909825ed
@ -2259,6 +2259,7 @@ tf_cc_test(
|
||||
srcs = ["gather_expander_test.cc"],
|
||||
deps = [
|
||||
":gather_expander",
|
||||
":hlo_query",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_macros_header",
|
||||
|
@ -183,6 +183,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:indexed_array_analysis",
|
||||
"//tensorflow/compiler/xla/service:llvm_compiler",
|
||||
"//tensorflow/compiler/xla/service:gather_expander",
|
||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||
"//tensorflow/compiler/xla/service:rng_expander",
|
||||
"//tensorflow/compiler/xla/service:sort_simplifier",
|
||||
|
@ -77,6 +77,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_padder.h"
|
||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/gather_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
|
||||
@ -303,6 +304,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
pass.AddPass<AlgebraicSimplifier>(options);
|
||||
pass.AddPass<SortSimplifier>();
|
||||
pass.AddPass<HloDCE>();
|
||||
pass.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
|
||||
|
||||
// BatchNormExpander can create zero-sized ops, so zero-sized HLO
|
||||
// elimination has to come after that pass.
|
||||
|
@ -269,6 +269,22 @@ static StatusOr<HloInstruction*> PermuteBatchAndOffsetDims(
|
||||
return MakeTransposeHlo(accumulator, permutation);
|
||||
}
|
||||
|
||||
// Computes how many trips a loop implementing this gather op would take.
|
||||
static int64 GatherLoopTripCount(HloInstruction* gather_instr) {
|
||||
HloInstruction* start_indices = gather_instr->mutable_operand(1);
|
||||
const Shape& start_indices_shape = start_indices->shape();
|
||||
const GatherDimensionNumbers& dim_numbers =
|
||||
gather_instr->gather_dimension_numbers();
|
||||
|
||||
int64 trip_count = 1;
|
||||
for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
|
||||
if (i != dim_numbers.index_vector_dim()) {
|
||||
trip_count *= start_indices_shape.dimensions(i);
|
||||
}
|
||||
}
|
||||
return trip_count;
|
||||
}
|
||||
|
||||
// High Level Algorithm
|
||||
//
|
||||
// We follow the following steps in sequence:
|
||||
@ -311,20 +327,13 @@ StatusOr<HloInstruction*> GatherExpander::ExpandInstruction(
|
||||
HloComputation* computation = gather_instr->parent();
|
||||
HloInstruction* operand = gather_instr->mutable_operand(0);
|
||||
HloInstruction* start_indices = gather_instr->mutable_operand(1);
|
||||
const Shape& start_indices_shape = start_indices->shape();
|
||||
const Shape& output_shape = gather_instr->shape();
|
||||
int64 output_rank = output_shape.dimensions_size();
|
||||
|
||||
const GatherDimensionNumbers& dim_numbers =
|
||||
gather_instr->gather_dimension_numbers();
|
||||
|
||||
int64 gather_loop_trip_count = 1;
|
||||
for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) {
|
||||
if (i != dim_numbers.index_vector_dim()) {
|
||||
gather_loop_trip_count *= start_indices_shape.dimensions(i);
|
||||
}
|
||||
}
|
||||
|
||||
int64 gather_loop_trip_count = GatherLoopTripCount(gather_instr);
|
||||
if (!IsInt32(gather_loop_trip_count)) {
|
||||
return Unimplemented(
|
||||
"Gather operations with more than 2147483647 gather indices are not "
|
||||
@ -373,7 +382,11 @@ bool GatherExpander::InstructionMatchesPattern(HloInstruction* inst) {
|
||||
return inst->opcode() == HloOpcode::kGather &&
|
||||
// Avoid expanding gather ops that produce zero sized tensors,
|
||||
// instead punt these to ZeroSizedHloElimination.
|
||||
!ShapeUtil::IsZeroElementArray(inst->shape());
|
||||
!ShapeUtil::IsZeroElementArray(inst->shape()) &&
|
||||
// In kEliminateSimpleGathers mode, we only simplify instructions
|
||||
// which can be represented without a loop -- i.e. we only simplify
|
||||
// gathers which have a trip count of 1.
|
||||
(mode_ == kEliminateAllGathers || GatherLoopTripCount(inst) == 1);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -21,10 +21,30 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
// This pass rewrites gather operations into (roughly) while loops of dynamic
|
||||
// slices. This lets backends that don't support gather directly to
|
||||
// nevertheless have a minimum level of support.
|
||||
// slices.
|
||||
//
|
||||
// This pass can be used two ways:
|
||||
//
|
||||
// - kEliminateAllGathers: For backends that don't support gather, this pass
|
||||
// can convert every gather to a loop.
|
||||
//
|
||||
// - kEliminateSimpleGathers: For backends that *do* support gather, this pass
|
||||
// can strength-reduce "simple" gathers -- specifically, gathers that can be
|
||||
// represented without a loop -- to dyanmic-slices.
|
||||
//
|
||||
// Note that even in kEliminateSimpleGathers mode, this pass may still expand a
|
||||
// gather into a loop (with a trip-count of 1). It's up to other simplification
|
||||
// passes to remove the loop.
|
||||
//
|
||||
class GatherExpander : public OpExpanderPass {
|
||||
public:
|
||||
enum Mode {
|
||||
kEliminateAllGathers,
|
||||
kEliminateSimpleGathers,
|
||||
};
|
||||
|
||||
explicit GatherExpander(Mode m) : mode_(m) {}
|
||||
|
||||
absl::string_view name() const override { return "gather_expander"; }
|
||||
|
||||
protected:
|
||||
@ -32,6 +52,9 @@ class GatherExpander : public OpExpanderPass {
|
||||
|
||||
StatusOr<HloInstruction*> ExpandInstruction(
|
||||
HloInstruction* gather_inst) override;
|
||||
|
||||
private:
|
||||
Mode mode_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gather_expander.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_query.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
@ -42,7 +43,9 @@ ENTRY main {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
|
||||
Status status = GatherExpander{}.Run(module.get()).status();
|
||||
Status status = GatherExpander{GatherExpander::kEliminateAllGathers}
|
||||
.Run(module.get())
|
||||
.status();
|
||||
EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
|
||||
|
||||
ASSERT_THAT(
|
||||
@ -68,7 +71,9 @@ ENTRY main {
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed,
|
||||
GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get()));
|
||||
ASSERT_TRUE(changed);
|
||||
|
||||
HloInstruction* while_instr = nullptr;
|
||||
@ -129,7 +134,9 @@ ENTRY main {
|
||||
OpMetadata metadata;
|
||||
metadata.set_op_name("Gather");
|
||||
module->entry_computation()->root_instruction()->set_metadata(metadata);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed,
|
||||
GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get()));
|
||||
ASSERT_TRUE(changed);
|
||||
|
||||
HloInstruction* while_instr = nullptr;
|
||||
@ -147,5 +154,54 @@ ENTRY main {
|
||||
"after gather expansion";
|
||||
EXPECT_EQ(while_instr->metadata().op_name(), "Gather");
|
||||
}
|
||||
|
||||
TEST_F(GatherExpanderTest, EliminateSimpleGathersSkipsNontrivialGather) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowGatherV1
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[3,3] parameter(0)
|
||||
indices = s32[2] parameter(1)
|
||||
ROOT gather = s32[2,3] gather(operand, indices),
|
||||
offset_dims={1},
|
||||
collapsed_slice_dims={0},
|
||||
start_index_map={0},
|
||||
index_vector_dim=1,
|
||||
slice_sizes={1, 3}
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
GatherExpander pass(GatherExpander::kEliminateSimpleGathers);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get()));
|
||||
ASSERT_FALSE(changed);
|
||||
}
|
||||
|
||||
TEST_F(GatherExpanderTest, EliminateSimpleGathersRewritesTrivialGather) {
|
||||
const string hlo_text = R"(
|
||||
HloModule test
|
||||
|
||||
ENTRY main {
|
||||
operand = s32[100] parameter(0)
|
||||
indices = s32[1] parameter(1)
|
||||
ROOT gather = s32[10] gather(operand, indices),
|
||||
offset_dims={0},
|
||||
collapsed_slice_dims={},
|
||||
start_index_map={0},
|
||||
index_vector_dim=0,
|
||||
slice_sizes={10}
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
GatherExpander pass(GatherExpander::kEliminateAllGathers);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get()));
|
||||
ASSERT_TRUE(changed);
|
||||
ASSERT_FALSE(hlo_query::ContainsInstrWithOpcode(module->entry_computation(),
|
||||
{HloOpcode::kGather}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -1177,6 +1177,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:dynamic_padder",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:flatten_call_graph",
|
||||
"//tensorflow/compiler/xla/service:gather_expander",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_constant_folding",
|
||||
"//tensorflow/compiler/xla/service:hlo_cse",
|
||||
|
@ -43,6 +43,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_padder.h"
|
||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/gather_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
|
||||
@ -196,6 +197,8 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
// elimination has to come after that pass.
|
||||
pass.AddPass<ZeroSizedHloElimination>();
|
||||
|
||||
pass.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
|
||||
|
||||
AlgebraicSimplifierOptions options;
|
||||
// When transposes appear in a fusion node, we can easily adjust the
|
||||
// multi-dimensional index to create the one needed for the operand. This
|
||||
|
Loading…
x
Reference in New Issue
Block a user