[XLA:SPMD] Fix wrong offsets of windowed dot (reduce-scatter case)
PiperOrigin-RevId: 341724756 Change-Id: Ied736cea8260e29dcebbbb1d79194e06ee324713
This commit is contained in:
parent
e07f05f816
commit
982a5a15b3
tensorflow/compiler/xla/service/spmd
@ -289,6 +289,12 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
to_mask.PadWithValue(b->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(output_base_shape.element_type()))));
|
||||
}
|
||||
if (operands_sharded_at_contracting_dims) {
|
||||
auto zero = b->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(output_base_shape.element_type())));
|
||||
lhs = lhs.PadWithValue(zero);
|
||||
rhs = rhs.PadWithValue(zero);
|
||||
}
|
||||
auto result_buffer = CreateZero(padded_result_buffer_shape, b);
|
||||
auto iteration = b->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0)));
|
||||
@ -333,57 +339,28 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
if (windowed_at_contracting_dims || windowed_at_batch_dims ||
|
||||
operands_sharded_at_contracting_dims) {
|
||||
// Slice the matching operand according to the partitioned dimensions on
|
||||
// the windowed operand.
|
||||
// the windowed operand or the output.
|
||||
auto slice_operand = matching_operand == 0 ? l : r;
|
||||
HloInstruction* slice;
|
||||
// We do this by treating the matching operand as replicated, and
|
||||
// resharding it to match the windowed operand or the output.
|
||||
slice_operand->set_sharding(HloSharding::Replicate());
|
||||
auto state = lhs.state();
|
||||
state.b = &body_b;
|
||||
state.partition_id = data_partition_id;
|
||||
const HloSharding* slice_sharding;
|
||||
if (operands_sharded_at_contracting_dims) {
|
||||
CHECK_NE(output_sharding_dim, -1);
|
||||
int64 output_sharding_dim_size =
|
||||
o->shape().dimensions(output_sharding_dim);
|
||||
int64 slice_dim = matching_operand == 0
|
||||
? output_to_lhs_indices[output_sharding_dim]
|
||||
: output_to_rhs_indices[output_sharding_dim];
|
||||
auto slice_shape = slice_operand->shape();
|
||||
slice_shape.set_dimensions(slice_dim, output_sharding_dim_size);
|
||||
std::vector<HloInstruction*> slice_offsets(slice_shape.rank());
|
||||
for (int64 i = 0; i < slice_offsets.size(); ++i) {
|
||||
if (i != slice_dim) {
|
||||
slice_offsets[i] =
|
||||
body_b.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<uint32>(0)));
|
||||
} else {
|
||||
auto stride = body_b.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<uint32>(output_sharding_dim_size)));
|
||||
slice_offsets[i] =
|
||||
body_b.AddInstruction(HloInstruction::CreateBinary(
|
||||
data_partition_id->shape(), HloOpcode::kMultiply,
|
||||
data_partition_id, stride));
|
||||
}
|
||||
}
|
||||
auto padded_shape = slice_operand->shape();
|
||||
padded_shape.set_dimensions(
|
||||
slice_dim,
|
||||
o->shape().dimensions(output_sharding_dim) * num_partitions);
|
||||
auto padded_slice_operand =
|
||||
PadToShape(slice_operand, padded_shape, &body_b);
|
||||
slice = body_b.AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
slice_shape, padded_slice_operand, slice_offsets,
|
||||
slice_shape.dimensions()));
|
||||
slice_sharding = windowing_operand == 0
|
||||
? &*output_sharding_transposed_to_match_rhs
|
||||
: &*output_sharding_transposed_to_match_lhs;
|
||||
} else {
|
||||
// For windowed operand that partitioned along contracting dimensions,
|
||||
// we do this by treating the matching operand as replicated, and
|
||||
// resharding it to match the windowed operand.
|
||||
slice_operand->set_sharding(HloSharding::Replicate());
|
||||
auto state = lhs.state();
|
||||
state.b = &body_b;
|
||||
state.partition_id = data_partition_id;
|
||||
slice = PartitionedHlo(slice_operand, slice_operand->shape(), state)
|
||||
.Reshard(windowing_operand == 0
|
||||
? *lhs_sharding_transposed_to_match_rhs
|
||||
: *rhs_sharding_transposed_to_match_lhs)
|
||||
.hlo();
|
||||
slice_operand->clear_sharding();
|
||||
slice_sharding = windowing_operand == 0
|
||||
? &*lhs_sharding_transposed_to_match_rhs
|
||||
: &*rhs_sharding_transposed_to_match_lhs;
|
||||
}
|
||||
auto slice = PartitionedHlo(slice_operand, slice_operand->shape(), state)
|
||||
.Reshard(*slice_sharding)
|
||||
.hlo();
|
||||
slice_operand->clear_sharding();
|
||||
if (matching_operand == 0) {
|
||||
dot_lhs = slice;
|
||||
} else {
|
||||
|
@ -3818,7 +3818,7 @@ ENTRY entry {
|
||||
auto ds =
|
||||
AllOf(op::DynamicSlice(
|
||||
op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
|
||||
op::Constant(), op::Multiply(), op::Constant(), op::Constant()),
|
||||
op::Constant(), op::Reshape(), op::Constant(), op::Constant()),
|
||||
op::Shape("f32[320,7,16,128]"));
|
||||
auto partial_output =
|
||||
AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
|
||||
@ -3909,7 +3909,7 @@ ENTRY entry {
|
||||
auto ds =
|
||||
AllOf(op::DynamicSlice(
|
||||
op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
|
||||
op::Constant(), op::Multiply(), op::Constant()),
|
||||
op::Constant(), op::Reshape(), op::Constant()),
|
||||
op::Shape("f32[4096,17,128]"));
|
||||
auto partial_output =
|
||||
AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
|
||||
|
Loading…
Reference in New Issue
Block a user