reorder 'size_splits' in array_ops.split to prevent creating redundant tensor in graph mode
This commit is contained in:
parent
8db7298d14
commit
bdeedf139d
@ -2123,12 +2123,13 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If `num` is unspecified and cannot be inferred.
|
ValueError: If `num` is unspecified and cannot be inferred.
|
||||||
"""
|
"""
|
||||||
size_splits = ops.convert_to_tensor(num_or_size_splits)
|
|
||||||
if isinstance(num_or_size_splits,
|
if isinstance(num_or_size_splits,
|
||||||
(numbers.Integral, tensor_shape.Dimension)):
|
(numbers.Integral, tensor_shape.Dimension)):
|
||||||
return gen_array_ops.split(
|
return gen_array_ops.split(
|
||||||
axis=axis, num_split=num_or_size_splits, value=value, name=name)
|
axis=axis, num_split=num_or_size_splits, value=value, name=name)
|
||||||
|
|
||||||
|
size_splits = ops.convert_to_tensor(num_or_size_splits)
|
||||||
|
|
||||||
if size_splits._rank() == 0:
|
if size_splits._rank() == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Rank-0 tensors are not supported as the num_or_size_splits argument "
|
"Rank-0 tensors are not supported as the num_or_size_splits argument "
|
||||||
|
Loading…
Reference in New Issue
Block a user