Merge pull request #46942 from ktaebum:shape-fix
PiperOrigin-RevId: 356388605 Change-Id: If0de8b1c1ada5781b57bb7098797eb04bce37b52
This commit is contained in:
commit
063eb2465f
@ -2123,12 +2123,13 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
|
||||
Raises:
|
||||
ValueError: If `num` is unspecified and cannot be inferred.
|
||||
"""
|
||||
size_splits = ops.convert_to_tensor(num_or_size_splits)
|
||||
if isinstance(num_or_size_splits,
|
||||
(numbers.Integral, tensor_shape.Dimension)):
|
||||
return gen_array_ops.split(
|
||||
axis=axis, num_split=num_or_size_splits, value=value, name=name)
|
||||
|
||||
size_splits = ops.convert_to_tensor(num_or_size_splits)
|
||||
|
||||
if size_splits._rank() == 0:
|
||||
raise ValueError(
|
||||
"Rank-0 tensors are not supported as the num_or_size_splits argument "
|
||||
|
Loading…
Reference in New Issue
Block a user