From f74a0436426ee50805a9a204d0c1376db30707bf Mon Sep 17 00:00:00 2001 From: Tamara Norman Date: Mon, 5 Nov 2018 10:36:54 -0800 Subject: [PATCH] Modify to return an error specifying an axis argument must be given instead of an error due to being unable to convert a null tensor. PiperOrigin-RevId: 220129411 --- tensorflow/python/ops/array_ops.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 232a77c8888..6fdc50733a1 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -120,7 +120,7 @@ def expand_dims(input, axis=None, name=None, dim=None): axis: 0-D (scalar). Specifies the dimension index at which to expand the shape of `input`. Must be in the range `[-rank(input) - 1, rank(input)]`. - name: The name of the output `Tensor`. + name: The name of the output `Tensor` (optional). dim: 0-D (scalar). Equivalent to `axis`, to be deprecated. Returns: @@ -128,9 +128,11 @@ def expand_dims(input, axis=None, name=None, dim=None): dimension of size 1 added. Raises: - ValueError: if both `dim` and `axis` are specified. + ValueError: if either both or neither of `dim` and `axis` are specified. """ axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim) + if axis is None: + raise ValueError("Must specify an axis argument to tf.expand_dims()") return gen_array_ops.expand_dims(input, axis, name)