From 42a734170dae2942fcf553ccf5480fd48840795a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 15 Jun 2020 23:00:10 -0700 Subject: [PATCH] tf.numpy: Change a bunch of ops to handle unknown shapes. Fix logic in sort_ops to handle 1D values. PiperOrigin-RevId: 316620637 Change-Id: Iedc2ba8aad7673bbe210661bb741bf0660f047aa --- .../python/ops/numpy_ops/np_array_ops.py | 36 ++++++-- tensorflow/python/ops/numpy_ops/np_arrays.py | 51 ++++++++--- .../python/ops/numpy_ops/np_math_ops.py | 90 +++++++++++++------ tensorflow/python/ops/numpy_ops/np_utils.py | 32 ++++--- tensorflow/python/ops/sort_ops.py | 16 +--- 5 files changed, 153 insertions(+), 72 deletions(-) diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops.py b/tensorflow/python/ops/numpy_ops/np_array_ops.py index fbf67a46e31..e97bb61613b 100644 --- a/tensorflow/python/ops/numpy_ops/np_array_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_array_ops.py @@ -223,9 +223,10 @@ def full(shape, fill_value, dtype=None): # pylint: disable=redefined-outer-name Raises: ValueError: if `fill_value` can not be broadcast to shape `shape`. """ + if not isinstance(shape, np_arrays.ndarray): + shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32)) + shape = atleast_1d(shape).data fill_value = asarray(fill_value, dtype=dtype) - if np_utils.isscalar(shape): - shape = array_ops.reshape(shape, [1]) return np_arrays.tensor_to_ndarray( array_ops.broadcast_to(fill_value.data, shape)) @@ -808,16 +809,21 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint: d @np_utils.np_doc(np.std) def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring return _reduce( - math_ops.reduce_std, a, axis=axis, dtype=None, keepdims=keepdims, + math_ops.reduce_std, + a, + axis=axis, + dtype=None, + keepdims=keepdims, promote_int=_TO_FLOAT) @np_utils.np_doc(np.ravel) def ravel(a): # pylint: disable=missing-docstring a = asarray(a) - if a.ndim == 1: - return a - return np_utils.tensor_to_ndarray(array_ops.reshape(a.data, [-1])) + out = np_utils.cond( + math_ops.equal(a.ndim, 1), lambda: a.data, + lambda: array_ops.reshape(a.data, [-1])) + return np_utils.tensor_to_ndarray(out) setattr(np_arrays.ndarray, 'ravel', ravel) @@ -846,7 +852,8 @@ def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring a = asarray(a).data original_shape = a._shape_as_list() # pylint: disable=protected-access # Best effort recovery of the shape. - if original_shape is not None and None not in original_shape: + known_shape = original_shape is not None and None not in original_shape + if known_shape: if not original_shape: original_shape = (repeats,) else: @@ -865,7 +872,8 @@ def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring repeats = asarray(repeats).data result = array_ops.repeat(a, repeats, axis) - result.set_shape(original_shape) + if known_shape: + result.set_shape(original_shape) return np_utils.tensor_to_ndarray(result) @@ -1287,7 +1295,13 @@ def broadcast_to(array, shape): # pylint: disable=redefined-outer-name @np_utils.np_doc(np.stack) -def stack(arrays, axis=0): +def stack(arrays, axis=0): # pylint: disable=missing-function-docstring + if isinstance(arrays, (np_arrays.ndarray, ops.Tensor)): + arrays = asarray(arrays) + if axis == 0: + return arrays + else: + return swapaxes(arrays, 0, axis) arrays = _promote_dtype(*arrays) # pylint: disable=protected-access unwrapped_arrays = [ a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays @@ -1450,6 +1464,8 @@ def tri(N, M=None, k=0, dtype=None): # pylint: disable=invalid-name,missing-doc @np_utils.np_doc(np.tril) def tril(m, k=0): # pylint: disable=missing-docstring m = asarray(m).data + if m.shape.ndims is None: + raise ValueError('Argument to tril should have known rank') m_shape = m.shape.as_list() if len(m_shape) < 2: @@ -1470,6 +1486,8 @@ def tril(m, k=0): # pylint: disable=missing-docstring @np_utils.np_doc(np.triu) def triu(m, k=0): # pylint: disable=missing-docstring m = asarray(m).data + if m.shape.ndims is None: + raise ValueError('Argument to triu should have known rank') m_shape = m.shape.as_list() if len(m_shape) < 2: diff --git a/tensorflow/python/ops/numpy_ops/np_arrays.py b/tensorflow/python/ops/numpy_ops/np_arrays.py index a7696ad31c2..e2f73100909 100644 --- a/tensorflow/python/ops/numpy_ops/np_arrays.py +++ b/tensorflow/python/ops/numpy_ops/np_arrays.py @@ -13,6 +13,9 @@ # limitations under the License. # ============================================================================== """ndarray class.""" + +# pylint: disable=g-direct-tensorflow-import + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -151,13 +154,16 @@ def _slice_helper(tensor, slice_spec, var=None): def convert_to_tensor(value, dtype=None, dtype_hint=None): """Wrapper over `tf.convert_to_tensor`. - Args: - value: value to convert - dtype: (optional) the type we would like it to be converted to. - dtype_hint: (optional) soft preference for the type we would like it to - be converted to. `tf.convert_to_tensor` will attempt to convert value - to this type first, but will not fail if conversion is not possible - falling back to inferring the type instead. + Args: + value: value to convert + dtype: (optional) the type we would like it to be converted to. + dtype_hint: (optional) soft preference for the type we would like it to be + converted to. `tf.convert_to_tensor` will attempt to convert value to this + type first, but will not fail if conversion is not possible falling back + to inferring the type instead. + + Returns: + Value converted to tf.Tensor. """ # A safer version of `tf.convert_to_tensor` to work around b/149876037. # TODO(wangpeng): Remove this function once the bug is fixed. @@ -250,8 +256,12 @@ class ndarray(object): # pylint: disable=invalid-name @property def shape(self): - """Returns a tuple of array dimensions.""" - return self.data._shape_tuple() # pylint: disable=protected-access + """Returns a tuple or tf.Tensor of array dimensions.""" + shape = self.data.shape + if shape.is_fully_defined(): + return tuple(shape.as_list()) + else: + return array_ops.shape(self.data) @property def dtype(self): @@ -259,19 +269,30 @@ class ndarray(object): # pylint: disable=invalid-name @property def ndim(self): - return self.data.shape.ndims + ndims = self.data.shape.ndims + if ndims is None: + return array_ops.rank(self.data) + else: + return ndims @property def size(self): """Returns the number of elements in the array.""" - return np.prod(self.shape) + shape = self.shape + if isinstance(shape, ops.Tensor): + return array_ops.size(self.data) + else: + return np.prod(self.shape) @property def T(self): # pylint: disable=invalid-name return self.transpose() def __len__(self): - if self.shape: + shape = self.shape + if isinstance(shape, ops.Tensor): + raise TypeError('len() of symbolic tensor undefined') + elif shape: return self.shape[0] else: raise TypeError('len() of unsized object.') @@ -320,6 +341,8 @@ class ndarray(object): # pylint: disable=invalid-name return tensor_to_ndarray(result_t) def __iter__(self): + if not isinstance(self.data, ops.EagerTensor): + raise TypeError('Iteration over symbolic tensor is not allowed') for i in range(self.shape[0]): result_t = self.data[i] yield tensor_to_ndarray(result_t) @@ -356,6 +379,8 @@ class ndarray(object): # pylint: disable=invalid-name ValueError: If the array does not have size 1. """ # TODO(wangpeng): Handle graph mode + if not isinstance(self.data, ops.EagerTensor): + raise TypeError('Indexing using symbolic tensor is not allowed') return np.asscalar(self.data.numpy()) def tolist(self): @@ -384,5 +409,3 @@ def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False): ops.register_tensor_conversion_function(ndarray, ndarray_to_tensor) - - diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops.py b/tensorflow/python/ops/numpy_ops/np_math_ops.py index 02d37b3a3a4..b32f78bee5a 100644 --- a/tensorflow/python/ops/numpy_ops/np_math_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_math_ops.py @@ -50,9 +50,9 @@ def dot(a, b): # pylint: disable=missing-docstring math_ops.equal(array_ops.rank(b), 0)), lambda: a * b, lambda: np_utils.cond( # pylint: disable=g-long-lambda - math_ops.equal(array_ops.rank(b), 1), lambda: math_ops.tensordot( - a, b, axes=[[-1], [-1]]), lambda: math_ops.tensordot( - a, b, axes=[[-1], [-2]]))) + math_ops.equal(array_ops.rank(b), 1), + lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]]), + lambda: math_ops.tensordot(a, b, axes=[[-1], [-2]]))) return _bin_op(f, a, b) @@ -204,8 +204,8 @@ def matmul(x1, x2): # pylint: disable=missing-docstring return np_utils.cond( math_ops.equal(array_ops.rank(x2), 1), lambda: math_ops.tensordot(x1, x2, axes=1), - lambda: np_utils.cond( - math_ops.equal(array_ops.rank(x1), 1), # pylint: disable=g-long-lambda + lambda: np_utils.cond( # pylint: disable=g-long-lambda + math_ops.equal(array_ops.rank(x1), 1), lambda: math_ops.tensordot( # pylint: disable=g-long-lambda x1, x2, axes=[[0], [-2]]), lambda: math_ops.matmul(x1, x2))) @@ -352,14 +352,30 @@ def hypot(x1, x2): def kron(a, b): # pylint: disable=missing-function-docstring # pylint: disable=protected-access,g-complex-comprehension a, b = np_array_ops._promote_dtype(a, b) - ndim = max(a.ndim, b.ndim) - if a.ndim < ndim: - a = np_array_ops.reshape(a, np_array_ops._pad_left_to(ndim, a.shape)) - if b.ndim < ndim: - b = np_array_ops.reshape(b, np_array_ops._pad_left_to(ndim, b.shape)) - a_reshaped = np_array_ops.reshape(a, [i for d in a.shape for i in (d, 1)]) - b_reshaped = np_array_ops.reshape(b, [i for d in b.shape for i in (1, d)]) - out_shape = tuple(np.multiply(a.shape, b.shape)) + t_a = np_utils.cond( + a.ndim < b.ndim, + lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda + a.data, np_array_ops._pad_left_to(b.ndim, a.shape)), + lambda: a.data) + t_b = np_utils.cond( + b.ndim < a.ndim, + lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda + b.data, np_array_ops._pad_left_to(a.ndim, b.shape)), + lambda: b.data) + + def _make_shape(shape, prepend): + ones = array_ops.ones_like(shape) + if prepend: + shapes = [ones, shape] + else: + shapes = [shape, ones] + return array_ops.reshape(array_ops.stack(shapes, axis=1), [-1]) + + a_shape = array_ops.shape(t_a) + b_shape = array_ops.shape(t_b) + a_reshaped = np_array_ops.reshape(t_a, _make_shape(a_shape, False)) + b_reshaped = np_array_ops.reshape(t_b, _make_shape(b_shape, True)) + out_shape = a_shape * b_shape return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape) @@ -454,7 +470,8 @@ def _tf_gcd(x1, x2): # pylint: disable=missing-function-docstring if (not np.issubdtype(x1.dtype.as_numpy_dtype, np.integer) or not np.issubdtype(x2.dtype.as_numpy_dtype, np.integer)): raise ValueError('Arguments to gcd must be integers.') - shape = array_ops.broadcast_static_shape(x1.shape, x2.shape) + shape = array_ops.broadcast_dynamic_shape( + array_ops.shape(x1), array_ops.shape(x2)) x1 = array_ops.broadcast_to(x1, shape) x2 = array_ops.broadcast_to(x2, shape) value, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn, @@ -607,7 +624,7 @@ def signbit(x): def f(x): if x.dtype == dtypes.bool: - return array_ops.fill(x.shape, False) + return array_ops.fill(array_ops.shape(x), False) return x < 0 return _scalar(f, x) @@ -866,7 +883,11 @@ def square(x): def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring def f(a): + # TODO(agarwal): transpose and reshape to N, H, 1 and do a 1D convolution + # TODO(agarwal): avoid depending on static rank. nd = a.shape.rank + if nd is None: + raise ValueError('diff currently requires known rank for input `a`') if (axis + nd if axis < 0 else axis) >= nd: raise ValueError('axis %s is out of bounds for array of dimension %s' % (axis, nd)) @@ -887,8 +908,10 @@ def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring def _flip_args(f): + def _f(a, b): return f(b, a) + return _f @@ -910,6 +933,7 @@ setattr(np_arrays.ndarray, '__rtruediv__', _flip_args(true_divide)) def _comparison(tf_fun, x1, x2, cast_bool_to_int=False): + """Helper function for comparision.""" dtype = np_utils.result_type(x1, x2) # Cast x1 and x2 to the result_type if needed. x1 = np_array_ops.array(x1, dtype=dtype) @@ -953,12 +977,18 @@ def less_equal(x1, x2): @np_utils.np_doc(np.array_equal) -def array_equal(a1, a2): +def array_equal(a1, a2): # pylint: disable=missing-function-docstring - def f(a1, a2): - if a1.shape != a2.shape: - return constant_op.constant(False) - return math_ops.reduce_all(math_ops.equal(a1, a2)) + def f(x1, x2): + return np_utils.cond( + math_ops.equal(array_ops.rank(x1), array_ops.rank(x2)), + lambda: np_utils.cond( # pylint: disable=g-long-lambda + np_utils.reduce_all( + math_ops.equal(array_ops.shape(x1), array_ops.shape(x2)) + ), + lambda: math_ops.reduce_all(math_ops.equal(x1, x2)), + lambda: constant_op.constant(False)), + lambda: constant_op.constant(False)) return _comparison(f, a1, a2) @@ -1001,7 +1031,13 @@ setattr(np_arrays.ndarray, '__ne__', not_equal) @np_utils.np_doc(np.linspace) def linspace( # pylint: disable=missing-docstring - start, stop, num=50, endpoint=True, retstep=False, dtype=float, axis=0): + start, + stop, + num=50, + endpoint=True, + retstep=False, + dtype=float, + axis=0): if dtype: dtype = np_utils.result_type(dtype) start = np_array_ops.array(start, dtype=dtype).data @@ -1054,10 +1090,14 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): # pylint start_sign = 1 - np_array_ops.sign(np_array_ops.real(start)) stop_sign = 1 - np_array_ops.sign(np_array_ops.real(stop)) signflip = 1 - start_sign * stop_sign // 2 - res = signflip * logspace(log10(signflip * start), - log10(signflip * stop), num, - endpoint=endpoint, base=10.0, - dtype=computation_dtype, axis=0) + res = signflip * logspace( + log10(signflip * start), + log10(signflip * stop), + num, + endpoint=endpoint, + base=10.0, + dtype=computation_dtype, + axis=0) if axis != 0: res = np_array_ops.moveaxis(res, 0, axis) return np_utils.tensor_to_ndarray(math_ops.cast(res, dtype)) diff --git a/tensorflow/python/ops/numpy_ops/np_utils.py b/tensorflow/python/ops/numpy_ops/np_utils.py index 47b45b171fb..186e56816fe 100644 --- a/tensorflow/python/ops/numpy_ops/np_utils.py +++ b/tensorflow/python/ops/numpy_ops/np_utils.py @@ -47,7 +47,7 @@ def _canonicalize_axes(axes, rank): canonicalizer = ( lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis)) else: - canonicalizer = lambda axis: axis+rank if axis < 0 else axis + canonicalizer = lambda axis: axis + rank if axis < 0 else axis return [canonicalizer(axis) for axis in axes] @@ -100,9 +100,16 @@ def finfo(dtype): def isscalar(val): """Returns whether `val` is a scalar value or scalar Tensor.""" - if isinstance(val, (np.ndarray, np_arrays.ndarray, ops.Tensor)): - return len(val.shape) == 0 # pylint: disable=g-explicit-length-test - return np.isscalar(val) + if isinstance(val, np_arrays.ndarray): + val = val.data + if isinstance(val, ops.Tensor): + ndims = val.shape.ndims + if ndims is not None: + return ndims == 0 + else: + return math_ops.equal(array_ops.rank(val), 0) + else: + return np.isscalar(val) # Can't use np_doc because np.result_type is a builtin function. @@ -119,8 +126,8 @@ def result_type(*arrays_and_dtypes): def maybe_get_dtype(x): # Don't put np.ndarray in this list, because np.result_type looks at the # value (not just dtype) of np.ndarray to decide the result type. - if isinstance(x, (np_arrays.ndarray, ops.Tensor, - indexed_slices.IndexedSlices)): + if isinstance( + x, (np_arrays.ndarray, ops.Tensor, indexed_slices.IndexedSlices)): return _to_numpy_type(x.dtype) elif isinstance(x, dtypes.DType): return _to_numpy_type(x) @@ -277,8 +284,11 @@ def np_doc(np_fun, np_fun_name=None): # for name in np_sig.parameters: # if name not in sig.parameters: # unsupported_params.append(name) - f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name, - unsupported_params=unsupported_params) + f.__doc__ = _np_doc_helper( + f, + np_fun, + np_fun_name=np_fun_name, + unsupported_params=unsupported_params) return f return decorator @@ -287,9 +297,9 @@ def np_doc(np_fun, np_fun_name=None): def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None): """Helper to get docs.""" if not unsupported_params and not _has_docstring(f) and _has_docstring(np_f): - # TODO(wangpeng): It looks like code snippets in numpy doc don't work - # correctly with doctest. Fix that and remove the reformatting of the np_f - # comment, here and below. + # TODO(wangpeng): It looks like code snippets in numpy doc don't work + # correctly with doctest. Fix that and remove the reformatting of the np_f + # comment, here and below. return np_f.__doc__.replace('>>>', '>') assert np_f or np_fun_name if not np_fun_name: diff --git a/tensorflow/python/ops/sort_ops.py b/tensorflow/python/ops/sort_ops.py index 4e66a80bc01..d711516cb86 100644 --- a/tensorflow/python/ops/sort_ops.py +++ b/tensorflow/python/ops/sort_ops.py @@ -134,7 +134,7 @@ def _sort_or_argsort(values, axis, direction, return_argsort): # Axis must be an integer, not a Tensor. axis = framework_ops.convert_to_tensor(axis, name='axis') axis_static = tensor_util.constant_value(axis) - if axis.shape.ndims != 0 or axis_static is None: + if axis.shape.ndims not in (None, 0) or axis_static is None: raise ValueError('axis must be a constant scalar') axis_static = int(axis_static) # Avoids NumPy casting error @@ -184,18 +184,8 @@ def _descending_sort(values, axis, return_argsort=False): name='transposition') else: # Generate the transposition array from the tensors. - transposition = array_ops.concat( - [ - # Axes up to axis are unchanged. - math_ops.range(axis), - # Swap axis and rank - 1. - [rank - 1], - # Axes in [axis + 1, rank - 1) are unchanged. - math_ops.range(axis + 1, rank - 1), - # Swap axis and rank - 1. - [axis] - ], - axis=0) + transposition = array_ops.tensor_scatter_update( + math_ops.range(rank), [[axis], [rank-1]], [rank-1, axis]) top_k_input = array_ops.transpose(values, transposition) values, indices = nn_ops.top_k(top_k_input, k)