Modify to return an error specifying an axis argument must be given instead of an error due to being unable to convert a null tensor.

PiperOrigin-RevId: 220129411
This commit is contained in:
Tamara Norman 2018-11-05 10:36:54 -08:00 committed by TensorFlower Gardener
parent 7761c55477
commit f74a043642

View File

@ -120,7 +120,7 @@ def expand_dims(input, axis=None, name=None, dim=None):
axis: 0-D (scalar). Specifies the dimension index at which to
expand the shape of `input`. Must be in the range
`[-rank(input) - 1, rank(input)]`.
name: The name of the output `Tensor`.
name: The name of the output `Tensor` (optional).
dim: 0-D (scalar). Equivalent to `axis`, to be deprecated.
Returns:
@ -128,9 +128,11 @@ def expand_dims(input, axis=None, name=None, dim=None):
dimension of size 1 added.
Raises:
ValueError: if both `dim` and `axis` are specified.
ValueError: if either both or neither of `dim` and `axis` are specified.
"""
axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
if axis is None:
raise ValueError("Must specify an axis argument to tf.expand_dims()")
return gen_array_ops.expand_dims(input, axis, name)