[XLA:SPMD] Make dot base case less aggressive

So that we can prioritize recursive partial matches.

PiperOrigin-RevId: 327269736
Change-Id: I2d498d82bc3cea3eceb74e5e60a3d3d46e387054
This commit is contained in:
Yuanzhong Xu 2020-08-18 11:27:31 -07:00 committed by TensorFlower Gardener
parent 7f2bc1e4b8
commit b7dbaf4f23
2 changed files with 83 additions and 25 deletions

View File

@ -100,7 +100,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
int64 output_rhs_non_contracting_partitions,
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops) {
windowed_dot_general_loops,
bool may_reshard_without_detecting_match) {
const HloSharding& lhs_sharding = lhs.sharding();
const HloSharding& rhs_sharding = rhs.sharding();
if (lhs_sharding.ReplicateOnLastTileDim() ||
@ -491,29 +492,36 @@ StatusOr<HloInstruction*> PartitionBaseCase(
return dot;
}
// Output is batch partitioned.
if (output_batch_partitions == num_partitions) {
auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs);
auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(),
resharded_rhs.hlo(), b));
return dot;
}
// Output is partitioned along LHS non-contracting dimensions.
if (output_lhs_non_contracting_partitions == num_partitions) {
auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs);
auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(),
replicated_rhs.hlo(), b));
return dot;
}
// Output is partitioned along RHS non-contracting dimensions.
if (output_rhs_non_contracting_partitions == num_partitions) {
auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(),
resharded_rhs.hlo(), b));
return dot;
if (may_reshard_without_detecting_match) {
// Output is batch partitioned.
if (output_batch_partitions == num_partitions) {
auto resharded_lhs =
lhs.Reshard(*output_sharding_transposed_to_match_lhs);
auto resharded_rhs =
rhs.Reshard(*output_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(),
resharded_rhs.hlo(), b));
return dot;
}
// Output is partitioned along LHS non-contracting dimensions.
if (output_lhs_non_contracting_partitions == num_partitions) {
auto resharded_lhs =
lhs.Reshard(*output_sharding_transposed_to_match_lhs);
auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
TF_ASSIGN_OR_RETURN(
auto dot,
create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), b));
return dot;
}
// Output is partitioned along RHS non-contracting dimensions.
if (output_rhs_non_contracting_partitions == num_partitions) {
auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
auto resharded_rhs =
rhs.Reshard(*output_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(),
resharded_rhs.hlo(), b));
return dot;
}
}
// Returns true if it is beneficial to reshard the operand at `operand_idx`
@ -1155,6 +1163,8 @@ StatusOr<HloInstruction*> PartitionDot(
output_sharding, dims_mapping.lhs_non_contracting_dims, 2);
const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims(
output_sharding, dims_mapping.rhs_non_contracting_dims, 2);
// Before we find partial matches along the dimensions, invoke base case again
// without may_reshard_without_detecting_match.
TF_ASSIGN_OR_RETURN(
auto try_partitioned_dot,
PartitionBaseCase(
@ -1165,7 +1175,8 @@ StatusOr<HloInstruction*> PartitionDot(
lhs_non_contracting_partitions, rhs_non_contracting_partitions,
output_lhs_non_contracting_partitions,
output_rhs_non_contracting_partitions,
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops));
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops,
/*may_reshard_without_detecting_match=*/false));
if (try_partitioned_dot) {
return try_partitioned_dot;
}
@ -1350,6 +1361,24 @@ StatusOr<HloInstruction*> PartitionDot(
return dot;
}
}
// We failed to find partial matches, invoke base case again with
// may_reshard_without_detecting_match.
TF_ASSIGN_OR_RETURN(
auto dot,
PartitionBaseCase(
lhs, rhs, output_base_shape, output_sharding, dims_mapping,
num_partitions, create_sharded_dot, module, original_hlo,
lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
lhs_contracting_partitions, rhs_contracting_partitions,
lhs_non_contracting_partitions, rhs_non_contracting_partitions,
output_lhs_non_contracting_partitions,
output_rhs_non_contracting_partitions,
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops,
/*may_reshard_without_detecting_match=*/true));
if (dot) {
return dot;
}
return nullptr;
}

View File

@ -4730,6 +4730,35 @@ ENTRY entry {
EXPECT_THAT(root, op::AllReduce(op::AllReduce(dot)));
}
TEST_F(SpmdPartitioningTest, DotNonContractingPartialMatchContractingMatch) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[24,8,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
%rhs = f32[100,50] parameter(1), sharding={devices=[2,2]0,2,1,3}
ROOT %dot = f32[24,8,50] dot(%lhs, %rhs),
lhs_batch_dims={}, rhs_batch_dims={},
lhs_contracting_dims={2}, rhs_contracting_dims={0},
sharding={devices=[2,2,1]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto lhs = AllOf(op::Shape("f32[12,8,50]"), op::Parameter(0));
auto rhs = AllOf(op::Shape("f32[50,25]"), op::Parameter(1));
auto dot = AllOf(
op::Shape("f32[12,8,50]"),
op::Dot(lhs, AllOf(op::Shape("f32[50,50]"),
op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)))));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Shape("f32[12,4,50]"),
op::DynamicSlice(op::AllReduce(dot), _, _, _)))
<< module->ToString();
}
TEST_F(SpmdPartitioningTest, DotLHSMutiNonContractingRHSNotMatch) {
const char* const hlo_string = R"(
HloModule module