[XLA:SPMD] Make dot base case less aggressive
So that we can prioritize recursive partial matches. PiperOrigin-RevId: 327269736 Change-Id: I2d498d82bc3cea3eceb74e5e60a3d3d46e387054
This commit is contained in:
parent
7f2bc1e4b8
commit
b7dbaf4f23
@ -100,7 +100,8 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
int64 output_rhs_non_contracting_partitions,
|
||||
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
windowed_dot_general_loops) {
|
||||
windowed_dot_general_loops,
|
||||
bool may_reshard_without_detecting_match) {
|
||||
const HloSharding& lhs_sharding = lhs.sharding();
|
||||
const HloSharding& rhs_sharding = rhs.sharding();
|
||||
if (lhs_sharding.ReplicateOnLastTileDim() ||
|
||||
@ -491,29 +492,36 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
return dot;
|
||||
}
|
||||
|
||||
// Output is batch partitioned.
|
||||
if (output_batch_partitions == num_partitions) {
|
||||
auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs);
|
||||
auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs);
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(),
|
||||
resharded_rhs.hlo(), b));
|
||||
return dot;
|
||||
}
|
||||
// Output is partitioned along LHS non-contracting dimensions.
|
||||
if (output_lhs_non_contracting_partitions == num_partitions) {
|
||||
auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs);
|
||||
auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(),
|
||||
replicated_rhs.hlo(), b));
|
||||
return dot;
|
||||
}
|
||||
// Output is partitioned along RHS non-contracting dimensions.
|
||||
if (output_rhs_non_contracting_partitions == num_partitions) {
|
||||
auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
|
||||
auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs);
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(),
|
||||
resharded_rhs.hlo(), b));
|
||||
return dot;
|
||||
if (may_reshard_without_detecting_match) {
|
||||
// Output is batch partitioned.
|
||||
if (output_batch_partitions == num_partitions) {
|
||||
auto resharded_lhs =
|
||||
lhs.Reshard(*output_sharding_transposed_to_match_lhs);
|
||||
auto resharded_rhs =
|
||||
rhs.Reshard(*output_sharding_transposed_to_match_rhs);
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(),
|
||||
resharded_rhs.hlo(), b));
|
||||
return dot;
|
||||
}
|
||||
// Output is partitioned along LHS non-contracting dimensions.
|
||||
if (output_lhs_non_contracting_partitions == num_partitions) {
|
||||
auto resharded_lhs =
|
||||
lhs.Reshard(*output_sharding_transposed_to_match_lhs);
|
||||
auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot,
|
||||
create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), b));
|
||||
return dot;
|
||||
}
|
||||
// Output is partitioned along RHS non-contracting dimensions.
|
||||
if (output_rhs_non_contracting_partitions == num_partitions) {
|
||||
auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
|
||||
auto resharded_rhs =
|
||||
rhs.Reshard(*output_sharding_transposed_to_match_rhs);
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(),
|
||||
resharded_rhs.hlo(), b));
|
||||
return dot;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if it is beneficial to reshard the operand at `operand_idx`
|
||||
@ -1155,6 +1163,8 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
output_sharding, dims_mapping.lhs_non_contracting_dims, 2);
|
||||
const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims(
|
||||
output_sharding, dims_mapping.rhs_non_contracting_dims, 2);
|
||||
// Before we find partial matches along the dimensions, invoke base case again
|
||||
// without may_reshard_without_detecting_match.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto try_partitioned_dot,
|
||||
PartitionBaseCase(
|
||||
@ -1165,7 +1175,8 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
lhs_non_contracting_partitions, rhs_non_contracting_partitions,
|
||||
output_lhs_non_contracting_partitions,
|
||||
output_rhs_non_contracting_partitions,
|
||||
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops));
|
||||
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops,
|
||||
/*may_reshard_without_detecting_match=*/false));
|
||||
if (try_partitioned_dot) {
|
||||
return try_partitioned_dot;
|
||||
}
|
||||
@ -1350,6 +1361,24 @@ StatusOr<HloInstruction*> PartitionDot(
|
||||
return dot;
|
||||
}
|
||||
}
|
||||
|
||||
// We failed to find partial matches, invoke base case again with
|
||||
// may_reshard_without_detecting_match.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot,
|
||||
PartitionBaseCase(
|
||||
lhs, rhs, output_base_shape, output_sharding, dims_mapping,
|
||||
num_partitions, create_sharded_dot, module, original_hlo,
|
||||
lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
|
||||
lhs_contracting_partitions, rhs_contracting_partitions,
|
||||
lhs_non_contracting_partitions, rhs_non_contracting_partitions,
|
||||
output_lhs_non_contracting_partitions,
|
||||
output_rhs_non_contracting_partitions,
|
||||
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops,
|
||||
/*may_reshard_without_detecting_match=*/true));
|
||||
if (dot) {
|
||||
return dot;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -4730,6 +4730,35 @@ ENTRY entry {
|
||||
EXPECT_THAT(root, op::AllReduce(op::AllReduce(dot)));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, DotNonContractingPartialMatchContractingMatch) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%lhs = f32[24,8,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
|
||||
%rhs = f32[100,50] parameter(1), sharding={devices=[2,2]0,2,1,3}
|
||||
ROOT %dot = f32[24,8,50] dot(%lhs, %rhs),
|
||||
lhs_batch_dims={}, rhs_batch_dims={},
|
||||
lhs_contracting_dims={2}, rhs_contracting_dims={0},
|
||||
sharding={devices=[2,2,1]0,1,2,3}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/4));
|
||||
VLOG(1) << module->ToString();
|
||||
|
||||
auto lhs = AllOf(op::Shape("f32[12,8,50]"), op::Parameter(0));
|
||||
auto rhs = AllOf(op::Shape("f32[50,25]"), op::Parameter(1));
|
||||
auto dot = AllOf(
|
||||
op::Shape("f32[12,8,50]"),
|
||||
op::Dot(lhs, AllOf(op::Shape("f32[50,50]"),
|
||||
op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)))));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::Shape("f32[12,4,50]"),
|
||||
op::DynamicSlice(op::AllReduce(dot), _, _, _)))
|
||||
<< module->ToString();
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, DotLHSMutiNonContractingRHSNotMatch) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
Loading…
x
Reference in New Issue
Block a user