diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc index f765ee5ecc2..a346d8778d6 100644 --- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc @@ -289,6 +289,12 @@ StatusOr<HloInstruction*> PartitionBaseCase( to_mask.PadWithValue(b->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(output_base_shape.element_type())))); } + if (operands_sharded_at_contracting_dims) { + auto zero = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(output_base_shape.element_type()))); + lhs = lhs.PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } auto result_buffer = CreateZero(padded_result_buffer_shape, b); auto iteration = b->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0))); @@ -333,57 +339,28 @@ StatusOr<HloInstruction*> PartitionBaseCase( if (windowed_at_contracting_dims || windowed_at_batch_dims || operands_sharded_at_contracting_dims) { // Slice the matching operand according to the partitioned dimensions on - // the windowed operand. + // the windowed operand or the output. auto slice_operand = matching_operand == 0 ? l : r; - HloInstruction* slice; + // We do this by treating the matching operand as replicated, and + // resharding it to match the windowed operand or the output. + slice_operand->set_sharding(HloSharding::Replicate()); + auto state = lhs.state(); + state.b = &body_b; + state.partition_id = data_partition_id; + const HloSharding* slice_sharding; if (operands_sharded_at_contracting_dims) { - CHECK_NE(output_sharding_dim, -1); - int64 output_sharding_dim_size = - o->shape().dimensions(output_sharding_dim); - int64 slice_dim = matching_operand == 0 - ? output_to_lhs_indices[output_sharding_dim] - : output_to_rhs_indices[output_sharding_dim]; - auto slice_shape = slice_operand->shape(); - slice_shape.set_dimensions(slice_dim, output_sharding_dim_size); - std::vector<HloInstruction*> slice_offsets(slice_shape.rank()); - for (int64 i = 0; i < slice_offsets.size(); ++i) { - if (i != slice_dim) { - slice_offsets[i] = - body_b.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0<uint32>(0))); - } else { - auto stride = body_b.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0<uint32>(output_sharding_dim_size))); - slice_offsets[i] = - body_b.AddInstruction(HloInstruction::CreateBinary( - data_partition_id->shape(), HloOpcode::kMultiply, - data_partition_id, stride)); - } - } - auto padded_shape = slice_operand->shape(); - padded_shape.set_dimensions( - slice_dim, - o->shape().dimensions(output_sharding_dim) * num_partitions); - auto padded_slice_operand = - PadToShape(slice_operand, padded_shape, &body_b); - slice = body_b.AddInstruction(HloInstruction::CreateDynamicSlice( - slice_shape, padded_slice_operand, slice_offsets, - slice_shape.dimensions())); + slice_sharding = windowing_operand == 0 + ? &*output_sharding_transposed_to_match_rhs + : &*output_sharding_transposed_to_match_lhs; } else { - // For windowed operand that partitioned along contracting dimensions, - // we do this by treating the matching operand as replicated, and - // resharding it to match the windowed operand. - slice_operand->set_sharding(HloSharding::Replicate()); - auto state = lhs.state(); - state.b = &body_b; - state.partition_id = data_partition_id; - slice = PartitionedHlo(slice_operand, slice_operand->shape(), state) - .Reshard(windowing_operand == 0 - ? *lhs_sharding_transposed_to_match_rhs - : *rhs_sharding_transposed_to_match_lhs) - .hlo(); - slice_operand->clear_sharding(); + slice_sharding = windowing_operand == 0 + ? &*lhs_sharding_transposed_to_match_rhs + : &*rhs_sharding_transposed_to_match_lhs; } + auto slice = PartitionedHlo(slice_operand, slice_operand->shape(), state) + .Reshard(*slice_sharding) + .hlo(); + slice_operand->clear_sharding(); if (matching_operand == 0) { dot_lhs = slice; } else { diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 91a0c44b51a..e4bd272e361 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -3818,7 +3818,7 @@ ENTRY entry { auto ds = AllOf(op::DynamicSlice( op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()), - op::Constant(), op::Multiply(), op::Constant(), op::Constant()), + op::Constant(), op::Reshape(), op::Constant(), op::Constant()), op::Shape("f32[320,7,16,128]")); auto partial_output = AllOf(op::Add(op::GetTupleElement(op::Parameter(0)), @@ -3909,7 +3909,7 @@ ENTRY entry { auto ds = AllOf(op::DynamicSlice( op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()), - op::Constant(), op::Multiply(), op::Constant()), + op::Constant(), op::Reshape(), op::Constant()), op::Shape("f32[4096,17,128]")); auto partial_output = AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),