tf.numpy: Add more APIs.

Also fix some lint errors.

PiperOrigin-RevId: 314862020
Change-Id: I247619fa738f8103e1d27b7e3076785726c9c52b
This commit is contained in:
A. Unique TensorFlower 2020-06-04 21:34:23 -07:00 committed by TensorFlower Gardener
parent ca8c3462f9
commit 370dd56ff2
4 changed files with 300 additions and 84 deletions

View File

@ -13,10 +13,13 @@
# limitations under the License.
# ==============================================================================
"""Common array methods."""
# pylint: disable=g-direct-tensorflow-import
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import math
import numpy as np
import six
@ -25,6 +28,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
@ -660,8 +664,8 @@ def _reduce(tf_fn,
dtype: (optional) the dtype of the result.
keepdims: (optional) whether to keep the reduced dimension(s).
promote_int: how to promote integer and bool inputs. There are three
choices: (1) _TO_INT64: always promote them to int64 or uint64; (2)
_TO_FLOAT: always promote them to a float type (determined by
choices. (1) `_TO_INT64` always promotes them to int64 or uint64; (2)
`_TO_FLOAT` always promotes them to a float type (determined by
dtypes.default_float_type); (3) None: don't promote.
tf_bool_fn: (optional) the TF reduction function for bool inputs. It will
only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s dtype
@ -766,48 +770,60 @@ def amin(a, axis=None, keepdims=None):
preserve_bool=True)
# TODO(wangpeng): Remove this workaround once b/157232284 is fixed
def _reduce_variance_complex(input_tensor, axis, keepdims):
f = functools.partial(math_ops.reduce_variance, axis=axis, keepdims=keepdims)
return f(math_ops.real(input_tensor)) + f(math_ops.imag(input_tensor))
# TODO(wangpeng): Remove this workaround once b/157232284 is fixed
def _reduce_std_complex(input_tensor, axis, keepdims):
y = _reduce_variance_complex(
input_tensor=input_tensor, axis=axis, keepdims=keepdims)
return math_ops.sqrt(y)
@np_utils.np_doc(np.var)
def var(a, axis=None, keepdims=None):
def var(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring
def f(input_tensor, axis, keepdims):
if input_tensor.dtype in (dtypes.complex64, dtypes.complex128):
# A workaround for b/157232284
fn = _reduce_variance_complex
else:
fn = math_ops.reduce_variance
return fn(input_tensor=input_tensor, axis=axis, keepdims=keepdims)
return _reduce(
math_ops.reduce_variance,
a,
axis=axis,
dtype=None,
keepdims=keepdims,
promote_int=_TO_FLOAT)
f, a, axis=axis, dtype=None, keepdims=keepdims, promote_int=_TO_FLOAT)
@np_utils.np_doc(np.std)
def std(a, axis=None, keepdims=None):
def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring
def f(input_tensor, axis, keepdims):
if input_tensor.dtype in (dtypes.complex64, dtypes.complex128):
# A workaround for b/157232284
fn = _reduce_std_complex
else:
fn = math_ops.reduce_std
return fn(input_tensor=input_tensor, axis=axis, keepdims=keepdims)
return _reduce(
math_ops.reduce_std,
a,
axis=axis,
dtype=None,
keepdims=keepdims,
promote_int=_TO_FLOAT)
f, a, axis=axis, dtype=None, keepdims=keepdims, promote_int=_TO_FLOAT)
def ravel(a):
"""Flattens `a` into a 1-d array.
If `a` is already a 1-d ndarray it is returned as is.
Uses `tf.reshape`.
Args:
a: array_like. Could be an ndarray, a Tensor or any object that can be
converted to a Tensor using `tf.convert_to_tensor`.
Returns:
A 1-d ndarray.
"""
@np_utils.np_doc(np.ravel)
def ravel(a): # pylint: disable=missing-docstring
a = asarray(a)
if a.ndim == 1:
return a
return np_utils.tensor_to_ndarray(array_ops.reshape(a.data, [-1]))
setattr(np_arrays.ndarray, 'ravel', ravel)
def real(val):
"""Returns real parts of all elements in `a`.
@ -827,10 +843,32 @@ def real(val):
@np_utils.np_doc(np.repeat)
def repeat(a, repeats, axis=None):
def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring
a = asarray(a).data
original_shape = a._shape_as_list() # pylint: disable=protected-access
# Best effort recovery of the shape.
if original_shape is not None and None not in original_shape:
if not original_shape:
original_shape = (repeats,)
else:
repeats_np = np.ravel(np.array(repeats))
if repeats_np.size == 1:
repeats_np = repeats_np.item()
if axis is None:
original_shape = (repeats_np * np.prod(original_shape),)
else:
original_shape[axis] = repeats_np * original_shape[axis]
else:
if axis is None:
original_shape = (repeats_np.sum(),)
else:
original_shape[axis] = repeats_np.sum()
repeats = asarray(repeats).data
return np_utils.tensor_to_ndarray(array_ops.repeat(a, repeats, axis))
result = array_ops.repeat(a, repeats, axis)
result.set_shape(original_shape)
return np_utils.tensor_to_ndarray(result)
@np_utils.np_doc(np.around)
@ -838,9 +876,12 @@ def around(a, decimals=0): # pylint: disable=missing-docstring
a = asarray(a)
dtype = a.dtype
factor = math.pow(10, decimals)
# Use float as the working dtype instead of a.dtype, because a.dtype can be
# integer and `decimals` can be negative.
float_dtype = np_dtypes.default_float_type()
if np.issubdtype(dtype, np.inexact):
factor = math_ops.cast(factor, dtype)
else:
# Use float as the working dtype when a.dtype is exact (e.g. integer),
# because `decimals` can be negative.
float_dtype = dtypes.default_float_type()
a = a.astype(float_dtype).data
factor = math_ops.cast(factor, float_dtype)
a = math_ops.multiply(a, factor)
@ -853,21 +894,36 @@ round_ = around
setattr(np_arrays.ndarray, '__round__', around)
def reshape(a, newshape):
"""Reshapes an array.
@np_utils.np_doc(np.reshape)
def reshape(a, newshape, order='C'):
"""order argument can only b 'C' or 'F'."""
if order not in {'C', 'F'}:
raise ValueError('Unsupported order argument {}'.format(order))
Args:
a: array_like. Could be an ndarray, a Tensor or any object that can be
converted to a Tensor using `tf.convert_to_tensor`.
newshape: 0-d or 1-d array_like.
Returns:
An ndarray with the contents and dtype of `a` and shape `newshape`.
"""
a = asarray(a)
if isinstance(newshape, np_arrays.ndarray):
newshape = newshape.data
return np_utils.tensor_to_ndarray(array_ops.reshape(a.data, newshape))
if isinstance(newshape, int):
newshape = [newshape]
if order == 'F':
r = array_ops.transpose(
array_ops.reshape(array_ops.transpose(a.data), newshape[::-1]))
else:
r = array_ops.reshape(a.data, newshape)
return np_utils.tensor_to_ndarray(r)
def _reshape_method_wrapper(a, *newshape, **kwargs):
order = kwargs.pop('order', 'C')
if kwargs:
raise ValueError('Unsupported arguments: {}'.format(kwargs.keys()))
if len(newshape) == 1 and not isinstance(newshape[0], int):
newshape = newshape[0]
return reshape(a, newshape, order=order)
def expand_dims(a, axis):
@ -1096,29 +1152,31 @@ def pad(ary, pad_width, mode, constant_values=0):
constant_values=constant_values))
def take(a, indices, axis=None):
"""Take elements from an array along an axis.
@np_utils.np_doc(np.take)
def take(a, indices, axis=None, out=None, mode='clip'):
"""out argument is not supported, and default mode is clip."""
if out is not None:
raise ValueError('out argument is not supported in take.')
See https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html for
description.
if mode not in {'raise', 'clip', 'wrap'}:
raise ValueError("Invalid mode '{}' for take".format(mode))
Args:
a: array_like. The source array.
indices: array_like. The indices of the values to extract.
axis: int, optional. The axis over which to select values. By default, the
flattened input array is used.
a = asarray(a).data
indices = asarray(indices).data
Returns:
A ndarray. The returned array has the same type as `a`.
"""
a = asarray(a)
indices = asarray(indices)
a = a.data
if axis is None:
a = array_ops.reshape(a, [-1])
axis = 0
return np_utils.tensor_to_ndarray(
array_ops.gather(a, indices.data, axis=axis))
axis_size = array_ops.shape(a, out_type=indices.dtype)[axis]
if mode == 'clip':
indices = clip_ops.clip_by_value(indices, 0, axis_size - 1)
elif mode == 'wrap':
indices = math_ops.floormod(indices, axis_size)
else:
raise ValueError("The 'raise' mode to take is not supported.")
return np_utils.tensor_to_ndarray(array_ops.gather(a, indices, axis=axis))
@np_utils.np_doc_only(np.where)
@ -1134,6 +1192,23 @@ def where(condition, x=None, y=None):
raise ValueError('Both x and y must be ndarrays, or both must be None.')
@np_utils.np_doc(np.select)
def select(condlist, choicelist, default=0): # pylint: disable=missing-docstring
if len(condlist) != len(choicelist):
msg = 'condlist must have length equal to choicelist ({} vs {})'
raise ValueError(msg.format(len(condlist), len(choicelist)))
if not condlist:
raise ValueError('condlist must be non-empty')
choices = _promote_dtype(default, *choicelist)
choicelist = choices[1:]
output = choices[0]
# The traversal is in reverse order so we can return the first value in
# choicelist where condlist is True.
for cond, choice in zip(condlist[::-1], choicelist[::-1]):
output = where(cond, choice, output)
return output
def shape(a):
"""Return the shape of an array.
@ -1413,3 +1488,130 @@ def triu(m, k=0): # pylint: disable=missing-docstring
return np_utils.tensor_to_ndarray(
array_ops.where_v2(
array_ops.broadcast_to(mask, array_ops.shape(m)), z, m))
@np_utils.np_doc(np.flip)
def flip(m, axis=None): # pylint: disable=missing-docstring
m = asarray(m).data
if axis is None:
return np_utils.tensor_to_ndarray(
array_ops.reverse(m, math_ops.range(array_ops.rank(m))))
axis = np_utils._canonicalize_axis(axis, array_ops.rank(m)) # pylint: disable=protected-access
return np_utils.tensor_to_ndarray(array_ops.reverse(m, [axis]))
@np_utils.np_doc(np.flipud)
def flipud(m): # pylint: disable=missing-docstring
return flip(m, 0)
@np_utils.np_doc(np.fliplr)
def fliplr(m): # pylint: disable=missing-docstring
return flip(m, 1)
@np_utils.np_doc(np.roll)
def roll(a, shift, axis=None): # pylint: disable=missing-docstring
a = asarray(a).data
if axis is not None:
return np_utils.tensor_to_ndarray(array_ops.roll(a, shift, axis))
# If axis is None, the roll happens as a 1-d tensor.
original_shape = array_ops.shape(a)
a = array_ops.roll(array_ops.reshape(a, [-1]), shift, 0)
return np_utils.tensor_to_ndarray(array_ops.reshape(a, original_shape))
@np_utils.np_doc(np.rot90)
def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring
m_rank = array_ops.rank(m)
ax1, ax2 = np_utils._canonicalize_axes(axes, m_rank) # pylint: disable=protected-access
k = k % 4
if k == 0:
return m
elif k == 2:
return flip(flip(m, ax1), ax2)
else:
perm = math_ops.range(m_rank)
perm = array_ops.tensor_scatter_nd_update(perm, [[ax1], [ax2]], [ax2, ax1])
if k == 1:
return transpose(flip(m, ax2), perm)
else:
return flip(transpose(m, perm), ax2)
@np_utils.np_doc(np.vander)
def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,invalid-name
x = asarray(x).data
x_shape = array_ops.shape(x)
N = N or x_shape[0]
N_temp = np_utils.get_static_value(N) # pylint: disable=invalid-name
if N_temp is not None:
N = N_temp
if N < 0:
raise ValueError('N must be nonnegative')
else:
control_flow_ops.Assert(N >= 0, [N])
rank = array_ops.rank(x)
rank_temp = np_utils.get_static_value(rank)
if rank_temp is not None:
rank = rank_temp
if rank != 1:
raise ValueError('x must be a one-dimensional array')
else:
control_flow_ops.Assert(math_ops.equal(rank, 1), [rank])
if increasing:
start = 0
limit = N
delta = 1
else:
start = N - 1
limit = -1
delta = -1
x = array_ops.expand_dims(x, -1)
return np_utils.tensor_to_ndarray(
math_ops.pow(
x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype)))
@np_utils.np_doc(np.ix_)
def ix_(*args): # pylint: disable=missing-docstring
n = len(args)
output = []
for i, a in enumerate(args):
a = asarray(a).data
a_rank = array_ops.rank(a)
a_rank_temp = np_utils.get_static_value(a_rank)
if a_rank_temp is not None:
a_rank = a_rank_temp
if a_rank != 1:
raise ValueError('Arguments must be 1-d, got arg {} of rank {}'.format(
i, a_rank))
else:
control_flow_ops.Assert(math_ops.equal(a_rank, 1), [a_rank])
new_shape = [1] * n
new_shape[i] = -1
dtype = a.dtype
if dtype == dtypes.bool:
output.append(
np_utils.tensor_to_ndarray(
array_ops.reshape(nonzero(a)[0].data, new_shape)))
elif dtype.is_integer:
output.append(np_utils.tensor_to_ndarray(array_ops.reshape(a, new_shape)))
else:
raise ValueError(
'Only integer and bool dtypes are supported, got {}'.format(dtype))
return output

View File

@ -1104,6 +1104,16 @@ class ArrayManipulationTest(test.TestCase):
run_test([[1, 2]], (3, 2))
run_test([[[1, 2]], [[3, 4]], [[5, 6]]], (3, 4, 2))
def testIx_(self):
possible_arys = [[True, True], [True, False], [False, False],
list(range(5)), np_array_ops.empty(0, dtype=np.int64)]
for r in range(len(possible_arys)):
for arys in itertools.combinations_with_replacement(possible_arys, r):
tnp_ans = np_array_ops.ix_(*arys)
onp_ans = np.ix_(*arys)
for t, o in zip(tnp_ans, onp_ans):
self.match(t, o)
def match_shape(self, actual, expected, msg=None):
if msg:
msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(

View File

@ -13,6 +13,8 @@
# limitations under the License.
# ==============================================================================
"""Mathematical operations."""
# pylint: disable=g-direct-tensorflow-import
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -95,7 +97,7 @@ def multiply(x1, x2):
@np_utils.np_doc(np.true_divide)
def true_divide(x1, x2):
def true_divide(x1, x2): # pylint: disable=missing-function-docstring
def _avoid_float64(x1, x2):
if x1.dtype == x2.dtype and x1.dtype in (dtypes.int32, dtypes.int64):
@ -122,7 +124,7 @@ divide = true_divide
@np_utils.np_doc(np.floor_divide)
def floor_divide(x1, x2):
def floor_divide(x1, x2): # pylint: disable=missing-function-docstring
def f(x1, x2):
if x1.dtype == dtypes.bool:
@ -135,7 +137,7 @@ def floor_divide(x1, x2):
@np_utils.np_doc(np.mod)
def mod(x1, x2):
def mod(x1, x2): # pylint: disable=missing-function-docstring
def f(x1, x2):
if x1.dtype == dtypes.bool:
@ -219,7 +221,7 @@ def tensordot(a, b, axes=2):
@np_utils.np_doc_only(np.inner)
def inner(a, b):
def inner(a, b): # pylint: disable=missing-function-docstring
def f(a, b):
return np_utils.cond(
@ -328,7 +330,7 @@ def nextafter(x1, x2):
@np_utils.np_doc(np.heaviside)
def heaviside(x1, x2):
def heaviside(x1, x2): # pylint: disable=missing-function-docstring
def f(x1, x2):
return array_ops.where_v2(
@ -347,7 +349,7 @@ def hypot(x1, x2):
@np_utils.np_doc(np.kron)
def kron(a, b):
def kron(a, b): # pylint: disable=missing-function-docstring
# pylint: disable=protected-access,g-complex-comprehension
a, b = np_array_ops._promote_dtype(a, b)
ndim = max(a.ndim, b.ndim)
@ -392,7 +394,7 @@ def logaddexp2(x1, x2):
@np_utils.np_doc(np.polyval)
def polyval(p, x):
def polyval(p, x): # pylint: disable=missing-function-docstring
def f(p, x):
if p.shape.rank == 0:
@ -433,9 +435,9 @@ def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan))
def _tf_gcd(x1, x2):
def _tf_gcd(x1, x2): # pylint: disable=missing-function-docstring
def _gcd_cond_fn(x1, x2):
def _gcd_cond_fn(_, x2):
return math_ops.reduce_any(x2 != 0)
def _gcd_body_fn(x1, x2):
@ -455,20 +457,20 @@ def _tf_gcd(x1, x2):
shape = array_ops.broadcast_static_shape(x1.shape, x2.shape)
x1 = array_ops.broadcast_to(x1, shape)
x2 = array_ops.broadcast_to(x2, shape)
gcd, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn,
value, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn,
(math_ops.abs(x1), math_ops.abs(x2)))
return gcd
return value
# Note that np.gcd may not be present in some supported versions of numpy.
@np_utils.np_doc(None, np_fun_name="gcd")
@np_utils.np_doc(None, np_fun_name='gcd')
def gcd(x1, x2):
return _bin_op(_tf_gcd, x1, x2)
# Note that np.lcm may not be present in some supported versions of numpy.
@np_utils.np_doc(None, np_fun_name="lcm")
def lcm(x1, x2):
@np_utils.np_doc(None, np_fun_name='lcm')
def lcm(x1, x2): # pylint: disable=missing-function-docstring
def f(x1, x2):
d = _tf_gcd(x1, x2)
@ -482,7 +484,7 @@ def lcm(x1, x2):
return _bin_op(f, x1, x2)
def _bitwise_binary_op(tf_fn, x1, x2):
def _bitwise_binary_op(tf_fn, x1, x2): # pylint: disable=missing-function-docstring
def f(x1, x2):
is_bool = (x1.dtype == dtypes.bool)
@ -691,7 +693,7 @@ _tf_float_types = [
@np_utils.np_doc(np.angle)
def angle(z, deg=False):
def angle(z, deg=False): # pylint: disable=missing-function-docstring
def f(x):
if x.dtype in _tf_float_types:
@ -861,7 +863,7 @@ def square(x):
@np_utils.np_doc(np.diff)
def diff(a, n=1, axis=-1):
def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring
def f(a):
nd = a.shape.rank
@ -1059,7 +1061,7 @@ def concatenate(arys, axis=0):
@np_utils.np_doc_only(np.tile)
def tile(a, reps):
def tile(a, reps): # pylint: disable=missing-function-docstring
a = np_array_ops.array(a).data
reps = np_array_ops.array(reps, dtype=dtypes.int32).reshape([-1]).data

View File

@ -13,6 +13,8 @@
# limitations under the License.
# ==============================================================================
"""Utility functions for internal use."""
# pylint: disable=g-direct-tensorflow-import
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function