reorder 'size_splits' in array_ops.split to prevent creating redundant tensor in graph mode
This commit is contained in:
parent
8db7298d14
commit
bdeedf139d
@ -2123,12 +2123,13 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
|
||||
Raises:
|
||||
ValueError: If `num` is unspecified and cannot be inferred.
|
||||
"""
|
||||
size_splits = ops.convert_to_tensor(num_or_size_splits)
|
||||
if isinstance(num_or_size_splits,
|
||||
(numbers.Integral, tensor_shape.Dimension)):
|
||||
return gen_array_ops.split(
|
||||
axis=axis, num_split=num_or_size_splits, value=value, name=name)
|
||||
|
||||
size_splits = ops.convert_to_tensor(num_or_size_splits)
|
||||
|
||||
if size_splits._rank() == 0:
|
||||
raise ValueError(
|
||||
"Rank-0 tensors are not supported as the num_or_size_splits argument "
|
||||
|
Loading…
Reference in New Issue
Block a user