[XLA] Fix trivial partial sharding

PiperOrigin-RevId: 334505181
Change-Id: I1c29b5cfc25aebc3ec85e945ae1f19d7d3382886
This commit is contained in:
Yuanzhong Xu 2020-09-29 18:29:51 -07:00 committed by TensorFlower Gardener
parent 92a328a473
commit 140275975d

View File

@ -56,6 +56,13 @@ HloSharding HloSharding::PartialTile(
HloSharding HloSharding::PartialTile(
const Array<int64>& tile_assignment_last_dim_replicate) {
if (tile_assignment_last_dim_replicate.dimensions().back() == 1) {
auto new_tile_dims = tile_assignment_last_dim_replicate.dimensions();
new_tile_dims.pop_back();
auto fully_tiled = tile_assignment_last_dim_replicate;
fully_tiled.Reshape(new_tile_dims);
return HloSharding(fully_tiled);
}
std::vector<std::set<int64>> sorted_groups(
tile_assignment_last_dim_replicate.num_elements() /
tile_assignment_last_dim_replicate.dimensions().back());