[XLA:SPMD] Support Scatter in Partial Replicate.

PiperOrigin-RevId: 329023692
Change-Id: I28450c59dd259a573db2ee692201e6b1441a09aa
This commit is contained in:
A. Unique TensorFlower 2020-08-28 15:50:27 -07:00 committed by TensorFlower Gardener
parent 49b58c7b7f
commit ebce61dc84
5 changed files with 329 additions and 6 deletions

View File

@ -470,13 +470,19 @@ HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) {
index_tile_assignment_dims.push_back(1);
}
if (data_sharding.ReplicateOnLastTileDim()) {
index_tile_assignment_dims.push_back(
data_sharding.tile_assignment().dimensions().back());
}
Array<int64> new_tile_assignment = data_sharding.tile_assignment();
if (new_tile_assignment.num_elements() !=
Product(index_tile_assignment_dims)) {
return HloSharding::Replicate();
}
new_tile_assignment.Reshape(index_tile_assignment_dims);
return HloSharding::Tile(new_tile_assignment);
return data_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
}
HloSharding ScatterDataSharding(const HloSharding& index_sharding,
@ -496,13 +502,19 @@ HloSharding ScatterDataSharding(const HloSharding& index_sharding,
index_dim++;
}
}
if (index_sharding.ReplicateOnLastTileDim()) {
data_tile_assignment_dims.push_back(
index_sharding.tile_assignment().dimensions().back());
}
Array<int64> new_tile_assignment = index_sharding.tile_assignment();
if (new_tile_assignment.num_elements() !=
Product(data_tile_assignment_dims)) {
return HloSharding::Replicate();
}
new_tile_assignment.Reshape(data_tile_assignment_dims);
return HloSharding::Tile(new_tile_assignment);
return index_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
}
HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,

View File

@ -2025,6 +2025,38 @@ ENTRY entry {
op::Sharding("{devices=[1,2]0,1}"));
}
TEST_F(ShardingPropagationTest, DataOperandToScatter_PartialReplicate) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[2,9] parameter(0),
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
%indices = s32[3] parameter(1), sharding={replicated}
%updates = f32[3,9] parameter(2), sharding={replicated}
%scatter = f32[2,9] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
ROOT %copy = f32[2,9] copy(%scatter)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "scatter"),
op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"));
}
TEST_F(ShardingPropagationTest, UpdateOperandToScatter) {
const char* const hlo_string = R"(
HloModule module
@ -2056,6 +2088,70 @@ ENTRY entry {
op::Sharding("{devices=[1,2]0,1}"));
}
TEST_F(ShardingPropagationTest, UpdateOperandToScatter_PartialReplicate) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[2,9] parameter(0), sharding={replicated}
%indices = s32[3] parameter(1), sharding={replicated}
%updates = f32[3,9] parameter(2),
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
%scatter = f32[2,9] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
ROOT %copy = f32[2,9] copy(%scatter)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "scatter"),
op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"));
}
TEST_F(ShardingPropagationTest, ScatterToDataOperand_PartialReplicate) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%p0 = f32[2,9] parameter(0)
%input = f32[2,9] copy(%p0)
%indices = s32[3] parameter(1), sharding={replicated}
%updates = f32[3,9] parameter(2), sharding={replicated}
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1,
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
ShardingPropagation().Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "input"),
op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"));
}
TEST_F(ShardingPropagationTest, ScatterToDataOperand) {
const char* const hlo_string = R"(
HloModule module
@ -2087,6 +2183,38 @@ ENTRY entry {
op::Sharding("{devices=[1,2]0,1}"));
}
TEST_F(ShardingPropagationTest, ScatterToUpdateOperand_PartialReplicate) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[2,9] parameter(0)
%indices = s32[3] parameter(1), sharding={replicated}
%p2 = f32[3,9] parameter(2)
%updates = f32[3,9] copy(%p2)
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1,
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "updates"),
op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}"));
}
TEST_F(ShardingPropagationTest, ScatterToUpdateOperand) {
const char* const hlo_string = R"(
HloModule module
@ -2149,6 +2277,38 @@ ENTRY entry {
op::Sharding("{devices=[2]0,1}"));
}
TEST_F(ShardingPropagationTest, ScatterUpdateToIndex_PartialReplicate) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[2,9] parameter(0), sharding={replicated}
%p1 = s32[3] parameter(1), sharding={replicated}
%indices = s32[3] copy(%p1)
%updates = f32[3,9] parameter(2),
sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
ShardingPropagation().Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "indices"),
op::Sharding("{devices=[2,2]0,1,2,3 last_tile_dim_replicate}"));
}
TEST_F(ShardingPropagationTest, ScatterIndexToUpdate) {
const char* const hlo_string = R"(
HloModule module
@ -2180,6 +2340,38 @@ ENTRY entry {
op::Sharding("{devices=[2,1]0,1}"));
}
TEST_F(ShardingPropagationTest, ScatterIndexToUpdate_PartialReplicate) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[2,9] parameter(0), sharding={replicated}
%indices = s32[3] parameter(1),
sharding={devices=[2,2]0,1,2,3 last_tile_dim_replicate}
%p2 = f32[3,9] parameter(2), sharding={replicated}
%updates = f32[3,9] copy(%p2)
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
ShardingPropagation().Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "updates"),
op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}"));
}
TEST_F(ShardingPropagationTest, PartialShardingOnElementwise) {
const char* const hlo_string = R"(
HloModule module

View File

@ -1451,10 +1451,23 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
update_dim_to_index_dim);
CHECK(new_updates_sharding.has_value());
updates = updates.Reshard(*new_updates_sharding);
// Update collective_ops_creator and partition_id for partial replicate.
auto collective_ops_creator = collective_ops_creator_;
auto partition_id = partition_id_;
if (indices.sharding().ReplicateOnLastTileDim()) {
auto sharding_grouped = GroupShardingOnDims(
indices.sharding(),
{indices.sharding().tile_assignment().num_dimensions() - 1});
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
indices.state(), sharding_grouped.device_groups, &b_);
collective_ops_creator =
per_group_partitioner_state.collective_ops_creator;
partition_id = per_group_partitioner_state.partition_id;
}
// To avoid accumulating the initial operand multiple times during
// all-reduce, we use identity operands for all non-zero partitions.
auto not_partition_zero = b_.AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeScalarShape(PRED), partition_id_));
ShapeUtil::MakeScalarShape(PRED), partition_id));
not_partition_zero = b_.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(identity->shape(), PRED),
not_partition_zero, {}));
@ -1465,7 +1478,7 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
auto pscatter = b_.AddInstruction(scatter->CloneWithNewOperands(
scatter->shape(), {select_operand, indices.hlo(), updates.hlo()}));
auto all_reduce =
collective_ops_creator_.create_cross_partition_all_reduce(
collective_ops_creator.create_cross_partition_all_reduce(
&b_, pscatter, scatter->to_apply(), {}, NewChannel());
all_reduce->set_sharding(HloSharding::Replicate());
SetPartitionedHlo(hlo, [&]() {

View File

@ -4070,6 +4070,39 @@ ENTRY entry {
op::Shape("f32[2,5]")));
}
TEST_F(SpmdPartitioningTest, PassthroughScatter_PartialReplicate) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[2,9] parameter(0),
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
%indices = s32[3] parameter(1), sharding={replicated}
%updates = f32[3,9] parameter(2),
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1,
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1),
op::Parameter(2)),
op::Shape("f32[2,5]")));
}
TEST_F(SpmdPartitioningTest, IndexPassthroughScatter) {
const char* const hlo_string = R"(
HloModule module
@ -4104,6 +4137,42 @@ ENTRY entry {
op::Shape("f32[2,9,8]")));
}
TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_PartialReplicate) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[2,9,8] parameter(0), sharding={replicated}
%indices = s32[4,2,4] parameter(1),
sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
%updates = f32[4,4,8] parameter(2),
sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={2},
inserted_window_dims={0,1},
scatter_dims_to_operand_dims={0,1},
index_vector_dim=1, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/8));
VLOG(1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(op::AllReduce(op::Scatter(
op::Select(op::Broadcast(op::Convert(op::Reshape())),
op::Broadcast(op::Constant()), op::Parameter(0)),
op::Parameter(1), op::Parameter(2))),
op::Shape("f32[2,9,8]")));
}
TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_Min) {
const char* const hlo_string = R"(
HloModule module
@ -4172,6 +4241,43 @@ ENTRY entry {
op::Shape("f32[9,9]")));
}
TEST_F(SpmdPartitioningTest,
ScatterPartitionedOnTrivialSliceDims_PartialReplicate) {
const char* const hlo_string = R"(
HloModule module
add (lhs: f32[], rhs: f32[]) -> f32[] {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT sum = f32[] add(lhs, rhs)
}
ENTRY entry {
%input = f32[17,9] parameter(0),
sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
%indices = s32[2,3] parameter(1), sharding={replicated}
%updates = f32[2,3,9] parameter(2), sharding={replicated}
ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates),
to_apply=add,
update_window_dims={2},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=2,
sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto offset =
op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
auto indices = op::Subtract(
op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root,
AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)),
op::Shape("f32[9,9]")));
}
TEST_F(SpmdPartitioningTest, TiledReversePassthrough) {
const char* const hlo_string = R"(
HloModule module

View File

@ -362,8 +362,8 @@ absl::optional<HloInstruction*> PadFromPartialReplicateShape(
// dimensions by dynamic slice.
// For example, if partial_sharding is
// {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
// Target tile dims is {2, 2}, the returned compatible sharding will be
// sharding={devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}.
// Target sharding is {devices=[2,2]0,1,2,3}, the returned compatible sharding
// will be sharding={devices=[2,2]0,2,1,3}.
// If patial replicate sharding is not partial replicate or can't reshard to
// target_tile_dims by dynamic slice, return absl::nullopt.
// If target_sharding is already compatible, returns it.