[XLA:SPMD] Improve sharding propagation for scatter/gather
Try to pass through between the operands and outputs. PiperOrigin-RevId: 324347628 Change-Id: I1ded92984c87c3d269316f90c6952102f3ec3c76
This commit is contained in:
parent
803e198f83
commit
8a449bdb65
@ -473,6 +473,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:array",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
@ -331,6 +332,10 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding,
|
||||
}
|
||||
}
|
||||
Array<int64> new_tile_assignment = index_sharding.tile_assignment();
|
||||
if (new_tile_assignment.num_elements() !=
|
||||
Product(output_tile_assignment_dims)) {
|
||||
return HloSharding::Replicate();
|
||||
}
|
||||
new_tile_assignment.Reshape(output_tile_assignment_dims);
|
||||
return HloSharding::Tile(new_tile_assignment);
|
||||
}
|
||||
@ -350,6 +355,10 @@ HloSharding GatherIndexSharding(const HloSharding& output_sharding,
|
||||
}
|
||||
}
|
||||
Array<int64> new_tile_assignment = output_sharding.tile_assignment();
|
||||
if (new_tile_assignment.num_elements() !=
|
||||
Product(index_tile_assignment_dims)) {
|
||||
return HloSharding::Replicate();
|
||||
}
|
||||
new_tile_assignment.Reshape(index_tile_assignment_dims);
|
||||
return HloSharding::Tile(new_tile_assignment);
|
||||
}
|
||||
@ -422,6 +431,10 @@ HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
|
||||
index_tile_assignment_dims.push_back(1);
|
||||
}
|
||||
Array<int64> new_tile_assignment = data_sharding.tile_assignment();
|
||||
if (new_tile_assignment.num_elements() !=
|
||||
Product(index_tile_assignment_dims)) {
|
||||
return HloSharding::Replicate();
|
||||
}
|
||||
new_tile_assignment.Reshape(index_tile_assignment_dims);
|
||||
return HloSharding::Tile(new_tile_assignment);
|
||||
}
|
||||
@ -444,6 +457,10 @@ HloSharding ScatterDataSharding(const HloSharding& index_sharding,
|
||||
}
|
||||
}
|
||||
Array<int64> new_tile_assignment = index_sharding.tile_assignment();
|
||||
if (new_tile_assignment.num_elements() !=
|
||||
Product(data_tile_assignment_dims)) {
|
||||
return HloSharding::Replicate();
|
||||
}
|
||||
new_tile_assignment.Reshape(data_tile_assignment_dims);
|
||||
return HloSharding::Tile(new_tile_assignment);
|
||||
}
|
||||
@ -533,6 +550,169 @@ HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
|
||||
return HloSharding::Tile(tile_assignment);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// If partitioning in the operand only happens in dimensions in passthrough
|
||||
// dimensions (offset dimensions in the gather output (or scatter update) that
|
||||
// have the same size as the operand), returns the corresponding output (or
|
||||
// update) sharding by passing through the input sharding.
|
||||
absl::optional<HloSharding> PassthroughOperandToGatherOutputOrScatterUpdate(
|
||||
const Shape& operand_shape, const HloSharding& operand_sharding,
|
||||
const Shape& update_or_gather_shape,
|
||||
absl::Span<const int64> collapsed_or_inserted_dims,
|
||||
absl::Span<const int64> index_map,
|
||||
absl::Span<const int64> offset_or_window_dims,
|
||||
absl::Span<const int64> slice_size) {
|
||||
if (operand_sharding.IsTileMaximal()) {
|
||||
return operand_sharding;
|
||||
}
|
||||
std::vector<int64> passthrough_tile(update_or_gather_shape.rank(), 1);
|
||||
int64 collapsed = 0;
|
||||
for (int64 i = 0; i < operand_shape.rank(); ++i) {
|
||||
int64 dim_partitions = operand_sharding.tile_assignment().dim(i);
|
||||
if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
|
||||
absl::c_linear_search(index_map, i)) {
|
||||
if (dim_partitions > 1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
collapsed++;
|
||||
continue;
|
||||
}
|
||||
if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
int64 offset_dim = offset_or_window_dims[i - collapsed];
|
||||
if (i - collapsed > 0 &&
|
||||
offset_dim < offset_or_window_dims[i - collapsed - 1]) {
|
||||
// Output offsets are transposed, we do not support this case.
|
||||
return absl::nullopt;
|
||||
}
|
||||
passthrough_tile[offset_dim] = dim_partitions;
|
||||
}
|
||||
Array<int64> tile_assignment = operand_sharding.tile_assignment();
|
||||
tile_assignment.Reshape(passthrough_tile);
|
||||
return HloSharding::Tile(tile_assignment);
|
||||
}
|
||||
|
||||
// Inverse of PassthroughOperandToGatherOutputOrScatterUpdate.
|
||||
absl::optional<HloSharding> PassthroughGatherOutputOrScatterUpdateToOperand(
|
||||
const Shape& operand_shape, const HloSharding& update_or_gather_sharding,
|
||||
absl::Span<const int64> collapsed_or_inserted_dims,
|
||||
absl::Span<const int64> index_map,
|
||||
absl::Span<const int64> offset_or_window_dims,
|
||||
absl::Span<const int64> slice_size) {
|
||||
if (update_or_gather_sharding.IsTileMaximal()) {
|
||||
return update_or_gather_sharding;
|
||||
}
|
||||
std::vector<int64> passthrough_tile(operand_shape.rank(), 1);
|
||||
int64 collapsed = 0;
|
||||
for (int64 i = 0; i < operand_shape.rank(); ++i) {
|
||||
if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
|
||||
absl::c_linear_search(index_map, i)) {
|
||||
collapsed++;
|
||||
continue;
|
||||
}
|
||||
int64 offset_dim = offset_or_window_dims[i - collapsed];
|
||||
int64 dim_partitions =
|
||||
update_or_gather_sharding.tile_assignment().dim(offset_dim);
|
||||
if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
if (i - collapsed > 0 &&
|
||||
offset_dim < offset_or_window_dims[i - collapsed - 1]) {
|
||||
// Output offsets are transposed, we do not support this case.
|
||||
return absl::nullopt;
|
||||
}
|
||||
passthrough_tile[i] = dim_partitions;
|
||||
}
|
||||
Array<int64> tile_assignment = update_or_gather_sharding.tile_assignment();
|
||||
if (tile_assignment.num_elements() != Product(passthrough_tile)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
tile_assignment.Reshape(passthrough_tile);
|
||||
return HloSharding::Tile(tile_assignment);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
|
||||
const HloSharding& data_operand_sharding, const HloInstruction& hlo) {
|
||||
const auto& dnums = hlo.gather_dimension_numbers();
|
||||
std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
|
||||
dnums.collapsed_slice_dims().end());
|
||||
std::vector<int64> start_index_map(dnums.start_index_map().begin(),
|
||||
dnums.start_index_map().end());
|
||||
std::vector<int64> offset_dims(dnums.offset_dims().begin(),
|
||||
dnums.offset_dims().end());
|
||||
return PassthroughOperandToGatherOutputOrScatterUpdate(
|
||||
hlo.operand(0)->shape(), data_operand_sharding, hlo.shape(),
|
||||
collapsed_slice_dims, start_index_map, offset_dims,
|
||||
hlo.gather_slice_sizes());
|
||||
}
|
||||
|
||||
absl::optional<HloSharding> GatherDataOperandShardingFromOutput(
|
||||
const HloSharding& output_sharding, const HloInstruction& hlo) {
|
||||
const auto& dnums = hlo.gather_dimension_numbers();
|
||||
std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
|
||||
dnums.collapsed_slice_dims().end());
|
||||
std::vector<int64> start_index_map(dnums.start_index_map().begin(),
|
||||
dnums.start_index_map().end());
|
||||
std::vector<int64> offset_dims(dnums.offset_dims().begin(),
|
||||
dnums.offset_dims().end());
|
||||
return PassthroughGatherOutputOrScatterUpdateToOperand(
|
||||
hlo.operand(0)->shape(), output_sharding, collapsed_slice_dims,
|
||||
start_index_map, offset_dims, hlo.gather_slice_sizes());
|
||||
}
|
||||
|
||||
absl::optional<HloSharding> ScatterOutputShardingFromUpdate(
|
||||
const HloSharding& update_sharding, const HloInstruction& hlo) {
|
||||
const auto& dnums = hlo.scatter_dimension_numbers();
|
||||
std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
|
||||
dnums.inserted_window_dims().end());
|
||||
std::vector<int64> scatter_dims_to_operand_dims(
|
||||
dnums.scatter_dims_to_operand_dims().begin(),
|
||||
dnums.scatter_dims_to_operand_dims().end());
|
||||
std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
|
||||
dnums.update_window_dims().end());
|
||||
std::vector<int64> slice_size(hlo.shape().rank(), 1);
|
||||
int64 num_update_window_dims = 0;
|
||||
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
|
||||
if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
|
||||
continue;
|
||||
}
|
||||
slice_size[i] = hlo.operand(2)->shape().dimensions(
|
||||
dnums.update_window_dims(num_update_window_dims++));
|
||||
}
|
||||
return PassthroughGatherOutputOrScatterUpdateToOperand(
|
||||
hlo.shape(), update_sharding, inserted_window_dims,
|
||||
scatter_dims_to_operand_dims, update_window_dims, slice_size);
|
||||
}
|
||||
|
||||
absl::optional<HloSharding> ScatterUpdateShardingFromOutput(
|
||||
const HloSharding& output_sharding, const HloInstruction& hlo) {
|
||||
const auto& dnums = hlo.scatter_dimension_numbers();
|
||||
std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
|
||||
dnums.inserted_window_dims().end());
|
||||
std::vector<int64> scatter_dims_to_operand_dims(
|
||||
dnums.scatter_dims_to_operand_dims().begin(),
|
||||
dnums.scatter_dims_to_operand_dims().end());
|
||||
std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
|
||||
dnums.update_window_dims().end());
|
||||
std::vector<int64> slice_size(hlo.shape().rank(), 1);
|
||||
int64 num_update_window_dims = 0;
|
||||
for (int64 i = 0; i < hlo.shape().rank(); ++i) {
|
||||
if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
|
||||
continue;
|
||||
}
|
||||
slice_size[i] = hlo.operand(2)->shape().dimensions(
|
||||
dnums.update_window_dims(num_update_window_dims++));
|
||||
}
|
||||
return PassthroughOperandToGatherOutputOrScatterUpdate(
|
||||
hlo.shape(), output_sharding, hlo.operand(2)->shape(),
|
||||
inserted_window_dims, scatter_dims_to_operand_dims, update_window_dims,
|
||||
slice_size);
|
||||
}
|
||||
|
||||
StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
|
||||
IdentityValueAndHloOpcodeForScatterReduceComputation(
|
||||
const HloScatterInstruction& scatter) {
|
||||
|
@ -127,6 +127,26 @@ HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
|
||||
HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
|
||||
const HloInstruction& hlo);
|
||||
|
||||
// Returns an output sharding of gather by passing through the data operand's
|
||||
// sharding.
|
||||
absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
|
||||
const HloSharding& data_operand_sharding, const HloInstruction& hlo);
|
||||
|
||||
// Returns a data operand sharding of gather by passing through the output's
|
||||
// sharding.
|
||||
absl::optional<HloSharding> GatherDataOperandShardingFromOutput(
|
||||
const HloSharding& output_sharding, const HloInstruction& hlo);
|
||||
|
||||
// Returns an output sharding of scatter by passing through the update operand's
|
||||
// sharding.
|
||||
absl::optional<HloSharding> ScatterOutputShardingFromUpdate(
|
||||
const HloSharding& update_sharding, const HloInstruction& hlo);
|
||||
|
||||
// Returns an update operand sharding of scatter by passing through the output's
|
||||
// sharding.
|
||||
absl::optional<HloSharding> ScatterUpdateShardingFromOutput(
|
||||
const HloSharding& output_sharding, const HloInstruction& hlo);
|
||||
|
||||
// Returns an identity value and an HloOpcode for reduce computation of scatter
|
||||
// instruction.
|
||||
// - If computation is add/or, return 0/false with corresponding op code;
|
||||
|
@ -899,20 +899,45 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
return propagate_slicing() || propagate_base();
|
||||
}
|
||||
case HloOpcode::kGather: {
|
||||
if (!IsSpatiallyPartitioned(instruction->operand(1))) {
|
||||
return false;
|
||||
bool changed = false;
|
||||
if (IsSpatiallyPartitioned(instruction->operand(1))) {
|
||||
HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
|
||||
instruction->operand(1)->sharding(), instruction);
|
||||
changed |= MaybeImproveInstructionSharding(new_sharding, instruction);
|
||||
}
|
||||
HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
|
||||
instruction->operand(1)->sharding(), instruction);
|
||||
return MaybeImproveInstructionSharding(new_sharding, instruction);
|
||||
if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) {
|
||||
auto maybe_from_data =
|
||||
hlo_sharding_util::GatherOutputShardingFromDataOperand(
|
||||
instruction->operand(0)->sharding(), *instruction);
|
||||
if (maybe_from_data) {
|
||||
changed |=
|
||||
MaybeImproveInstructionSharding(*maybe_from_data, instruction);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
case HloOpcode::kScatter: {
|
||||
bool changed = false;
|
||||
if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) {
|
||||
changed |= MaybeImproveInstructionSharding(
|
||||
instruction->operand(0)->sharding(), instruction);
|
||||
}
|
||||
if (!IsSpatiallyPartitioned(instruction->operand(1)) &&
|
||||
!IsSpatiallyPartitioned(instruction->operand(2))) {
|
||||
return false;
|
||||
}
|
||||
return MaybeImproveInstructionSharding(HloSharding::Replicate(),
|
||||
instruction);
|
||||
if (is_spmd && IsSpatiallyPartitioned(instruction->operand(2))) {
|
||||
auto maybe_from_update =
|
||||
hlo_sharding_util::ScatterOutputShardingFromUpdate(
|
||||
instruction->operand(2)->sharding(), *instruction);
|
||||
if (maybe_from_update) {
|
||||
changed |=
|
||||
MaybeImproveInstructionSharding(*maybe_from_update, instruction);
|
||||
}
|
||||
}
|
||||
changed |= MaybeImproveInstructionSharding(HloSharding::Replicate(),
|
||||
instruction);
|
||||
return changed;
|
||||
}
|
||||
case HloOpcode::kWhile: {
|
||||
if (!instruction->operand(0)->has_sharding()) {
|
||||
@ -1218,6 +1243,43 @@ absl::optional<HloSharding> GetShardingFromUser(
|
||||
return hlo_sharding_util::ReverseSharding(user.sharding(),
|
||||
user.dimensions());
|
||||
}
|
||||
case HloOpcode::kGather: {
|
||||
if (&instruction == user.operand(1)) {
|
||||
return hlo_sharding_util::GatherIndexSharding(user.sharding(), &user);
|
||||
}
|
||||
if (is_spmd) {
|
||||
return hlo_sharding_util::GatherDataOperandShardingFromOutput(
|
||||
user.sharding(), user);
|
||||
}
|
||||
return absl::nullopt;
|
||||
}
|
||||
case HloOpcode::kScatter: {
|
||||
if (&instruction == user.operand(0)) {
|
||||
return user.sharding();
|
||||
}
|
||||
if (&instruction == user.operand(1)) {
|
||||
auto update = user.operand(2);
|
||||
if (!IsSpatiallyPartitioned(update)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
return hlo_sharding_util::ScatterIndexSharding(update->sharding(),
|
||||
&user);
|
||||
}
|
||||
CHECK_EQ(&instruction, user.operand(2));
|
||||
auto indices = user.operand(1);
|
||||
if (IsSpatiallyPartitioned(indices)) {
|
||||
auto from_indices =
|
||||
hlo_sharding_util::ScatterDataSharding(indices->sharding(), &user);
|
||||
if (!from_indices.IsTileMaximal()) {
|
||||
return from_indices;
|
||||
}
|
||||
}
|
||||
if (is_spmd) {
|
||||
return hlo_sharding_util::ScatterUpdateShardingFromOutput(
|
||||
user.sharding(), user);
|
||||
}
|
||||
return absl::nullopt;
|
||||
}
|
||||
default: {
|
||||
// If the user output shape is compatible with the current instruction
|
||||
// shape excluding element type and the current instruction is supported
|
||||
|
@ -1494,5 +1494,275 @@ ENTRY entry {
|
||||
op::Sharding("{devices=[2,1,1,1]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, GatherFromIndex) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%input = f32[2,9] parameter(0), sharding={replicated}
|
||||
%indices = s32[3] parameter(1), sharding={devices=[2]0,1}
|
||||
%gather = f32[3,9] gather(%input, %indices), offset_dims={1},
|
||||
collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
|
||||
slice_sizes={1,9}
|
||||
ROOT %copy = f32[3,9] copy(%gather)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
ShardingPropagation().Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "gather"),
|
||||
op::Sharding("{devices=[2,1]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, GatherFromDataOperand) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1}
|
||||
%indices = s32[3] parameter(1), sharding={replicated}
|
||||
%gather = f32[3,9] gather(%input, %indices), offset_dims={1},
|
||||
collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
|
||||
slice_sizes={1,9}
|
||||
ROOT %copy = f32[3,9] copy(%gather)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "gather"),
|
||||
op::Sharding("{devices=[1,2]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, GatherToIndex) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%input = f32[2,9] parameter(0), sharding={replicated}
|
||||
%p1 = s32[3] parameter(1)
|
||||
%indices = s32[3] copy(%p1)
|
||||
ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
|
||||
collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
|
||||
slice_sizes={1,9}, sharding={devices=[2,1]0,1}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
ShardingPropagation().Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "indices"),
|
||||
op::Sharding("{devices=[2]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, GatherToDataOperand) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%p0 = f32[2,9] parameter(0)
|
||||
%input = f32[2,9] copy(%p0)
|
||||
%indices = s32[3] parameter(1), sharding={replicated}
|
||||
ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
|
||||
collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
|
||||
slice_sizes={1,9}, sharding={devices=[1,2]0,1}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "input"),
|
||||
op::Sharding("{devices=[1,2]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, DataOperandToScatter) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT sum = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1}
|
||||
%indices = s32[3] parameter(1), sharding={replicated}
|
||||
%updates = f32[3,9] parameter(2), sharding={replicated}
|
||||
%scatter = f32[2,9] scatter(%input, %indices, %updates),
|
||||
to_apply=add,
|
||||
update_window_dims={1},
|
||||
inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=1
|
||||
ROOT %copy = f32[2,9] copy(%scatter)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "scatter"),
|
||||
op::Sharding("{devices=[1,2]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, UpdateOperandToScatter) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT sum = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%input = f32[2,9] parameter(0), sharding={replicated}
|
||||
%indices = s32[3] parameter(1), sharding={replicated}
|
||||
%updates = f32[3,9] parameter(2), sharding={devices=[1,2]0,1}
|
||||
%scatter = f32[2,9] scatter(%input, %indices, %updates),
|
||||
to_apply=add,
|
||||
update_window_dims={1},
|
||||
inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=1
|
||||
ROOT %copy = f32[2,9] copy(%scatter)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "scatter"),
|
||||
op::Sharding("{devices=[1,2]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, ScatterToDataOperand) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT sum = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%p0 = f32[2,9] parameter(0)
|
||||
%input = f32[2,9] copy(%p0)
|
||||
%indices = s32[3] parameter(1), sharding={replicated}
|
||||
%updates = f32[3,9] parameter(2), sharding={replicated}
|
||||
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
|
||||
to_apply=add,
|
||||
update_window_dims={1},
|
||||
inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=1, sharding={devices=[1,2]0,1}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
ShardingPropagation().Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "input"),
|
||||
op::Sharding("{devices=[1,2]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, ScatterToUpdateOperand) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT sum = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%input = f32[2,9] parameter(0)
|
||||
%indices = s32[3] parameter(1), sharding={replicated}
|
||||
%p2 = f32[3,9] parameter(2)
|
||||
%updates = f32[3,9] copy(%p2)
|
||||
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
|
||||
to_apply=add,
|
||||
update_window_dims={1},
|
||||
inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=1, sharding={devices=[1,2]0,1}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "updates"),
|
||||
op::Sharding("{devices=[1,2]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, ScatterUpdateToIndex) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT sum = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%input = f32[2,9] parameter(0), sharding={replicated}
|
||||
%p1 = s32[3] parameter(1), sharding={replicated}
|
||||
%indices = s32[3] copy(%p1)
|
||||
%updates = f32[3,9] parameter(2), sharding={devices=[2,1]0,1}
|
||||
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
|
||||
to_apply=add,
|
||||
update_window_dims={1},
|
||||
inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=1, sharding={replicated}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
ShardingPropagation().Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "indices"),
|
||||
op::Sharding("{devices=[2]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, ScatterIndexToUpdate) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT sum = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%input = f32[2,9] parameter(0), sharding={replicated}
|
||||
%indices = s32[3] parameter(1), sharding={devices=[2]0,1}
|
||||
%p2 = f32[3,9] parameter(2), sharding={replicated}
|
||||
%updates = f32[3,9] copy(%p2)
|
||||
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
|
||||
to_apply=add,
|
||||
update_window_dims={1},
|
||||
inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=1, sharding={replicated}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
ShardingPropagation().Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "updates"),
|
||||
op::Sharding("{devices=[2,1]0,1}"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -1069,47 +1069,7 @@ Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If partitioning in the operand only happens in dimensions in passthrough
|
||||
// dimensions (offset dimensions in the gather output (or scatter update) that
|
||||
// have the same size as the operand), returns the corresponding output (or
|
||||
// update) sharding by passing through the input sharding.
|
||||
absl::optional<HloSharding> PassthroughOperandToGatherOutputOrScatterUpdate(
|
||||
const PartitionedHlo& operand, const Shape& update_or_gather_shape,
|
||||
absl::Span<const int64> collapsed_or_inserted_dims,
|
||||
absl::Span<const int64> index_map,
|
||||
absl::Span<const int64> offset_or_window_dims,
|
||||
absl::Span<const int64> slice_size) {
|
||||
if (operand.sharding().IsTileMaximal()) {
|
||||
return operand.sharding();
|
||||
}
|
||||
std::vector<int64> passthrough_tile(update_or_gather_shape.rank(), 1);
|
||||
int64 collapsed = 0;
|
||||
for (int64 i = 0; i < operand.base_shape().rank(); ++i) {
|
||||
int64 dim_partitions = operand.sharding().tile_assignment().dim(i);
|
||||
if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
|
||||
absl::c_linear_search(index_map, i)) {
|
||||
if (dim_partitions > 1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
collapsed++;
|
||||
continue;
|
||||
}
|
||||
if (slice_size[i] != operand.base_shape().dimensions(i) &&
|
||||
dim_partitions > 1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
int64 offset_dim = offset_or_window_dims[i - collapsed];
|
||||
if (i - collapsed > 0 &&
|
||||
offset_dim < offset_or_window_dims[i - collapsed - 1]) {
|
||||
// Output offsets are transposed, we do not support this case.
|
||||
return absl::nullopt;
|
||||
}
|
||||
passthrough_tile[offset_dim] = dim_partitions;
|
||||
}
|
||||
Array<int64> tile_assignment = operand.sharding().tile_assignment();
|
||||
tile_assignment.Reshape(passthrough_tile);
|
||||
return HloSharding::Tile(tile_assignment);
|
||||
}
|
||||
namespace {
|
||||
|
||||
// Returns whether partitioning in the operand only happens in dimensions with
|
||||
// gather/scatter slice size 1.
|
||||
@ -1204,6 +1164,8 @@ IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims(
|
||||
return {broadcast_min, broadcast_max};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
|
||||
auto scatter = Cast<HloScatterInstruction>(hlo);
|
||||
auto dnums = scatter->scatter_dimension_numbers();
|
||||
@ -1219,16 +1181,12 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
|
||||
slice_size[i] = updates.base_shape().dimensions(
|
||||
dnums.update_window_dims(num_update_window_dims++));
|
||||
}
|
||||
std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
|
||||
dnums.inserted_window_dims().end());
|
||||
std::vector<int64> scatter_dims_to_operand_dims(
|
||||
dnums.scatter_dims_to_operand_dims().begin(),
|
||||
dnums.scatter_dims_to_operand_dims().end());
|
||||
std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
|
||||
dnums.update_window_dims().end());
|
||||
std::vector<int64> update_scatter_dims;
|
||||
for (int64 i = 0; i < updates.base_shape().rank(); ++i) {
|
||||
if (!absl::c_linear_search(update_window_dims, i)) {
|
||||
if (!absl::c_linear_search(dnums.update_window_dims(), i)) {
|
||||
update_scatter_dims.push_back(i);
|
||||
}
|
||||
}
|
||||
@ -1292,9 +1250,8 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
|
||||
return Status::OK();
|
||||
}
|
||||
} else {
|
||||
auto maybe_passthrough = PassthroughOperandToGatherOutputOrScatterUpdate(
|
||||
operand, updates.base_shape(), inserted_window_dims,
|
||||
scatter_dims_to_operand_dims, update_window_dims, slice_size);
|
||||
auto maybe_passthrough = hlo_sharding_util::ScatterUpdateShardingFromOutput(
|
||||
operand.sharding(), *hlo);
|
||||
// Handle pass through cases if we can use compatible sharding for update.
|
||||
if (maybe_passthrough.has_value()) {
|
||||
indices = indices.Reshard(HloSharding::Replicate());
|
||||
@ -2148,15 +2105,11 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
|
||||
const auto& dnums = gather->gather_dimension_numbers();
|
||||
auto operand = GetPartitionedHlo(gather->operand(0));
|
||||
auto indices = GetPartitionedHlo(gather->operand(1));
|
||||
std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
|
||||
dnums.collapsed_slice_dims().end());
|
||||
std::vector<int64> start_index_map(dnums.start_index_map().begin(),
|
||||
dnums.start_index_map().end());
|
||||
std::vector<int64> offset_dims(dnums.offset_dims().begin(),
|
||||
dnums.offset_dims().end());
|
||||
std::vector<int64> batch_dims;
|
||||
for (int64 i = 0; i < gather->shape().rank(); ++i) {
|
||||
if (!absl::c_linear_search(offset_dims, i)) {
|
||||
if (!absl::c_linear_search(dnums.offset_dims(), i)) {
|
||||
batch_dims.push_back(i);
|
||||
}
|
||||
}
|
||||
@ -2193,9 +2146,9 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
|
||||
return Status::OK();
|
||||
}
|
||||
} else {
|
||||
auto maybe_passthrough = PassthroughOperandToGatherOutputOrScatterUpdate(
|
||||
operand, gather->shape(), collapsed_slice_dims, start_index_map,
|
||||
offset_dims, gather->gather_slice_sizes());
|
||||
auto maybe_passthrough =
|
||||
hlo_sharding_util::GatherOutputShardingFromDataOperand(
|
||||
operand.sharding(), *hlo);
|
||||
if (maybe_passthrough.has_value()) {
|
||||
indices = indices.Reshard(HloSharding::Replicate());
|
||||
auto pshape = MakePartitionedShape(gather->shape(), *maybe_passthrough);
|
||||
|
Loading…
Reference in New Issue
Block a user