Mask operands for unevenly partitioned contracting dims.
PiperOrigin-RevId: 340997388 Change-Id: I43ed530c3c3ef077e190dddcc832632bf024a71c
This commit is contained in:
parent
18ec783484
commit
63446fb3b5
@ -1148,6 +1148,27 @@ StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
|
||||
}
|
||||
lhs = lhs.Reshard(lhs_sharding);
|
||||
}
|
||||
// Mask out invalid data.
|
||||
std::vector<int64> lhs_skipped_dims;
|
||||
for (int64 i = 0; i < lhs.base_shape().rank(); ++i) {
|
||||
if (absl::c_linear_search(lhs_dims, i)) {
|
||||
continue;
|
||||
}
|
||||
lhs_skipped_dims.push_back(i);
|
||||
}
|
||||
lhs = lhs.PadWithValue(
|
||||
CreateZero(ShapeUtil::MakeShape(lhs.base_shape().element_type(), {}), b),
|
||||
/*left_padded_dims=*/{}, lhs_skipped_dims);
|
||||
std::vector<int64> rhs_skipped_dims;
|
||||
for (int64 i = 0; i < rhs.base_shape().rank(); ++i) {
|
||||
if (absl::c_linear_search(rhs_dims, i)) {
|
||||
continue;
|
||||
}
|
||||
rhs_skipped_dims.push_back(i);
|
||||
}
|
||||
rhs = rhs.PadWithValue(
|
||||
CreateZero(ShapeUtil::MakeShape(rhs.base_shape().element_type(), {}), b),
|
||||
/*left_padded_dims=*/{}, rhs_skipped_dims);
|
||||
top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs_sharding);
|
||||
lhs.hlo()->set_sharding(lhs_grouped.sharding);
|
||||
top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs_sharding);
|
||||
|
||||
@ -463,7 +463,8 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
|
||||
}
|
||||
|
||||
PartitionedHlo PartitionedHlo::PadWithValue(
|
||||
HloInstruction* pad_value, absl::Span<const int64> left_padded_dims) const {
|
||||
HloInstruction* pad_value, absl::Span<const int64> left_padded_dims,
|
||||
absl::Span<const int64> skipped_dims) const {
|
||||
const HloSharding& sharding = hlo_->sharding();
|
||||
const Shape& shape = hlo_->shape();
|
||||
CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
|
||||
@ -502,7 +503,8 @@ PartitionedHlo PartitionedHlo::PadWithValue(
|
||||
auto offsets = MakePartitionOffsets(base_shape_, sharding,
|
||||
state_.partition_id, state_.b);
|
||||
for (int64 i = 0; i < shape.rank(); ++i) {
|
||||
if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0) {
|
||||
if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0 ||
|
||||
absl::c_linear_search(skipped_dims, i)) {
|
||||
continue;
|
||||
}
|
||||
if (mask == nullptr) {
|
||||
|
||||
@ -283,9 +283,9 @@ class PartitionedHlo {
|
||||
// unevenly partitioned dimensions are padded on the right, but this function
|
||||
// allows specifying left-padded dimensions, which can be used during the
|
||||
// handling of kReverse, etc.
|
||||
PartitionedHlo PadWithValue(
|
||||
HloInstruction* pad_value,
|
||||
absl::Span<const int64> left_padded_dims = {}) const;
|
||||
PartitionedHlo PadWithValue(HloInstruction* pad_value,
|
||||
absl::Span<const int64> left_padded_dims = {},
|
||||
absl::Span<const int64> skipped_dims = {}) const;
|
||||
|
||||
// Returns the SPMD instruction.
|
||||
HloInstruction* hlo() const { return hlo_; }
|
||||
|
||||
@ -5003,6 +5003,37 @@ ENTRY entry {
|
||||
op::Dot(lhs_slice, partial_replicated_rhs)));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, Dot2DPartitionedNoncontractingAndContracting3) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%lhs = f32[23,24] parameter(0), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
|
||||
%rhs = f32[23,32] parameter(1), sharding={devices=[2,2]0,1,2,3}
|
||||
ROOT %dot = f32[24,32] dot(%lhs, %rhs),
|
||||
lhs_contracting_dims={0}, rhs_contracting_dims={0},
|
||||
sharding={devices=[2,2]1,0,3,2}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/4));
|
||||
VLOG(1) << module->ToString();
|
||||
|
||||
auto lhs = AllOf(op::Shape("f32[12,24]"), op::Parameter(0));
|
||||
auto masked_lhs = op::Select(_, lhs, op::Broadcast(op::Constant()));
|
||||
auto rhs = AllOf(op::Shape("f32[12,16]"), op::Parameter(1));
|
||||
auto masked_rhs = op::Select(_, rhs, op::Broadcast(op::Constant()));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(
|
||||
root,
|
||||
AllOf(op::Shape("f32[12,16]"),
|
||||
op::DynamicSlice(
|
||||
AllOf(op::Shape("f32[24,16]"),
|
||||
op::AllReduce(op::Dot(
|
||||
masked_lhs, op::CollectivePermute(masked_rhs)))),
|
||||
_, _)));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndNonContracting) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
@ -82,6 +82,9 @@ HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) {
|
||||
}
|
||||
auto zero = b->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
|
||||
if (shape.rank() == 0) {
|
||||
return zero;
|
||||
}
|
||||
return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {}));
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user