[XLA:SPMD] Support Scatter in Partial Replicate.
PiperOrigin-RevId: 329023692 Change-Id: I28450c59dd259a573db2ee692201e6b1441a09aa
This commit is contained in:
parent
49b58c7b7f
commit
ebce61dc84
@ -470,13 +470,19 @@ HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
|
||||
if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) {
|
||||
index_tile_assignment_dims.push_back(1);
|
||||
}
|
||||
if (data_sharding.ReplicateOnLastTileDim()) {
|
||||
index_tile_assignment_dims.push_back(
|
||||
data_sharding.tile_assignment().dimensions().back());
|
||||
}
|
||||
Array<int64> new_tile_assignment = data_sharding.tile_assignment();
|
||||
if (new_tile_assignment.num_elements() !=
|
||||
Product(index_tile_assignment_dims)) {
|
||||
return HloSharding::Replicate();
|
||||
}
|
||||
new_tile_assignment.Reshape(index_tile_assignment_dims);
|
||||
return HloSharding::Tile(new_tile_assignment);
|
||||
return data_sharding.ReplicateOnLastTileDim()
|
||||
? HloSharding::PartialTile(new_tile_assignment)
|
||||
: HloSharding::Tile(new_tile_assignment);
|
||||
}
|
||||
|
||||
HloSharding ScatterDataSharding(const HloSharding& index_sharding,
|
||||
@ -496,13 +502,19 @@ HloSharding ScatterDataSharding(const HloSharding& index_sharding,
|
||||
index_dim++;
|
||||
}
|
||||
}
|
||||
if (index_sharding.ReplicateOnLastTileDim()) {
|
||||
data_tile_assignment_dims.push_back(
|
||||
index_sharding.tile_assignment().dimensions().back());
|
||||
}
|
||||
Array<int64> new_tile_assignment = index_sharding.tile_assignment();
|
||||
if (new_tile_assignment.num_elements() !=
|
||||
Product(data_tile_assignment_dims)) {
|
||||
return HloSharding::Replicate();
|
||||
}
|
||||
new_tile_assignment.Reshape(data_tile_assignment_dims);
|
||||
return HloSharding::Tile(new_tile_assignment);
|
||||
return index_sharding.ReplicateOnLastTileDim()
|
||||
? HloSharding::PartialTile(new_tile_assignment)
|
||||
: HloSharding::Tile(new_tile_assignment);
|
||||
}
|
||||
|
||||
HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
|
||||
|
||||
@ -2025,6 +2025,38 @@ ENTRY entry {
|
||||
op::Sharding("{devices=[1,2]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, DataOperandToScatter_PartialReplicate) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT sum = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%input = f32[2,9] parameter(0),
|
||||
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
|
||||
%indices = s32[3] parameter(1), sharding={replicated}
|
||||
%updates = f32[3,9] parameter(2), sharding={replicated}
|
||||
%scatter = f32[2,9] scatter(%input, %indices, %updates),
|
||||
to_apply=add,
|
||||
update_window_dims={1},
|
||||
inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=1
|
||||
ROOT %copy = f32[2,9] copy(%scatter)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "scatter"),
|
||||
op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, UpdateOperandToScatter) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
@ -2056,6 +2088,70 @@ ENTRY entry {
|
||||
op::Sharding("{devices=[1,2]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, UpdateOperandToScatter_PartialReplicate) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT sum = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%input = f32[2,9] parameter(0), sharding={replicated}
|
||||
%indices = s32[3] parameter(1), sharding={replicated}
|
||||
%updates = f32[3,9] parameter(2),
|
||||
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
|
||||
%scatter = f32[2,9] scatter(%input, %indices, %updates),
|
||||
to_apply=add,
|
||||
update_window_dims={1},
|
||||
inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=1
|
||||
ROOT %copy = f32[2,9] copy(%scatter)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "scatter"),
|
||||
op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, ScatterToDataOperand_PartialReplicate) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT sum = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%p0 = f32[2,9] parameter(0)
|
||||
%input = f32[2,9] copy(%p0)
|
||||
%indices = s32[3] parameter(1), sharding={replicated}
|
||||
%updates = f32[3,9] parameter(2), sharding={replicated}
|
||||
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
|
||||
to_apply=add,
|
||||
update_window_dims={1},
|
||||
inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=1,
|
||||
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
ShardingPropagation().Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "input"),
|
||||
op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, ScatterToDataOperand) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
@ -2087,6 +2183,38 @@ ENTRY entry {
|
||||
op::Sharding("{devices=[1,2]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, ScatterToUpdateOperand_PartialReplicate) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT sum = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%input = f32[2,9] parameter(0)
|
||||
%indices = s32[3] parameter(1), sharding={replicated}
|
||||
%p2 = f32[3,9] parameter(2)
|
||||
%updates = f32[3,9] copy(%p2)
|
||||
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
|
||||
to_apply=add,
|
||||
update_window_dims={1},
|
||||
inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=1,
|
||||
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "updates"),
|
||||
op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, ScatterToUpdateOperand) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
@ -2149,6 +2277,38 @@ ENTRY entry {
|
||||
op::Sharding("{devices=[2]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, ScatterUpdateToIndex_PartialReplicate) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT sum = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%input = f32[2,9] parameter(0), sharding={replicated}
|
||||
%p1 = s32[3] parameter(1), sharding={replicated}
|
||||
%indices = s32[3] copy(%p1)
|
||||
%updates = f32[3,9] parameter(2),
|
||||
sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
|
||||
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
|
||||
to_apply=add,
|
||||
update_window_dims={1},
|
||||
inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=1, sharding={replicated}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
ShardingPropagation().Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "indices"),
|
||||
op::Sharding("{devices=[2,2]0,1,2,3 last_tile_dim_replicate}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, ScatterIndexToUpdate) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
@ -2180,6 +2340,38 @@ ENTRY entry {
|
||||
op::Sharding("{devices=[2,1]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, ScatterIndexToUpdate_PartialReplicate) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT sum = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%input = f32[2,9] parameter(0), sharding={replicated}
|
||||
%indices = s32[3] parameter(1),
|
||||
sharding={devices=[2,2]0,1,2,3 last_tile_dim_replicate}
|
||||
%p2 = f32[3,9] parameter(2), sharding={replicated}
|
||||
%updates = f32[3,9] copy(%p2)
|
||||
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
|
||||
to_apply=add,
|
||||
update_window_dims={1},
|
||||
inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=1, sharding={replicated}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
ShardingPropagation().Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "updates"),
|
||||
op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, PartialShardingOnElementwise) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
@ -1451,10 +1451,23 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
|
||||
update_dim_to_index_dim);
|
||||
CHECK(new_updates_sharding.has_value());
|
||||
updates = updates.Reshard(*new_updates_sharding);
|
||||
// Update collective_ops_creator and partition_id for partial replicate.
|
||||
auto collective_ops_creator = collective_ops_creator_;
|
||||
auto partition_id = partition_id_;
|
||||
if (indices.sharding().ReplicateOnLastTileDim()) {
|
||||
auto sharding_grouped = GroupShardingOnDims(
|
||||
indices.sharding(),
|
||||
{indices.sharding().tile_assignment().num_dimensions() - 1});
|
||||
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
|
||||
indices.state(), sharding_grouped.device_groups, &b_);
|
||||
collective_ops_creator =
|
||||
per_group_partitioner_state.collective_ops_creator;
|
||||
partition_id = per_group_partitioner_state.partition_id;
|
||||
}
|
||||
// To avoid accumulating the initial operand multiple times during
|
||||
// all-reduce, we use identity operands for all non-zero partitions.
|
||||
auto not_partition_zero = b_.AddInstruction(HloInstruction::CreateConvert(
|
||||
ShapeUtil::MakeScalarShape(PRED), partition_id_));
|
||||
ShapeUtil::MakeScalarShape(PRED), partition_id));
|
||||
not_partition_zero = b_.AddInstruction(HloInstruction::CreateBroadcast(
|
||||
ShapeUtil::ChangeElementType(identity->shape(), PRED),
|
||||
not_partition_zero, {}));
|
||||
@ -1465,7 +1478,7 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
|
||||
auto pscatter = b_.AddInstruction(scatter->CloneWithNewOperands(
|
||||
scatter->shape(), {select_operand, indices.hlo(), updates.hlo()}));
|
||||
auto all_reduce =
|
||||
collective_ops_creator_.create_cross_partition_all_reduce(
|
||||
collective_ops_creator.create_cross_partition_all_reduce(
|
||||
&b_, pscatter, scatter->to_apply(), {}, NewChannel());
|
||||
all_reduce->set_sharding(HloSharding::Replicate());
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
|
||||
@ -4070,6 +4070,39 @@ ENTRY entry {
|
||||
op::Shape("f32[2,5]")));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, PassthroughScatter_PartialReplicate) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT sum = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%input = f32[2,9] parameter(0),
|
||||
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
|
||||
%indices = s32[3] parameter(1), sharding={replicated}
|
||||
%updates = f32[3,9] parameter(2),
|
||||
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
|
||||
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
|
||||
to_apply=add,
|
||||
update_window_dims={1},
|
||||
inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=1,
|
||||
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/4));
|
||||
VLOG(1) << module->ToString();
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1),
|
||||
op::Parameter(2)),
|
||||
op::Shape("f32[2,5]")));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, IndexPassthroughScatter) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
@ -4104,6 +4137,42 @@ ENTRY entry {
|
||||
op::Shape("f32[2,9,8]")));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_PartialReplicate) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT sum = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%input = f32[2,9,8] parameter(0), sharding={replicated}
|
||||
%indices = s32[4,2,4] parameter(1),
|
||||
sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
|
||||
%updates = f32[4,4,8] parameter(2),
|
||||
sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
|
||||
ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates),
|
||||
to_apply=add,
|
||||
update_window_dims={2},
|
||||
inserted_window_dims={0,1},
|
||||
scatter_dims_to_operand_dims={0,1},
|
||||
index_vector_dim=1, sharding={replicated}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/8));
|
||||
VLOG(1) << module->ToString();
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(
|
||||
root,
|
||||
AllOf(op::AllReduce(op::Scatter(
|
||||
op::Select(op::Broadcast(op::Convert(op::Reshape())),
|
||||
op::Broadcast(op::Constant()), op::Parameter(0)),
|
||||
op::Parameter(1), op::Parameter(2))),
|
||||
op::Shape("f32[2,9,8]")));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_Min) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
@ -4172,6 +4241,43 @@ ENTRY entry {
|
||||
op::Shape("f32[9,9]")));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest,
|
||||
ScatterPartitionedOnTrivialSliceDims_PartialReplicate) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT sum = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
%input = f32[17,9] parameter(0),
|
||||
sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
|
||||
%indices = s32[2,3] parameter(1), sharding={replicated}
|
||||
%updates = f32[2,3,9] parameter(2), sharding={replicated}
|
||||
ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates),
|
||||
to_apply=add,
|
||||
update_window_dims={2},
|
||||
inserted_window_dims={0},
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=2,
|
||||
sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/4));
|
||||
VLOG(1) << module->ToString();
|
||||
auto offset =
|
||||
op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
|
||||
auto indices = op::Subtract(
|
||||
op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root,
|
||||
AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)),
|
||||
op::Shape("f32[9,9]")));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, TiledReversePassthrough) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
@ -362,8 +362,8 @@ absl::optional<HloInstruction*> PadFromPartialReplicateShape(
|
||||
// dimensions by dynamic slice.
|
||||
// For example, if partial_sharding is
|
||||
// {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
|
||||
// Target tile dims is {2, 2}, the returned compatible sharding will be
|
||||
// sharding={devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}.
|
||||
// Target sharding is {devices=[2,2]0,1,2,3}, the returned compatible sharding
|
||||
// will be sharding={devices=[2,2]0,2,1,3}.
|
||||
// If patial replicate sharding is not partial replicate or can't reshard to
|
||||
// target_tile_dims by dynamic slice, return absl::nullopt.
|
||||
// If target_sharding is already compatible, returns it.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user