[XLA:SPMD] Fix wrong offsets of windowed dot (reduce-scatter case)

PiperOrigin-RevId: 341724756
Change-Id: Ied736cea8260e29dcebbbb1d79194e06ee324713
This commit is contained in:
Yuanzhong Xu 2020-11-10 16:44:04 -08:00 committed by TensorFlower Gardener
parent e07f05f816
commit 982a5a15b3
2 changed files with 26 additions and 49 deletions
tensorflow/compiler/xla/service/spmd

View File

@ -289,6 +289,12 @@ StatusOr<HloInstruction*> PartitionBaseCase(
to_mask.PadWithValue(b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(output_base_shape.element_type()))));
}
if (operands_sharded_at_contracting_dims) {
auto zero = b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(output_base_shape.element_type())));
lhs = lhs.PadWithValue(zero);
rhs = rhs.PadWithValue(zero);
}
auto result_buffer = CreateZero(padded_result_buffer_shape, b);
auto iteration = b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0)));
@ -333,57 +339,28 @@ StatusOr<HloInstruction*> PartitionBaseCase(
if (windowed_at_contracting_dims || windowed_at_batch_dims ||
operands_sharded_at_contracting_dims) {
// Slice the matching operand according to the partitioned dimensions on
// the windowed operand.
// the windowed operand or the output.
auto slice_operand = matching_operand == 0 ? l : r;
HloInstruction* slice;
// We do this by treating the matching operand as replicated, and
// resharding it to match the windowed operand or the output.
slice_operand->set_sharding(HloSharding::Replicate());
auto state = lhs.state();
state.b = &body_b;
state.partition_id = data_partition_id;
const HloSharding* slice_sharding;
if (operands_sharded_at_contracting_dims) {
CHECK_NE(output_sharding_dim, -1);
int64 output_sharding_dim_size =
o->shape().dimensions(output_sharding_dim);
int64 slice_dim = matching_operand == 0
? output_to_lhs_indices[output_sharding_dim]
: output_to_rhs_indices[output_sharding_dim];
auto slice_shape = slice_operand->shape();
slice_shape.set_dimensions(slice_dim, output_sharding_dim_size);
std::vector<HloInstruction*> slice_offsets(slice_shape.rank());
for (int64 i = 0; i < slice_offsets.size(); ++i) {
if (i != slice_dim) {
slice_offsets[i] =
body_b.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<uint32>(0)));
} else {
auto stride = body_b.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<uint32>(output_sharding_dim_size)));
slice_offsets[i] =
body_b.AddInstruction(HloInstruction::CreateBinary(
data_partition_id->shape(), HloOpcode::kMultiply,
data_partition_id, stride));
}
}
auto padded_shape = slice_operand->shape();
padded_shape.set_dimensions(
slice_dim,
o->shape().dimensions(output_sharding_dim) * num_partitions);
auto padded_slice_operand =
PadToShape(slice_operand, padded_shape, &body_b);
slice = body_b.AddInstruction(HloInstruction::CreateDynamicSlice(
slice_shape, padded_slice_operand, slice_offsets,
slice_shape.dimensions()));
slice_sharding = windowing_operand == 0
? &*output_sharding_transposed_to_match_rhs
: &*output_sharding_transposed_to_match_lhs;
} else {
// For windowed operand that partitioned along contracting dimensions,
// we do this by treating the matching operand as replicated, and
// resharding it to match the windowed operand.
slice_operand->set_sharding(HloSharding::Replicate());
auto state = lhs.state();
state.b = &body_b;
state.partition_id = data_partition_id;
slice = PartitionedHlo(slice_operand, slice_operand->shape(), state)
.Reshard(windowing_operand == 0
? *lhs_sharding_transposed_to_match_rhs
: *rhs_sharding_transposed_to_match_lhs)
.hlo();
slice_operand->clear_sharding();
slice_sharding = windowing_operand == 0
? &*lhs_sharding_transposed_to_match_rhs
: &*rhs_sharding_transposed_to_match_lhs;
}
auto slice = PartitionedHlo(slice_operand, slice_operand->shape(), state)
.Reshard(*slice_sharding)
.hlo();
slice_operand->clear_sharding();
if (matching_operand == 0) {
dot_lhs = slice;
} else {

View File

@ -3818,7 +3818,7 @@ ENTRY entry {
auto ds =
AllOf(op::DynamicSlice(
op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
op::Constant(), op::Multiply(), op::Constant(), op::Constant()),
op::Constant(), op::Reshape(), op::Constant(), op::Constant()),
op::Shape("f32[320,7,16,128]"));
auto partial_output =
AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
@ -3909,7 +3909,7 @@ ENTRY entry {
auto ds =
AllOf(op::DynamicSlice(
op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
op::Constant(), op::Multiply(), op::Constant()),
op::Constant(), op::Reshape(), op::Constant()),
op::Shape("f32[4096,17,128]"));
auto partial_output =
AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),