[XLA] Expand simple scatter operations into dynamic-update-slice.

This allows them to be fused.

PiperOrigin-RevId: 327291810
Change-Id: I8e706a6add56e5e9fb4e9262e886f19ee11ac2df
This commit is contained in:
Justin Lebar 2020-08-18 13:11:30 -07:00 committed by TensorFlower Gardener
parent e99f1d3efb
commit ad7e6583cd
8 changed files with 137 additions and 53 deletions

View File

@ -1843,6 +1843,7 @@ cc_library(
":hlo",
":hlo_creation_utils",
":hlo_pass",
":op_expander_pass",
":while_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",

View File

@ -290,7 +290,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
/*expansion_type=*/LogisticExpansionType::kExp);
pipeline.AddPass<ConditionalCanonicalizer>();
pipeline.AddPass<DynamicPadder>();
pipeline.AddPass<ScatterExpander>();
pipeline.AddPass<ScatterExpander>(ScatterExpander::kEliminateAllScatters);
pipeline.AddPass<ConvCanonicalization>(target_machine_features);
{
auto& pass =

View File

@ -201,6 +201,7 @@ Status GpuCompiler::OptimizeHloModule(
pass.AddPass<ZeroSizedHloElimination>();
pass.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
pass.AddPass<ScatterExpander>(ScatterExpander::kEliminateSimpleScatters);
AlgebraicSimplifierOptions options;
// When transposes appear in a fusion node, we can easily adjust the

View File

@ -23,26 +23,11 @@ limitations under the License.
namespace xla {
StatusOr<bool> GpuScatterExpander::Run(HloModule* module) {
auto is_nontrivial_scatter = [](HloInstruction* inst) {
// TODO(b/129698548): Scattering elements larger than 64 bits is not
// supported by XLA:GPU.
return inst->opcode() == HloOpcode::kScatter &&
inst->shape().element_type() == C128;
};
std::vector<HloInstruction*> scatter_instrs;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
absl::c_copy_if(computation->instructions(),
std::back_inserter(scatter_instrs), is_nontrivial_scatter);
}
for (HloInstruction* inst : scatter_instrs) {
TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(inst));
TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root));
}
return !scatter_instrs.empty();
bool GpuScatterExpander::InstructionMatchesPattern(HloInstruction* inst) {
// TODO(b/129698548): Scattering elements larger than 64 bits is not
// supported by XLA:GPU.
return inst->opcode() == HloOpcode::kScatter &&
primitive_util::BitWidth(inst->shape().element_type()) > 64;
}
} // namespace xla

View File

@ -20,10 +20,17 @@ limitations under the License.
namespace xla {
// Legalizes scatters on the GPU.
class GpuScatterExpander : public ScatterExpander {
public:
// Although we pass kEliminateAllScatters, we override this behavior in
// InstruuctionMatchesPattern and select only some scatters to expand.
GpuScatterExpander() : ScatterExpander(kEliminateAllScatters) {}
absl::string_view name() const override { return "gpu_scatter_expander"; }
StatusOr<bool> Run(HloModule* module) override;
protected:
bool InstructionMatchesPattern(HloInstruction* inst) override;
};
} // namespace xla

View File

@ -325,6 +325,22 @@ static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
{updated_operand, scatter_indices, updates}};
}
static int64 ScatterTripCount(HloInstruction* scatter) {
// Compute the trip count for the while loop to be used for scatter. This
// should be the number of indices we should scatter into the operand.
HloInstruction* scatter_indices = scatter->mutable_operand(1);
const Shape& scatter_indices_shape = scatter_indices->shape();
const ScatterDimensionNumbers& dim_numbers =
scatter->scatter_dimension_numbers();
int64 scatter_loop_trip_count = 1;
for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
if (i != dim_numbers.index_vector_dim()) {
scatter_loop_trip_count *= scatter_indices_shape.dimensions(i);
}
}
return scatter_loop_trip_count;
}
// High Level Algorithm.
//
// 1. Canonicalize the scatter_indices tensor such that it has rank 2, where
@ -342,7 +358,7 @@ static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
// from c. and d. using the update_computation of scatter.
// f. Write the updated value of the slice into the operand tensor.
StatusOr<HloInstruction*> ScatterExpander::ExpandScatter(
StatusOr<HloInstruction*> ScatterExpander::ExpandInstruction(
HloInstruction* scatter) {
HloInstruction* operand = scatter->mutable_operand(0);
HloInstruction* scatter_indices = scatter->mutable_operand(1);
@ -358,13 +374,7 @@ StatusOr<HloInstruction*> ScatterExpander::ExpandScatter(
// Compute the trip count for the while loop to be used for scatter. This
// should be the number of indices we should scatter into the operand.
const Shape& scatter_indices_shape = scatter_indices->shape();
int64 scatter_loop_trip_count = 1;
for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
if (i != dim_numbers.index_vector_dim()) {
scatter_loop_trip_count *= scatter_indices_shape.dimensions(i);
}
}
int64 scatter_loop_trip_count = ScatterTripCount(scatter);
if (!IsInt32(scatter_loop_trip_count)) {
return Unimplemented(
"Scatter operations with more than 2147483647 scatter indices are not "
@ -408,23 +418,9 @@ StatusOr<HloInstruction*> ScatterExpander::ExpandScatter(
return scatter_loop_result.front();
}
StatusOr<bool> ScatterExpander::Run(HloModule* module) {
std::vector<HloInstruction*> scatter_instrs;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
for (HloInstruction* instr : computation->instructions()) {
if (instr->opcode() == HloOpcode::kScatter) {
scatter_instrs.push_back(instr);
}
}
}
for (auto instr : scatter_instrs) {
TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(instr));
TF_RETURN_IF_ERROR(
instr->parent()->ReplaceInstruction(instr, expanded_root));
}
return !scatter_instrs.empty();
bool ScatterExpander::InstructionMatchesPattern(HloInstruction* inst) {
return inst->opcode() == HloOpcode::kScatter &&
(mode_ == kEliminateAllScatters || ScatterTripCount(inst) == 1);
}
} // namespace xla

View File

@ -16,17 +16,43 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
namespace xla {
class ScatterExpander : public HloModulePass {
// This pass rewrites scatter operations into (roughly) while loops of
// dynamic-update-slices.
//
// This pass can be used in two ways:
//
// - kEliminateAllScatters: For backends that don't support scatter, this pass
// can convert every scatter into a loop.
//
// - kEliminateSimpleScatters: For backends that *do* support scatter, this
// pass can strength-reduce "simple" scatters -- specifically, scatters that
// can be represented without a loop -- to dynamic-update-slices.
//
// Note that even in kEliminateSimpleScatters mode, this pass may still expand a
// scatter into a loop (with a trip-count of 1). It's up to other
// simplification passes to remove the loop.
class ScatterExpander : public OpExpanderPass {
public:
enum Mode {
kEliminateAllScatters,
kEliminateSimpleScatters,
};
explicit ScatterExpander(Mode m) : mode_(m) {}
absl::string_view name() const override { return "scatter_expander"; }
StatusOr<bool> Run(HloModule* module) override;
protected:
StatusOr<HloInstruction*> ExpandScatter(HloInstruction* scatter);
bool InstructionMatchesPattern(HloInstruction* inst) override;
StatusOr<HloInstruction*> ExpandInstruction(HloInstruction* scatter) override;
private:
Mode mode_;
};
} // namespace xla

View File

@ -57,11 +57,79 @@ TEST_F(ScatterExpanderTest, ScatterOperandWithoutLayout) {
ParseAndReturnVerifiedModule(kModuleStr));
// The HLO parser changes all no layout shapes from the input to have a
// default layout, clear the layout of the scatter operand for testing.
// default layout. Clear the layout of the scatter operand for testing.
HloInstruction* scatter_operand = FindInstruction(module.get(), "operand");
scatter_operand->mutable_shape()->clear_layout();
ScatterExpander scatter_expander;
ScatterExpander scatter_expander(ScatterExpander::kEliminateAllScatters);
TF_ASSERT_OK_AND_ASSIGN(bool result,
RunHloPass(&scatter_expander, module.get()));
EXPECT_TRUE(result);
}
TEST_F(ScatterExpanderTest, EliminateSimpleScattersSkipsNontrivialScatter) {
const char* kModuleStr = R"(
HloModule scatter_expander
scatter_computation {
parameter0 = s32[] parameter(0)
ROOT parameter1 = s32[] parameter(1)
}
ENTRY kernel_entry {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=scatter_computation,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr));
// The HLO parser changes all no layout shapes from the input to have a
// default layout. Clear the layout of the scatter operand for testing.
HloInstruction* scatter_operand = FindInstruction(module.get(), "operand");
scatter_operand->mutable_shape()->clear_layout();
ScatterExpander scatter_expander(ScatterExpander::kEliminateSimpleScatters);
TF_ASSERT_OK_AND_ASSIGN(bool result,
RunHloPass(&scatter_expander, module.get()));
EXPECT_FALSE(result);
}
TEST_F(ScatterExpanderTest, EliminateSimpleScattersRewritesTrivialScatter) {
const char* kModuleStr = R"(
HloModule scatter_expander
scatter_computation {
parameter0 = s32[] parameter(0)
ROOT parameter1 = s32[] parameter(1)
}
ENTRY kernel_entry {
operand = s32[5] iota(), iota_dimension=0
indices = s32[1] parameter(0)
update = s32[] constant(0)
ROOT scatter = s32[5]{0} scatter(operand, indices, update),
update_window_dims={}, inserted_window_dims={0},
scatter_dims_to_operand_dims={0}, index_vector_dim=0,
to_apply=scatter_computation
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr));
// The HLO parser changes all no layout shapes from the input to have a
// default layout. Clear the layout of the scatter operand for testing.
HloInstruction* scatter_operand = FindInstruction(module.get(), "operand");
scatter_operand->mutable_shape()->clear_layout();
ScatterExpander scatter_expander(ScatterExpander::kEliminateSimpleScatters);
TF_ASSERT_OK_AND_ASSIGN(bool result,
RunHloPass(&scatter_expander, module.get()));
EXPECT_TRUE(result);