Modify to return an error specifying an axis argument must be given instead of an error due to being unable to convert a null tensor.
PiperOrigin-RevId: 220129411
This commit is contained in:
parent
7761c55477
commit
f74a043642
@ -120,7 +120,7 @@ def expand_dims(input, axis=None, name=None, dim=None):
|
||||
axis: 0-D (scalar). Specifies the dimension index at which to
|
||||
expand the shape of `input`. Must be in the range
|
||||
`[-rank(input) - 1, rank(input)]`.
|
||||
name: The name of the output `Tensor`.
|
||||
name: The name of the output `Tensor` (optional).
|
||||
dim: 0-D (scalar). Equivalent to `axis`, to be deprecated.
|
||||
|
||||
Returns:
|
||||
@ -128,9 +128,11 @@ def expand_dims(input, axis=None, name=None, dim=None):
|
||||
dimension of size 1 added.
|
||||
|
||||
Raises:
|
||||
ValueError: if both `dim` and `axis` are specified.
|
||||
ValueError: if either both or neither of `dim` and `axis` are specified.
|
||||
"""
|
||||
axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
|
||||
if axis is None:
|
||||
raise ValueError("Must specify an axis argument to tf.expand_dims()")
|
||||
return gen_array_ops.expand_dims(input, axis, name)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user