[XLA:SPMD] Fix wrong offsets of windowed dot (reduce-scatter case)
PiperOrigin-RevId: 341724756 Change-Id: Ied736cea8260e29dcebbbb1d79194e06ee324713
This commit is contained in:
parent
e07f05f816
commit
982a5a15b3
@ -289,6 +289,12 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
|||||||
to_mask.PadWithValue(b->AddInstruction(HloInstruction::CreateConstant(
|
to_mask.PadWithValue(b->AddInstruction(HloInstruction::CreateConstant(
|
||||||
LiteralUtil::Zero(output_base_shape.element_type()))));
|
LiteralUtil::Zero(output_base_shape.element_type()))));
|
||||||
}
|
}
|
||||||
|
if (operands_sharded_at_contracting_dims) {
|
||||||
|
auto zero = b->AddInstruction(HloInstruction::CreateConstant(
|
||||||
|
LiteralUtil::Zero(output_base_shape.element_type())));
|
||||||
|
lhs = lhs.PadWithValue(zero);
|
||||||
|
rhs = rhs.PadWithValue(zero);
|
||||||
|
}
|
||||||
auto result_buffer = CreateZero(padded_result_buffer_shape, b);
|
auto result_buffer = CreateZero(padded_result_buffer_shape, b);
|
||||||
auto iteration = b->AddInstruction(
|
auto iteration = b->AddInstruction(
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0)));
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0)));
|
||||||
@ -333,57 +339,28 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
|||||||
if (windowed_at_contracting_dims || windowed_at_batch_dims ||
|
if (windowed_at_contracting_dims || windowed_at_batch_dims ||
|
||||||
operands_sharded_at_contracting_dims) {
|
operands_sharded_at_contracting_dims) {
|
||||||
// Slice the matching operand according to the partitioned dimensions on
|
// Slice the matching operand according to the partitioned dimensions on
|
||||||
// the windowed operand.
|
// the windowed operand or the output.
|
||||||
auto slice_operand = matching_operand == 0 ? l : r;
|
auto slice_operand = matching_operand == 0 ? l : r;
|
||||||
HloInstruction* slice;
|
// We do this by treating the matching operand as replicated, and
|
||||||
|
// resharding it to match the windowed operand or the output.
|
||||||
|
slice_operand->set_sharding(HloSharding::Replicate());
|
||||||
|
auto state = lhs.state();
|
||||||
|
state.b = &body_b;
|
||||||
|
state.partition_id = data_partition_id;
|
||||||
|
const HloSharding* slice_sharding;
|
||||||
if (operands_sharded_at_contracting_dims) {
|
if (operands_sharded_at_contracting_dims) {
|
||||||
CHECK_NE(output_sharding_dim, -1);
|
slice_sharding = windowing_operand == 0
|
||||||
int64 output_sharding_dim_size =
|
? &*output_sharding_transposed_to_match_rhs
|
||||||
o->shape().dimensions(output_sharding_dim);
|
: &*output_sharding_transposed_to_match_lhs;
|
||||||
int64 slice_dim = matching_operand == 0
|
|
||||||
? output_to_lhs_indices[output_sharding_dim]
|
|
||||||
: output_to_rhs_indices[output_sharding_dim];
|
|
||||||
auto slice_shape = slice_operand->shape();
|
|
||||||
slice_shape.set_dimensions(slice_dim, output_sharding_dim_size);
|
|
||||||
std::vector<HloInstruction*> slice_offsets(slice_shape.rank());
|
|
||||||
for (int64 i = 0; i < slice_offsets.size(); ++i) {
|
|
||||||
if (i != slice_dim) {
|
|
||||||
slice_offsets[i] =
|
|
||||||
body_b.AddInstruction(HloInstruction::CreateConstant(
|
|
||||||
LiteralUtil::CreateR0<uint32>(0)));
|
|
||||||
} else {
|
|
||||||
auto stride = body_b.AddInstruction(HloInstruction::CreateConstant(
|
|
||||||
LiteralUtil::CreateR0<uint32>(output_sharding_dim_size)));
|
|
||||||
slice_offsets[i] =
|
|
||||||
body_b.AddInstruction(HloInstruction::CreateBinary(
|
|
||||||
data_partition_id->shape(), HloOpcode::kMultiply,
|
|
||||||
data_partition_id, stride));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto padded_shape = slice_operand->shape();
|
|
||||||
padded_shape.set_dimensions(
|
|
||||||
slice_dim,
|
|
||||||
o->shape().dimensions(output_sharding_dim) * num_partitions);
|
|
||||||
auto padded_slice_operand =
|
|
||||||
PadToShape(slice_operand, padded_shape, &body_b);
|
|
||||||
slice = body_b.AddInstruction(HloInstruction::CreateDynamicSlice(
|
|
||||||
slice_shape, padded_slice_operand, slice_offsets,
|
|
||||||
slice_shape.dimensions()));
|
|
||||||
} else {
|
} else {
|
||||||
// For windowed operand that partitioned along contracting dimensions,
|
slice_sharding = windowing_operand == 0
|
||||||
// we do this by treating the matching operand as replicated, and
|
? &*lhs_sharding_transposed_to_match_rhs
|
||||||
// resharding it to match the windowed operand.
|
: &*rhs_sharding_transposed_to_match_lhs;
|
||||||
slice_operand->set_sharding(HloSharding::Replicate());
|
|
||||||
auto state = lhs.state();
|
|
||||||
state.b = &body_b;
|
|
||||||
state.partition_id = data_partition_id;
|
|
||||||
slice = PartitionedHlo(slice_operand, slice_operand->shape(), state)
|
|
||||||
.Reshard(windowing_operand == 0
|
|
||||||
? *lhs_sharding_transposed_to_match_rhs
|
|
||||||
: *rhs_sharding_transposed_to_match_lhs)
|
|
||||||
.hlo();
|
|
||||||
slice_operand->clear_sharding();
|
|
||||||
}
|
}
|
||||||
|
auto slice = PartitionedHlo(slice_operand, slice_operand->shape(), state)
|
||||||
|
.Reshard(*slice_sharding)
|
||||||
|
.hlo();
|
||||||
|
slice_operand->clear_sharding();
|
||||||
if (matching_operand == 0) {
|
if (matching_operand == 0) {
|
||||||
dot_lhs = slice;
|
dot_lhs = slice;
|
||||||
} else {
|
} else {
|
||||||
|
@ -3818,7 +3818,7 @@ ENTRY entry {
|
|||||||
auto ds =
|
auto ds =
|
||||||
AllOf(op::DynamicSlice(
|
AllOf(op::DynamicSlice(
|
||||||
op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
|
op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
|
||||||
op::Constant(), op::Multiply(), op::Constant(), op::Constant()),
|
op::Constant(), op::Reshape(), op::Constant(), op::Constant()),
|
||||||
op::Shape("f32[320,7,16,128]"));
|
op::Shape("f32[320,7,16,128]"));
|
||||||
auto partial_output =
|
auto partial_output =
|
||||||
AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
|
AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
|
||||||
@ -3909,7 +3909,7 @@ ENTRY entry {
|
|||||||
auto ds =
|
auto ds =
|
||||||
AllOf(op::DynamicSlice(
|
AllOf(op::DynamicSlice(
|
||||||
op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
|
op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
|
||||||
op::Constant(), op::Multiply(), op::Constant()),
|
op::Constant(), op::Reshape(), op::Constant()),
|
||||||
op::Shape("f32[4096,17,128]"));
|
op::Shape("f32[4096,17,128]"));
|
||||||
auto partial_output =
|
auto partial_output =
|
||||||
AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
|
AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user