reorder 'size_splits' in array_ops.split to prevent creating redundant tensor in graph mode

This commit is contained in:
Taebum Kim 2021-02-05 13:35:40 +09:00
parent 8db7298d14
commit bdeedf139d

View File

@ -2123,12 +2123,13 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
Raises:
ValueError: If `num` is unspecified and cannot be inferred.
"""
size_splits = ops.convert_to_tensor(num_or_size_splits)
if isinstance(num_or_size_splits,
(numbers.Integral, tensor_shape.Dimension)):
return gen_array_ops.split(
axis=axis, num_split=num_or_size_splits, value=value, name=name)
size_splits = ops.convert_to_tensor(num_or_size_splits)
if size_splits._rank() == 0:
raise ValueError(
"Rank-0 tensors are not supported as the num_or_size_splits argument "