[XLA] Fix trivial partial sharding
PiperOrigin-RevId: 334505181 Change-Id: I1c29b5cfc25aebc3ec85e945ae1f19d7d3382886
This commit is contained in:
parent
92a328a473
commit
140275975d
@ -56,6 +56,13 @@ HloSharding HloSharding::PartialTile(
|
||||
|
||||
HloSharding HloSharding::PartialTile(
|
||||
const Array<int64>& tile_assignment_last_dim_replicate) {
|
||||
if (tile_assignment_last_dim_replicate.dimensions().back() == 1) {
|
||||
auto new_tile_dims = tile_assignment_last_dim_replicate.dimensions();
|
||||
new_tile_dims.pop_back();
|
||||
auto fully_tiled = tile_assignment_last_dim_replicate;
|
||||
fully_tiled.Reshape(new_tile_dims);
|
||||
return HloSharding(fully_tiled);
|
||||
}
|
||||
std::vector<std::set<int64>> sorted_groups(
|
||||
tile_assignment_last_dim_replicate.num_elements() /
|
||||
tile_assignment_last_dim_replicate.dimensions().back());
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user