From 370dd56ff202a6b69d77fdbe4bf7f72ae9ddb3e2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 4 Jun 2020 21:34:23 -0700 Subject: [PATCH] tf.numpy: Add more APIs. Also fix some lint errors. PiperOrigin-RevId: 314862020 Change-Id: I247619fa738f8103e1d27b7e3076785726c9c52b --- .../python/ops/numpy_ops/np_array_ops.py | 332 ++++++++++++++---- .../python/ops/numpy_ops/np_array_ops_test.py | 10 + .../python/ops/numpy_ops/np_math_ops.py | 40 ++- tensorflow/python/ops/numpy_ops/np_utils.py | 2 + 4 files changed, 300 insertions(+), 84 deletions(-) diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops.py b/tensorflow/python/ops/numpy_ops/np_array_ops.py index ee2141ab8de..aba7ce4f2a8 100644 --- a/tensorflow/python/ops/numpy_ops/np_array_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_array_ops.py @@ -13,10 +13,13 @@ # limitations under the License. # ============================================================================== """Common array methods.""" +# pylint: disable=g-direct-tensorflow-import + from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import math import numpy as np import six @@ -25,6 +28,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops @@ -660,8 +664,8 @@ def _reduce(tf_fn, dtype: (optional) the dtype of the result. keepdims: (optional) whether to keep the reduced dimension(s). promote_int: how to promote integer and bool inputs. There are three - choices: (1) _TO_INT64: always promote them to int64 or uint64; (2) - _TO_FLOAT: always promote them to a float type (determined by + choices. (1) `_TO_INT64` always promotes them to int64 or uint64; (2) + `_TO_FLOAT` always promotes them to a float type (determined by dtypes.default_float_type); (3) None: don't promote. tf_bool_fn: (optional) the TF reduction function for bool inputs. It will only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s dtype @@ -766,48 +770,60 @@ def amin(a, axis=None, keepdims=None): preserve_bool=True) +# TODO(wangpeng): Remove this workaround once b/157232284 is fixed +def _reduce_variance_complex(input_tensor, axis, keepdims): + f = functools.partial(math_ops.reduce_variance, axis=axis, keepdims=keepdims) + return f(math_ops.real(input_tensor)) + f(math_ops.imag(input_tensor)) + + +# TODO(wangpeng): Remove this workaround once b/157232284 is fixed +def _reduce_std_complex(input_tensor, axis, keepdims): + y = _reduce_variance_complex( + input_tensor=input_tensor, axis=axis, keepdims=keepdims) + return math_ops.sqrt(y) + + @np_utils.np_doc(np.var) -def var(a, axis=None, keepdims=None): +def var(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring + + def f(input_tensor, axis, keepdims): + if input_tensor.dtype in (dtypes.complex64, dtypes.complex128): + # A workaround for b/157232284 + fn = _reduce_variance_complex + else: + fn = math_ops.reduce_variance + return fn(input_tensor=input_tensor, axis=axis, keepdims=keepdims) + return _reduce( - math_ops.reduce_variance, - a, - axis=axis, - dtype=None, - keepdims=keepdims, - promote_int=_TO_FLOAT) + f, a, axis=axis, dtype=None, keepdims=keepdims, promote_int=_TO_FLOAT) @np_utils.np_doc(np.std) -def std(a, axis=None, keepdims=None): +def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring + + def f(input_tensor, axis, keepdims): + if input_tensor.dtype in (dtypes.complex64, dtypes.complex128): + # A workaround for b/157232284 + fn = _reduce_std_complex + else: + fn = math_ops.reduce_std + return fn(input_tensor=input_tensor, axis=axis, keepdims=keepdims) + return _reduce( - math_ops.reduce_std, - a, - axis=axis, - dtype=None, - keepdims=keepdims, - promote_int=_TO_FLOAT) + f, a, axis=axis, dtype=None, keepdims=keepdims, promote_int=_TO_FLOAT) -def ravel(a): - """Flattens `a` into a 1-d array. - - If `a` is already a 1-d ndarray it is returned as is. - - Uses `tf.reshape`. - - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can be - converted to a Tensor using `tf.convert_to_tensor`. - - Returns: - A 1-d ndarray. - """ +@np_utils.np_doc(np.ravel) +def ravel(a): # pylint: disable=missing-docstring a = asarray(a) if a.ndim == 1: return a return np_utils.tensor_to_ndarray(array_ops.reshape(a.data, [-1])) +setattr(np_arrays.ndarray, 'ravel', ravel) + + def real(val): """Returns real parts of all elements in `a`. @@ -827,10 +843,32 @@ def real(val): @np_utils.np_doc(np.repeat) -def repeat(a, repeats, axis=None): +def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring a = asarray(a).data + original_shape = a._shape_as_list() # pylint: disable=protected-access + # Best effort recovery of the shape. + if original_shape is not None and None not in original_shape: + if not original_shape: + original_shape = (repeats,) + else: + repeats_np = np.ravel(np.array(repeats)) + if repeats_np.size == 1: + repeats_np = repeats_np.item() + if axis is None: + original_shape = (repeats_np * np.prod(original_shape),) + else: + original_shape[axis] = repeats_np * original_shape[axis] + else: + if axis is None: + original_shape = (repeats_np.sum(),) + else: + original_shape[axis] = repeats_np.sum() + repeats = asarray(repeats).data - return np_utils.tensor_to_ndarray(array_ops.repeat(a, repeats, axis)) + result = array_ops.repeat(a, repeats, axis) + result.set_shape(original_shape) + + return np_utils.tensor_to_ndarray(result) @np_utils.np_doc(np.around) @@ -838,11 +876,14 @@ def around(a, decimals=0): # pylint: disable=missing-docstring a = asarray(a) dtype = a.dtype factor = math.pow(10, decimals) - # Use float as the working dtype instead of a.dtype, because a.dtype can be - # integer and `decimals` can be negative. - float_dtype = np_dtypes.default_float_type() - a = a.astype(float_dtype).data - factor = math_ops.cast(factor, float_dtype) + if np.issubdtype(dtype, np.inexact): + factor = math_ops.cast(factor, dtype) + else: + # Use float as the working dtype when a.dtype is exact (e.g. integer), + # because `decimals` can be negative. + float_dtype = dtypes.default_float_type() + a = a.astype(float_dtype).data + factor = math_ops.cast(factor, float_dtype) a = math_ops.multiply(a, factor) a = math_ops.round(a) a = math_ops.divide(a, factor) @@ -853,21 +894,36 @@ round_ = around setattr(np_arrays.ndarray, '__round__', around) -def reshape(a, newshape): - """Reshapes an array. +@np_utils.np_doc(np.reshape) +def reshape(a, newshape, order='C'): + """order argument can only b 'C' or 'F'.""" + if order not in {'C', 'F'}: + raise ValueError('Unsupported order argument {}'.format(order)) - Args: - a: array_like. Could be an ndarray, a Tensor or any object that can be - converted to a Tensor using `tf.convert_to_tensor`. - newshape: 0-d or 1-d array_like. - - Returns: - An ndarray with the contents and dtype of `a` and shape `newshape`. - """ a = asarray(a) if isinstance(newshape, np_arrays.ndarray): newshape = newshape.data - return np_utils.tensor_to_ndarray(array_ops.reshape(a.data, newshape)) + if isinstance(newshape, int): + newshape = [newshape] + + if order == 'F': + r = array_ops.transpose( + array_ops.reshape(array_ops.transpose(a.data), newshape[::-1])) + else: + r = array_ops.reshape(a.data, newshape) + + return np_utils.tensor_to_ndarray(r) + + +def _reshape_method_wrapper(a, *newshape, **kwargs): + order = kwargs.pop('order', 'C') + if kwargs: + raise ValueError('Unsupported arguments: {}'.format(kwargs.keys())) + + if len(newshape) == 1 and not isinstance(newshape[0], int): + newshape = newshape[0] + + return reshape(a, newshape, order=order) def expand_dims(a, axis): @@ -1096,29 +1152,31 @@ def pad(ary, pad_width, mode, constant_values=0): constant_values=constant_values)) -def take(a, indices, axis=None): - """Take elements from an array along an axis. +@np_utils.np_doc(np.take) +def take(a, indices, axis=None, out=None, mode='clip'): + """out argument is not supported, and default mode is clip.""" + if out is not None: + raise ValueError('out argument is not supported in take.') - See https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html for - description. + if mode not in {'raise', 'clip', 'wrap'}: + raise ValueError("Invalid mode '{}' for take".format(mode)) - Args: - a: array_like. The source array. - indices: array_like. The indices of the values to extract. - axis: int, optional. The axis over which to select values. By default, the - flattened input array is used. + a = asarray(a).data + indices = asarray(indices).data - Returns: - A ndarray. The returned array has the same type as `a`. - """ - a = asarray(a) - indices = asarray(indices) - a = a.data if axis is None: a = array_ops.reshape(a, [-1]) axis = 0 - return np_utils.tensor_to_ndarray( - array_ops.gather(a, indices.data, axis=axis)) + + axis_size = array_ops.shape(a, out_type=indices.dtype)[axis] + if mode == 'clip': + indices = clip_ops.clip_by_value(indices, 0, axis_size - 1) + elif mode == 'wrap': + indices = math_ops.floormod(indices, axis_size) + else: + raise ValueError("The 'raise' mode to take is not supported.") + + return np_utils.tensor_to_ndarray(array_ops.gather(a, indices, axis=axis)) @np_utils.np_doc_only(np.where) @@ -1134,6 +1192,23 @@ def where(condition, x=None, y=None): raise ValueError('Both x and y must be ndarrays, or both must be None.') +@np_utils.np_doc(np.select) +def select(condlist, choicelist, default=0): # pylint: disable=missing-docstring + if len(condlist) != len(choicelist): + msg = 'condlist must have length equal to choicelist ({} vs {})' + raise ValueError(msg.format(len(condlist), len(choicelist))) + if not condlist: + raise ValueError('condlist must be non-empty') + choices = _promote_dtype(default, *choicelist) + choicelist = choices[1:] + output = choices[0] + # The traversal is in reverse order so we can return the first value in + # choicelist where condlist is True. + for cond, choice in zip(condlist[::-1], choicelist[::-1]): + output = where(cond, choice, output) + return output + + def shape(a): """Return the shape of an array. @@ -1413,3 +1488,130 @@ def triu(m, k=0): # pylint: disable=missing-docstring return np_utils.tensor_to_ndarray( array_ops.where_v2( array_ops.broadcast_to(mask, array_ops.shape(m)), z, m)) + + +@np_utils.np_doc(np.flip) +def flip(m, axis=None): # pylint: disable=missing-docstring + m = asarray(m).data + + if axis is None: + return np_utils.tensor_to_ndarray( + array_ops.reverse(m, math_ops.range(array_ops.rank(m)))) + + axis = np_utils._canonicalize_axis(axis, array_ops.rank(m)) # pylint: disable=protected-access + + return np_utils.tensor_to_ndarray(array_ops.reverse(m, [axis])) + + +@np_utils.np_doc(np.flipud) +def flipud(m): # pylint: disable=missing-docstring + return flip(m, 0) + + +@np_utils.np_doc(np.fliplr) +def fliplr(m): # pylint: disable=missing-docstring + return flip(m, 1) + + +@np_utils.np_doc(np.roll) +def roll(a, shift, axis=None): # pylint: disable=missing-docstring + a = asarray(a).data + + if axis is not None: + return np_utils.tensor_to_ndarray(array_ops.roll(a, shift, axis)) + + # If axis is None, the roll happens as a 1-d tensor. + original_shape = array_ops.shape(a) + a = array_ops.roll(array_ops.reshape(a, [-1]), shift, 0) + return np_utils.tensor_to_ndarray(array_ops.reshape(a, original_shape)) + + +@np_utils.np_doc(np.rot90) +def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring + m_rank = array_ops.rank(m) + ax1, ax2 = np_utils._canonicalize_axes(axes, m_rank) # pylint: disable=protected-access + + k = k % 4 + if k == 0: + return m + elif k == 2: + return flip(flip(m, ax1), ax2) + else: + perm = math_ops.range(m_rank) + perm = array_ops.tensor_scatter_nd_update(perm, [[ax1], [ax2]], [ax2, ax1]) + + if k == 1: + return transpose(flip(m, ax2), perm) + else: + return flip(transpose(m, perm), ax2) + + +@np_utils.np_doc(np.vander) +def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,invalid-name + x = asarray(x).data + + x_shape = array_ops.shape(x) + N = N or x_shape[0] + + N_temp = np_utils.get_static_value(N) # pylint: disable=invalid-name + if N_temp is not None: + N = N_temp + if N < 0: + raise ValueError('N must be nonnegative') + else: + control_flow_ops.Assert(N >= 0, [N]) + + rank = array_ops.rank(x) + rank_temp = np_utils.get_static_value(rank) + if rank_temp is not None: + rank = rank_temp + if rank != 1: + raise ValueError('x must be a one-dimensional array') + else: + control_flow_ops.Assert(math_ops.equal(rank, 1), [rank]) + + if increasing: + start = 0 + limit = N + delta = 1 + else: + start = N - 1 + limit = -1 + delta = -1 + + x = array_ops.expand_dims(x, -1) + return np_utils.tensor_to_ndarray( + math_ops.pow( + x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype))) + + +@np_utils.np_doc(np.ix_) +def ix_(*args): # pylint: disable=missing-docstring + n = len(args) + output = [] + for i, a in enumerate(args): + a = asarray(a).data + a_rank = array_ops.rank(a) + a_rank_temp = np_utils.get_static_value(a_rank) + if a_rank_temp is not None: + a_rank = a_rank_temp + if a_rank != 1: + raise ValueError('Arguments must be 1-d, got arg {} of rank {}'.format( + i, a_rank)) + else: + control_flow_ops.Assert(math_ops.equal(a_rank, 1), [a_rank]) + + new_shape = [1] * n + new_shape[i] = -1 + dtype = a.dtype + if dtype == dtypes.bool: + output.append( + np_utils.tensor_to_ndarray( + array_ops.reshape(nonzero(a)[0].data, new_shape))) + elif dtype.is_integer: + output.append(np_utils.tensor_to_ndarray(array_ops.reshape(a, new_shape))) + else: + raise ValueError( + 'Only integer and bool dtypes are supported, got {}'.format(dtype)) + + return output diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops_test.py b/tensorflow/python/ops/numpy_ops/np_array_ops_test.py index 254cf97f52d..d69deda2d73 100644 --- a/tensorflow/python/ops/numpy_ops/np_array_ops_test.py +++ b/tensorflow/python/ops/numpy_ops/np_array_ops_test.py @@ -1104,6 +1104,16 @@ class ArrayManipulationTest(test.TestCase): run_test([[1, 2]], (3, 2)) run_test([[[1, 2]], [[3, 4]], [[5, 6]]], (3, 4, 2)) + def testIx_(self): + possible_arys = [[True, True], [True, False], [False, False], + list(range(5)), np_array_ops.empty(0, dtype=np.int64)] + for r in range(len(possible_arys)): + for arys in itertools.combinations_with_replacement(possible_arys, r): + tnp_ans = np_array_ops.ix_(*arys) + onp_ans = np.ix_(*arys) + for t, o in zip(tnp_ans, onp_ans): + self.match(t, o) + def match_shape(self, actual, expected, msg=None): if msg: msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format( diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops.py b/tensorflow/python/ops/numpy_ops/np_math_ops.py index 711c4d226bb..dac60b63bfe 100644 --- a/tensorflow/python/ops/numpy_ops/np_math_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_math_ops.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== """Mathematical operations.""" +# pylint: disable=g-direct-tensorflow-import + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -95,7 +97,7 @@ def multiply(x1, x2): @np_utils.np_doc(np.true_divide) -def true_divide(x1, x2): +def true_divide(x1, x2): # pylint: disable=missing-function-docstring def _avoid_float64(x1, x2): if x1.dtype == x2.dtype and x1.dtype in (dtypes.int32, dtypes.int64): @@ -122,7 +124,7 @@ divide = true_divide @np_utils.np_doc(np.floor_divide) -def floor_divide(x1, x2): +def floor_divide(x1, x2): # pylint: disable=missing-function-docstring def f(x1, x2): if x1.dtype == dtypes.bool: @@ -135,7 +137,7 @@ def floor_divide(x1, x2): @np_utils.np_doc(np.mod) -def mod(x1, x2): +def mod(x1, x2): # pylint: disable=missing-function-docstring def f(x1, x2): if x1.dtype == dtypes.bool: @@ -219,7 +221,7 @@ def tensordot(a, b, axes=2): @np_utils.np_doc_only(np.inner) -def inner(a, b): +def inner(a, b): # pylint: disable=missing-function-docstring def f(a, b): return np_utils.cond( @@ -328,7 +330,7 @@ def nextafter(x1, x2): @np_utils.np_doc(np.heaviside) -def heaviside(x1, x2): +def heaviside(x1, x2): # pylint: disable=missing-function-docstring def f(x1, x2): return array_ops.where_v2( @@ -347,7 +349,7 @@ def hypot(x1, x2): @np_utils.np_doc(np.kron) -def kron(a, b): +def kron(a, b): # pylint: disable=missing-function-docstring # pylint: disable=protected-access,g-complex-comprehension a, b = np_array_ops._promote_dtype(a, b) ndim = max(a.ndim, b.ndim) @@ -392,7 +394,7 @@ def logaddexp2(x1, x2): @np_utils.np_doc(np.polyval) -def polyval(p, x): +def polyval(p, x): # pylint: disable=missing-function-docstring def f(p, x): if p.shape.rank == 0: @@ -433,9 +435,9 @@ def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) -def _tf_gcd(x1, x2): +def _tf_gcd(x1, x2): # pylint: disable=missing-function-docstring - def _gcd_cond_fn(x1, x2): + def _gcd_cond_fn(_, x2): return math_ops.reduce_any(x2 != 0) def _gcd_body_fn(x1, x2): @@ -455,20 +457,20 @@ def _tf_gcd(x1, x2): shape = array_ops.broadcast_static_shape(x1.shape, x2.shape) x1 = array_ops.broadcast_to(x1, shape) x2 = array_ops.broadcast_to(x2, shape) - gcd, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn, - (math_ops.abs(x1), math_ops.abs(x2))) - return gcd + value, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn, + (math_ops.abs(x1), math_ops.abs(x2))) + return value # Note that np.gcd may not be present in some supported versions of numpy. -@np_utils.np_doc(None, np_fun_name="gcd") +@np_utils.np_doc(None, np_fun_name='gcd') def gcd(x1, x2): return _bin_op(_tf_gcd, x1, x2) # Note that np.lcm may not be present in some supported versions of numpy. -@np_utils.np_doc(None, np_fun_name="lcm") -def lcm(x1, x2): +@np_utils.np_doc(None, np_fun_name='lcm') +def lcm(x1, x2): # pylint: disable=missing-function-docstring def f(x1, x2): d = _tf_gcd(x1, x2) @@ -482,7 +484,7 @@ def lcm(x1, x2): return _bin_op(f, x1, x2) -def _bitwise_binary_op(tf_fn, x1, x2): +def _bitwise_binary_op(tf_fn, x1, x2): # pylint: disable=missing-function-docstring def f(x1, x2): is_bool = (x1.dtype == dtypes.bool) @@ -691,7 +693,7 @@ _tf_float_types = [ @np_utils.np_doc(np.angle) -def angle(z, deg=False): +def angle(z, deg=False): # pylint: disable=missing-function-docstring def f(x): if x.dtype in _tf_float_types: @@ -861,7 +863,7 @@ def square(x): @np_utils.np_doc(np.diff) -def diff(a, n=1, axis=-1): +def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring def f(a): nd = a.shape.rank @@ -1059,7 +1061,7 @@ def concatenate(arys, axis=0): @np_utils.np_doc_only(np.tile) -def tile(a, reps): +def tile(a, reps): # pylint: disable=missing-function-docstring a = np_array_ops.array(a).data reps = np_array_ops.array(reps, dtype=dtypes.int32).reshape([-1]).data diff --git a/tensorflow/python/ops/numpy_ops/np_utils.py b/tensorflow/python/ops/numpy_ops/np_utils.py index f276e35e640..598a8147980 100644 --- a/tensorflow/python/ops/numpy_ops/np_utils.py +++ b/tensorflow/python/ops/numpy_ops/np_utils.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== """Utility functions for internal use.""" +# pylint: disable=g-direct-tensorflow-import + from __future__ import absolute_import from __future__ import division from __future__ import print_function