diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 9701c343288..b0a03707efb 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -199,10 +199,12 @@ std::vector HloSharding::TileLimitForDevice(const Shape& shape, } int64 HloSharding::RequiredLeaves(const Shape& shape) { - // Empty tuples have no leaf nodes as far as ShapeUtil and ShapeTree are - // concerned, but they do have a single tuple_elements_ entry since we want - // to allow empty tuple results to have sharding. - return ShapeUtil::IsEmptyTuple(shape) ? 1 : ShapeUtil::GetLeafCount(shape); + // Empty tuples (with arbitrary nesting) have no leaf nodes as far as + // ShapeUtil and ShapeTree are concerned, but they do have a single + // tuple_elements_ entry since we want to allow empty tuple results to + // have sharding. + const int64 leaf_count = ShapeUtil::GetLeafCount(shape); + return (leaf_count == 0) ? 1 : leaf_count; } Status HloSharding::CheckLeafCount(const Shape& shape) const {