[XLA] Correct typo in PartitionDotGroupOnContracting. NFC

PiperOrigin-RevId: 359214134
Change-Id: I013fef9a443b6d1db2108d0d37b03567296a251e
This commit is contained in:
Marcello Maggioni 2021-02-23 22:43:18 -08:00 committed by TensorFlower Gardener
parent 089d2e1ba9
commit dcdc6b2f90

View File

@ -2075,7 +2075,7 @@ StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
PartitionedHlo lhs, PartitionedHlo rhs,
absl::Span<const DotConvDimsMapping::DimsMapping>
partitioned_contractin_dims,
partitioned_contracting_dims,
int64 output_batch_partitions, int64 output_lhs_non_contracting_partitions,
int64 output_rhs_non_contracting_partitions, const Shape& output_base_shape,
const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
@ -2102,14 +2102,14 @@ StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
std::vector<int64> lhs_dims;
std::vector<int64> rhs_dims;
int64 group_count = 1;
for (const auto& dim : partitioned_contractin_dims) {
for (const auto& dim : partitioned_contracting_dims) {
lhs_dims.push_back(dim.lhs);
rhs_dims.push_back(dim.rhs);
group_count *= lhs_sharding.tile_assignment().dim(dim.lhs);
}
if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
for (const auto& dim : partitioned_contractin_dims) {
for (const auto& dim : partitioned_contracting_dims) {
rhs_tile_shape[dim.rhs] = lhs_tile_shape[dim.lhs];
}
auto new_tile = rhs.sharding().tile_assignment();
@ -2118,7 +2118,7 @@ StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
? HloSharding::PartialTile(new_tile)
: HloSharding::Tile(new_tile);
} else {
for (const auto& dim : partitioned_contractin_dims) {
for (const auto& dim : partitioned_contracting_dims) {
lhs_tile_shape[dim.lhs] = rhs_tile_shape[dim.rhs];
}
auto new_tile = lhs.sharding().tile_assignment();