From bdeedf139de7d4e30f2d5f3fa0a7fa813419e7d3 Mon Sep 17 00:00:00 2001 From: Taebum Kim <k.taebum@snu.ac.kr> Date: Fri, 5 Feb 2021 13:35:40 +0900 Subject: [PATCH] reorder 'size_splits' in array_ops.split to prevent creating redundant tensor in graph mode --- tensorflow/python/ops/array_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index f0f97f6a054..f7fc3239714 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2123,12 +2123,13 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"): Raises: ValueError: If `num` is unspecified and cannot be inferred. """ - size_splits = ops.convert_to_tensor(num_or_size_splits) if isinstance(num_or_size_splits, (numbers.Integral, tensor_shape.Dimension)): return gen_array_ops.split( axis=axis, num_split=num_or_size_splits, value=value, name=name) + size_splits = ops.convert_to_tensor(num_or_size_splits) + if size_splits._rank() == 0: raise ValueError( "Rank-0 tensors are not supported as the num_or_size_splits argument "