diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 232a77c8888..6fdc50733a1 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -120,7 +120,7 @@ def expand_dims(input, axis=None, name=None, dim=None): axis: 0-D (scalar). Specifies the dimension index at which to expand the shape of `input`. Must be in the range `[-rank(input) - 1, rank(input)]`. - name: The name of the output `Tensor`. + name: The name of the output `Tensor` (optional). dim: 0-D (scalar). Equivalent to `axis`, to be deprecated. Returns: @@ -128,9 +128,11 @@ def expand_dims(input, axis=None, name=None, dim=None): dimension of size 1 added. Raises: - ValueError: if both `dim` and `axis` are specified. + ValueError: if either both or neither of `dim` and `axis` are specified. """ axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim) + if axis is None: + raise ValueError("Must specify an axis argument to tf.expand_dims()") return gen_array_ops.expand_dims(input, axis, name)