Mask operands for unevenly partitioned contracting dims.

PiperOrigin-RevId: 340997388
Change-Id: I43ed530c3c3ef077e190dddcc832632bf024a71c
This commit is contained in:
Yuanzhong Xu 2020-11-05 23:40:07 -08:00 committed by TensorFlower Gardener
parent 18ec783484
commit 63446fb3b5
5 changed files with 62 additions and 5 deletions

View File

@ -1148,6 +1148,27 @@ StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
}
lhs = lhs.Reshard(lhs_sharding);
}
// Mask out invalid data.
std::vector<int64> lhs_skipped_dims;
for (int64 i = 0; i < lhs.base_shape().rank(); ++i) {
if (absl::c_linear_search(lhs_dims, i)) {
continue;
}
lhs_skipped_dims.push_back(i);
}
lhs = lhs.PadWithValue(
CreateZero(ShapeUtil::MakeShape(lhs.base_shape().element_type(), {}), b),
/*left_padded_dims=*/{}, lhs_skipped_dims);
std::vector<int64> rhs_skipped_dims;
for (int64 i = 0; i < rhs.base_shape().rank(); ++i) {
if (absl::c_linear_search(rhs_dims, i)) {
continue;
}
rhs_skipped_dims.push_back(i);
}
rhs = rhs.PadWithValue(
CreateZero(ShapeUtil::MakeShape(rhs.base_shape().element_type(), {}), b),
/*left_padded_dims=*/{}, rhs_skipped_dims);
top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs_sharding);
lhs.hlo()->set_sharding(lhs_grouped.sharding);
top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs_sharding);

View File

@ -463,7 +463,8 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
}
PartitionedHlo PartitionedHlo::PadWithValue(
HloInstruction* pad_value, absl::Span<const int64> left_padded_dims) const {
HloInstruction* pad_value, absl::Span<const int64> left_padded_dims,
absl::Span<const int64> skipped_dims) const {
const HloSharding& sharding = hlo_->sharding();
const Shape& shape = hlo_->shape();
CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
@ -502,7 +503,8 @@ PartitionedHlo PartitionedHlo::PadWithValue(
auto offsets = MakePartitionOffsets(base_shape_, sharding,
state_.partition_id, state_.b);
for (int64 i = 0; i < shape.rank(); ++i) {
if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0) {
if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0 ||
absl::c_linear_search(skipped_dims, i)) {
continue;
}
if (mask == nullptr) {

View File

@ -283,9 +283,9 @@ class PartitionedHlo {
// unevenly partitioned dimensions are padded on the right, but this function
// allows specifying left-padded dimensions, which can be used during the
// handling of kReverse, etc.
PartitionedHlo PadWithValue(
HloInstruction* pad_value,
absl::Span<const int64> left_padded_dims = {}) const;
PartitionedHlo PadWithValue(HloInstruction* pad_value,
absl::Span<const int64> left_padded_dims = {},
absl::Span<const int64> skipped_dims = {}) const;
// Returns the SPMD instruction.
HloInstruction* hlo() const { return hlo_; }

View File

@ -5003,6 +5003,37 @@ ENTRY entry {
op::Dot(lhs_slice, partial_replicated_rhs)));
}
TEST_F(SpmdPartitioningTest, Dot2DPartitionedNoncontractingAndContracting3) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[23,24] parameter(0), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
%rhs = f32[23,32] parameter(1), sharding={devices=[2,2]0,1,2,3}
ROOT %dot = f32[24,32] dot(%lhs, %rhs),
lhs_contracting_dims={0}, rhs_contracting_dims={0},
sharding={devices=[2,2]1,0,3,2}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto lhs = AllOf(op::Shape("f32[12,24]"), op::Parameter(0));
auto masked_lhs = op::Select(_, lhs, op::Broadcast(op::Constant()));
auto rhs = AllOf(op::Shape("f32[12,16]"), op::Parameter(1));
auto masked_rhs = op::Select(_, rhs, op::Broadcast(op::Constant()));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root,
AllOf(op::Shape("f32[12,16]"),
op::DynamicSlice(
AllOf(op::Shape("f32[24,16]"),
op::AllReduce(op::Dot(
masked_lhs, op::CollectivePermute(masked_rhs)))),
_, _)));
}
TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndNonContracting) {
const char* const hlo_string = R"(
HloModule module

View File

@ -82,6 +82,9 @@ HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) {
}
auto zero = b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
if (shape.rank() == 0) {
return zero;
}
return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {}));
}