Merge pull request #46942 from ktaebum:shape-fix

PiperOrigin-RevId: 356388605
Change-Id: If0de8b1c1ada5781b57bb7098797eb04bce37b52
This commit is contained in:
TensorFlower Gardener 2021-02-08 17:24:42 -08:00
commit 063eb2465f

View File

@ -2123,12 +2123,13 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
Raises:
ValueError: If `num` is unspecified and cannot be inferred.
"""
size_splits = ops.convert_to_tensor(num_or_size_splits)
if isinstance(num_or_size_splits,
(numbers.Integral, tensor_shape.Dimension)):
return gen_array_ops.split(
axis=axis, num_split=num_or_size_splits, value=value, name=name)
size_splits = ops.convert_to_tensor(num_or_size_splits)
if size_splits._rank() == 0:
raise ValueError(
"Rank-0 tensors are not supported as the num_or_size_splits argument "