[XLA:SPMD] Improve sharding propagation for scatter/gather

Try to pass through between the operands and outputs.

PiperOrigin-RevId: 324347628
Change-Id: I1ded92984c87c3d269316f90c6952102f3ec3c76
This commit is contained in:
Yuanzhong Xu 2020-07-31 20:59:29 -07:00 committed by TensorFlower Gardener
parent 803e198f83
commit 8a449bdb65
6 changed files with 550 additions and 64 deletions

View File

@ -473,6 +473,7 @@ cc_library(
"//tensorflow/compiler/xla:array",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/types:optional",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@ -331,6 +332,10 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding,
}
}
Array<int64> new_tile_assignment = index_sharding.tile_assignment();
if (new_tile_assignment.num_elements() !=
Product(output_tile_assignment_dims)) {
return HloSharding::Replicate();
}
new_tile_assignment.Reshape(output_tile_assignment_dims);
return HloSharding::Tile(new_tile_assignment);
}
@ -350,6 +355,10 @@ HloSharding GatherIndexSharding(const HloSharding& output_sharding,
}
}
Array<int64> new_tile_assignment = output_sharding.tile_assignment();
if (new_tile_assignment.num_elements() !=
Product(index_tile_assignment_dims)) {
return HloSharding::Replicate();
}
new_tile_assignment.Reshape(index_tile_assignment_dims);
return HloSharding::Tile(new_tile_assignment);
}
@ -422,6 +431,10 @@ HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
index_tile_assignment_dims.push_back(1);
}
Array<int64> new_tile_assignment = data_sharding.tile_assignment();
if (new_tile_assignment.num_elements() !=
Product(index_tile_assignment_dims)) {
return HloSharding::Replicate();
}
new_tile_assignment.Reshape(index_tile_assignment_dims);
return HloSharding::Tile(new_tile_assignment);
}
@ -444,6 +457,10 @@ HloSharding ScatterDataSharding(const HloSharding& index_sharding,
}
}
Array<int64> new_tile_assignment = index_sharding.tile_assignment();
if (new_tile_assignment.num_elements() !=
Product(data_tile_assignment_dims)) {
return HloSharding::Replicate();
}
new_tile_assignment.Reshape(data_tile_assignment_dims);
return HloSharding::Tile(new_tile_assignment);
}
@ -533,6 +550,169 @@ HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
return HloSharding::Tile(tile_assignment);
}
namespace {
// If partitioning in the operand only happens in dimensions in passthrough
// dimensions (offset dimensions in the gather output (or scatter update) that
// have the same size as the operand), returns the corresponding output (or
// update) sharding by passing through the input sharding.
absl::optional<HloSharding> PassthroughOperandToGatherOutputOrScatterUpdate(
const Shape& operand_shape, const HloSharding& operand_sharding,
const Shape& update_or_gather_shape,
absl::Span<const int64> collapsed_or_inserted_dims,
absl::Span<const int64> index_map,
absl::Span<const int64> offset_or_window_dims,
absl::Span<const int64> slice_size) {
if (operand_sharding.IsTileMaximal()) {
return operand_sharding;
}
std::vector<int64> passthrough_tile(update_or_gather_shape.rank(), 1);
int64 collapsed = 0;
for (int64 i = 0; i < operand_shape.rank(); ++i) {
int64 dim_partitions = operand_sharding.tile_assignment().dim(i);
if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
absl::c_linear_search(index_map, i)) {
if (dim_partitions > 1) {
return absl::nullopt;
}
collapsed++;
continue;
}
if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
return absl::nullopt;
}
int64 offset_dim = offset_or_window_dims[i - collapsed];
if (i - collapsed > 0 &&
offset_dim < offset_or_window_dims[i - collapsed - 1]) {
// Output offsets are transposed, we do not support this case.
return absl::nullopt;
}
passthrough_tile[offset_dim] = dim_partitions;
}
Array<int64> tile_assignment = operand_sharding.tile_assignment();
tile_assignment.Reshape(passthrough_tile);
return HloSharding::Tile(tile_assignment);
}
// Inverse of PassthroughOperandToGatherOutputOrScatterUpdate.
absl::optional<HloSharding> PassthroughGatherOutputOrScatterUpdateToOperand(
const Shape& operand_shape, const HloSharding& update_or_gather_sharding,
absl::Span<const int64> collapsed_or_inserted_dims,
absl::Span<const int64> index_map,
absl::Span<const int64> offset_or_window_dims,
absl::Span<const int64> slice_size) {
if (update_or_gather_sharding.IsTileMaximal()) {
return update_or_gather_sharding;
}
std::vector<int64> passthrough_tile(operand_shape.rank(), 1);
int64 collapsed = 0;
for (int64 i = 0; i < operand_shape.rank(); ++i) {
if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
absl::c_linear_search(index_map, i)) {
collapsed++;
continue;
}
int64 offset_dim = offset_or_window_dims[i - collapsed];
int64 dim_partitions =
update_or_gather_sharding.tile_assignment().dim(offset_dim);
if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
return absl::nullopt;
}
if (i - collapsed > 0 &&
offset_dim < offset_or_window_dims[i - collapsed - 1]) {
// Output offsets are transposed, we do not support this case.
return absl::nullopt;
}
passthrough_tile[i] = dim_partitions;
}
Array<int64> tile_assignment = update_or_gather_sharding.tile_assignment();
if (tile_assignment.num_elements() != Product(passthrough_tile)) {
return absl::nullopt;
}
tile_assignment.Reshape(passthrough_tile);
return HloSharding::Tile(tile_assignment);
}
} // namespace
absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
const HloSharding& data_operand_sharding, const HloInstruction& hlo) {
const auto& dnums = hlo.gather_dimension_numbers();
std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
dnums.collapsed_slice_dims().end());
std::vector<int64> start_index_map(dnums.start_index_map().begin(),
dnums.start_index_map().end());
std::vector<int64> offset_dims(dnums.offset_dims().begin(),
dnums.offset_dims().end());
return PassthroughOperandToGatherOutputOrScatterUpdate(
hlo.operand(0)->shape(), data_operand_sharding, hlo.shape(),
collapsed_slice_dims, start_index_map, offset_dims,
hlo.gather_slice_sizes());
}
absl::optional<HloSharding> GatherDataOperandShardingFromOutput(
const HloSharding& output_sharding, const HloInstruction& hlo) {
const auto& dnums = hlo.gather_dimension_numbers();
std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
dnums.collapsed_slice_dims().end());
std::vector<int64> start_index_map(dnums.start_index_map().begin(),
dnums.start_index_map().end());
std::vector<int64> offset_dims(dnums.offset_dims().begin(),
dnums.offset_dims().end());
return PassthroughGatherOutputOrScatterUpdateToOperand(
hlo.operand(0)->shape(), output_sharding, collapsed_slice_dims,
start_index_map, offset_dims, hlo.gather_slice_sizes());
}
absl::optional<HloSharding> ScatterOutputShardingFromUpdate(
const HloSharding& update_sharding, const HloInstruction& hlo) {
const auto& dnums = hlo.scatter_dimension_numbers();
std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
dnums.inserted_window_dims().end());
std::vector<int64> scatter_dims_to_operand_dims(
dnums.scatter_dims_to_operand_dims().begin(),
dnums.scatter_dims_to_operand_dims().end());
std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
dnums.update_window_dims().end());
std::vector<int64> slice_size(hlo.shape().rank(), 1);
int64 num_update_window_dims = 0;
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
continue;
}
slice_size[i] = hlo.operand(2)->shape().dimensions(
dnums.update_window_dims(num_update_window_dims++));
}
return PassthroughGatherOutputOrScatterUpdateToOperand(
hlo.shape(), update_sharding, inserted_window_dims,
scatter_dims_to_operand_dims, update_window_dims, slice_size);
}
absl::optional<HloSharding> ScatterUpdateShardingFromOutput(
const HloSharding& output_sharding, const HloInstruction& hlo) {
const auto& dnums = hlo.scatter_dimension_numbers();
std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
dnums.inserted_window_dims().end());
std::vector<int64> scatter_dims_to_operand_dims(
dnums.scatter_dims_to_operand_dims().begin(),
dnums.scatter_dims_to_operand_dims().end());
std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
dnums.update_window_dims().end());
std::vector<int64> slice_size(hlo.shape().rank(), 1);
int64 num_update_window_dims = 0;
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
continue;
}
slice_size[i] = hlo.operand(2)->shape().dimensions(
dnums.update_window_dims(num_update_window_dims++));
}
return PassthroughOperandToGatherOutputOrScatterUpdate(
hlo.shape(), output_sharding, hlo.operand(2)->shape(),
inserted_window_dims, scatter_dims_to_operand_dims, update_window_dims,
slice_size);
}
StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
IdentityValueAndHloOpcodeForScatterReduceComputation(
const HloScatterInstruction& scatter) {

View File

@ -127,6 +127,26 @@ HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
const HloInstruction& hlo);
// Returns an output sharding of gather by passing through the data operand's
// sharding.
absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
const HloSharding& data_operand_sharding, const HloInstruction& hlo);
// Returns a data operand sharding of gather by passing through the output's
// sharding.
absl::optional<HloSharding> GatherDataOperandShardingFromOutput(
const HloSharding& output_sharding, const HloInstruction& hlo);
// Returns an output sharding of scatter by passing through the update operand's
// sharding.
absl::optional<HloSharding> ScatterOutputShardingFromUpdate(
const HloSharding& update_sharding, const HloInstruction& hlo);
// Returns an update operand sharding of scatter by passing through the output's
// sharding.
absl::optional<HloSharding> ScatterUpdateShardingFromOutput(
const HloSharding& output_sharding, const HloInstruction& hlo);
// Returns an identity value and an HloOpcode for reduce computation of scatter
// instruction.
// - If computation is add/or, return 0/false with corresponding op code;

View File

@ -899,20 +899,45 @@ bool InferShardingFromOperands(HloInstruction* instruction,
return propagate_slicing() || propagate_base();
}
case HloOpcode::kGather: {
if (!IsSpatiallyPartitioned(instruction->operand(1))) {
return false;
bool changed = false;
if (IsSpatiallyPartitioned(instruction->operand(1))) {
HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
instruction->operand(1)->sharding(), instruction);
changed |= MaybeImproveInstructionSharding(new_sharding, instruction);
}
HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
instruction->operand(1)->sharding(), instruction);
return MaybeImproveInstructionSharding(new_sharding, instruction);
if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) {
auto maybe_from_data =
hlo_sharding_util::GatherOutputShardingFromDataOperand(
instruction->operand(0)->sharding(), *instruction);
if (maybe_from_data) {
changed |=
MaybeImproveInstructionSharding(*maybe_from_data, instruction);
}
}
return changed;
}
case HloOpcode::kScatter: {
bool changed = false;
if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) {
changed |= MaybeImproveInstructionSharding(
instruction->operand(0)->sharding(), instruction);
}
if (!IsSpatiallyPartitioned(instruction->operand(1)) &&
!IsSpatiallyPartitioned(instruction->operand(2))) {
return false;
}
return MaybeImproveInstructionSharding(HloSharding::Replicate(),
instruction);
if (is_spmd && IsSpatiallyPartitioned(instruction->operand(2))) {
auto maybe_from_update =
hlo_sharding_util::ScatterOutputShardingFromUpdate(
instruction->operand(2)->sharding(), *instruction);
if (maybe_from_update) {
changed |=
MaybeImproveInstructionSharding(*maybe_from_update, instruction);
}
}
changed |= MaybeImproveInstructionSharding(HloSharding::Replicate(),
instruction);
return changed;
}
case HloOpcode::kWhile: {
if (!instruction->operand(0)->has_sharding()) {
@ -1218,6 +1243,43 @@ absl::optional<HloSharding> GetShardingFromUser(
return hlo_sharding_util::ReverseSharding(user.sharding(),
user.dimensions());
}
case HloOpcode::kGather: {
if (&instruction == user.operand(1)) {
return hlo_sharding_util::GatherIndexSharding(user.sharding(), &user);
}
if (is_spmd) {
return hlo_sharding_util::GatherDataOperandShardingFromOutput(
user.sharding(), user);
}
return absl::nullopt;
}
case HloOpcode::kScatter: {
if (&instruction == user.operand(0)) {
return user.sharding();
}
if (&instruction == user.operand(1)) {
auto update = user.operand(2);
if (!IsSpatiallyPartitioned(update)) {
return absl::nullopt;
}
return hlo_sharding_util::ScatterIndexSharding(update->sharding(),
&user);
}
CHECK_EQ(&instruction, user.operand(2));
auto indices = user.operand(1);
if (IsSpatiallyPartitioned(indices)) {
auto from_indices =
hlo_sharding_util::ScatterDataSharding(indices->sharding(), &user);
if (!from_indices.IsTileMaximal()) {
return from_indices;
}
}
if (is_spmd) {
return hlo_sharding_util::ScatterUpdateShardingFromOutput(
user.sharding(), user);
}
return absl::nullopt;
}
default: {
// If the user output shape is compatible with the current instruction
// shape excluding element type and the current instruction is supported

View File

@ -1494,5 +1494,275 @@ ENTRY entry {
op::Sharding("{devices=[2,1,1,1]0,1}"));
}
TEST_F(ShardingPropagationTest, GatherFromIndex) {
const char* hlo_string = R"(
HloModule module
ENTRY entry {
%input = f32[2,9] parameter(0), sharding={replicated}
%indices = s32[3] parameter(1), sharding={devices=[2]0,1}
%gather = f32[3,9] gather(%input, %indices), offset_dims={1},
collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
slice_sizes={1,9}
ROOT %copy = f32[3,9] copy(%gather)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
ShardingPropagation().Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "gather"),
op::Sharding("{devices=[2,1]0,1}"));
}
TEST_F(ShardingPropagationTest, GatherFromDataOperand) {
const char* hlo_string = R"(
HloModule module
ENTRY entry {
%input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1}
%indices = s32[3] parameter(1), sharding={replicated}
%gather = f32[3,9] gather(%input, %indices), offset_dims={1},
collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
slice_sizes={1,9}
ROOT %copy = f32[3,9] copy(%gather)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "gather"),
op::Sharding("{devices=[1,2]0,1}"));
}
TEST_F(ShardingPropagationTest, GatherToIndex) {
const char* hlo_string = R"(
HloModule module
ENTRY entry {
%input = f32[2,9] parameter(0), sharding={replicated}
%p1 = s32[3] parameter(1)
%indices = s32[3] copy(%p1)
ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
slice_sizes={1,9}, sharding={devices=[2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
ShardingPropagation().Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "indices"),
op::Sharding("{devices=[2]0,1}"));
}
TEST_F(ShardingPropagationTest, GatherToDataOperand) {
const char* hlo_string = R"(
HloModule module
ENTRY entry {
%p0 = f32[2,9] parameter(0)
%input = f32[2,9] copy(%p0)
%indices = s32[3] parameter(1), sharding={replicated}
ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
slice_sizes={1,9}, sharding={devices=[1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "input"),
op::Sharding("{devices=[1,2]0,1}"));
}
TEST_F(ShardingPropagationTest, DataOperandToScatter) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1}
%indices = s32[3] parameter(1), sharding={replicated}
%updates = f32[3,9] parameter(2), sharding={replicated}
%scatter = f32[2,9] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
ROOT %copy = f32[2,9] copy(%scatter)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "scatter"),
op::Sharding("{devices=[1,2]0,1}"));
}
TEST_F(ShardingPropagationTest, UpdateOperandToScatter) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[2,9] parameter(0), sharding={replicated}
%indices = s32[3] parameter(1), sharding={replicated}
%updates = f32[3,9] parameter(2), sharding={devices=[1,2]0,1}
%scatter = f32[2,9] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
ROOT %copy = f32[2,9] copy(%scatter)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "scatter"),
op::Sharding("{devices=[1,2]0,1}"));
}
TEST_F(ShardingPropagationTest, ScatterToDataOperand) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%p0 = f32[2,9] parameter(0)
%input = f32[2,9] copy(%p0)
%indices = s32[3] parameter(1), sharding={replicated}
%updates = f32[3,9] parameter(2), sharding={replicated}
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1, sharding={devices=[1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
ShardingPropagation().Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "input"),
op::Sharding("{devices=[1,2]0,1}"));
}
TEST_F(ShardingPropagationTest, ScatterToUpdateOperand) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[2,9] parameter(0)
%indices = s32[3] parameter(1), sharding={replicated}
%p2 = f32[3,9] parameter(2)
%updates = f32[3,9] copy(%p2)
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1, sharding={devices=[1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "updates"),
op::Sharding("{devices=[1,2]0,1}"));
}
TEST_F(ShardingPropagationTest, ScatterUpdateToIndex) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[2,9] parameter(0), sharding={replicated}
%p1 = s32[3] parameter(1), sharding={replicated}
%indices = s32[3] copy(%p1)
%updates = f32[3,9] parameter(2), sharding={devices=[2,1]0,1}
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
ShardingPropagation().Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "indices"),
op::Sharding("{devices=[2]0,1}"));
}
TEST_F(ShardingPropagationTest, ScatterIndexToUpdate) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[2,9] parameter(0), sharding={replicated}
%indices = s32[3] parameter(1), sharding={devices=[2]0,1}
%p2 = f32[3,9] parameter(2), sharding={replicated}
%updates = f32[3,9] copy(%p2)
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
ShardingPropagation().Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "updates"),
op::Sharding("{devices=[2,1]0,1}"));
}
} // namespace
} // namespace xla

View File

@ -1069,47 +1069,7 @@ Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) {
return Status::OK();
}
// If partitioning in the operand only happens in dimensions in passthrough
// dimensions (offset dimensions in the gather output (or scatter update) that
// have the same size as the operand), returns the corresponding output (or
// update) sharding by passing through the input sharding.
absl::optional<HloSharding> PassthroughOperandToGatherOutputOrScatterUpdate(
const PartitionedHlo& operand, const Shape& update_or_gather_shape,
absl::Span<const int64> collapsed_or_inserted_dims,
absl::Span<const int64> index_map,
absl::Span<const int64> offset_or_window_dims,
absl::Span<const int64> slice_size) {
if (operand.sharding().IsTileMaximal()) {
return operand.sharding();
}
std::vector<int64> passthrough_tile(update_or_gather_shape.rank(), 1);
int64 collapsed = 0;
for (int64 i = 0; i < operand.base_shape().rank(); ++i) {
int64 dim_partitions = operand.sharding().tile_assignment().dim(i);
if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
absl::c_linear_search(index_map, i)) {
if (dim_partitions > 1) {
return absl::nullopt;
}
collapsed++;
continue;
}
if (slice_size[i] != operand.base_shape().dimensions(i) &&
dim_partitions > 1) {
return absl::nullopt;
}
int64 offset_dim = offset_or_window_dims[i - collapsed];
if (i - collapsed > 0 &&
offset_dim < offset_or_window_dims[i - collapsed - 1]) {
// Output offsets are transposed, we do not support this case.
return absl::nullopt;
}
passthrough_tile[offset_dim] = dim_partitions;
}
Array<int64> tile_assignment = operand.sharding().tile_assignment();
tile_assignment.Reshape(passthrough_tile);
return HloSharding::Tile(tile_assignment);
}
namespace {
// Returns whether partitioning in the operand only happens in dimensions with
// gather/scatter slice size 1.
@ -1204,6 +1164,8 @@ IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims(
return {broadcast_min, broadcast_max};
}
} // namespace
Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
auto scatter = Cast<HloScatterInstruction>(hlo);
auto dnums = scatter->scatter_dimension_numbers();
@ -1219,16 +1181,12 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
slice_size[i] = updates.base_shape().dimensions(
dnums.update_window_dims(num_update_window_dims++));
}
std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
dnums.inserted_window_dims().end());
std::vector<int64> scatter_dims_to_operand_dims(
dnums.scatter_dims_to_operand_dims().begin(),
dnums.scatter_dims_to_operand_dims().end());
std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
dnums.update_window_dims().end());
std::vector<int64> update_scatter_dims;
for (int64 i = 0; i < updates.base_shape().rank(); ++i) {
if (!absl::c_linear_search(update_window_dims, i)) {
if (!absl::c_linear_search(dnums.update_window_dims(), i)) {
update_scatter_dims.push_back(i);
}
}
@ -1292,9 +1250,8 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
return Status::OK();
}
} else {
auto maybe_passthrough = PassthroughOperandToGatherOutputOrScatterUpdate(
operand, updates.base_shape(), inserted_window_dims,
scatter_dims_to_operand_dims, update_window_dims, slice_size);
auto maybe_passthrough = hlo_sharding_util::ScatterUpdateShardingFromOutput(
operand.sharding(), *hlo);
// Handle pass through cases if we can use compatible sharding for update.
if (maybe_passthrough.has_value()) {
indices = indices.Reshard(HloSharding::Replicate());
@ -2148,15 +2105,11 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
const auto& dnums = gather->gather_dimension_numbers();
auto operand = GetPartitionedHlo(gather->operand(0));
auto indices = GetPartitionedHlo(gather->operand(1));
std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
dnums.collapsed_slice_dims().end());
std::vector<int64> start_index_map(dnums.start_index_map().begin(),
dnums.start_index_map().end());
std::vector<int64> offset_dims(dnums.offset_dims().begin(),
dnums.offset_dims().end());
std::vector<int64> batch_dims;
for (int64 i = 0; i < gather->shape().rank(); ++i) {
if (!absl::c_linear_search(offset_dims, i)) {
if (!absl::c_linear_search(dnums.offset_dims(), i)) {
batch_dims.push_back(i);
}
}
@ -2193,9 +2146,9 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
return Status::OK();
}
} else {
auto maybe_passthrough = PassthroughOperandToGatherOutputOrScatterUpdate(
operand, gather->shape(), collapsed_slice_dims, start_index_map,
offset_dims, gather->gather_slice_sizes());
auto maybe_passthrough =
hlo_sharding_util::GatherOutputShardingFromDataOperand(
operand.sharding(), *hlo);
if (maybe_passthrough.has_value()) {
indices = indices.Reshard(HloSharding::Replicate());
auto pshape = MakePartitionedShape(gather->shape(), *maybe_passthrough);