tf.numpy: Add more APIs.
Also fix some lint errors. PiperOrigin-RevId: 314862020 Change-Id: I247619fa738f8103e1d27b7e3076785726c9c52b
This commit is contained in:
parent
ca8c3462f9
commit
370dd56ff2
|
@ -13,10 +13,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Common array methods."""
|
"""Common array methods."""
|
||||||
|
# pylint: disable=g-direct-tensorflow-import
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
@ -25,6 +28,7 @@ from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import clip_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import linalg_ops
|
from tensorflow.python.ops import linalg_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
@ -660,8 +664,8 @@ def _reduce(tf_fn,
|
||||||
dtype: (optional) the dtype of the result.
|
dtype: (optional) the dtype of the result.
|
||||||
keepdims: (optional) whether to keep the reduced dimension(s).
|
keepdims: (optional) whether to keep the reduced dimension(s).
|
||||||
promote_int: how to promote integer and bool inputs. There are three
|
promote_int: how to promote integer and bool inputs. There are three
|
||||||
choices: (1) _TO_INT64: always promote them to int64 or uint64; (2)
|
choices. (1) `_TO_INT64` always promotes them to int64 or uint64; (2)
|
||||||
_TO_FLOAT: always promote them to a float type (determined by
|
`_TO_FLOAT` always promotes them to a float type (determined by
|
||||||
dtypes.default_float_type); (3) None: don't promote.
|
dtypes.default_float_type); (3) None: don't promote.
|
||||||
tf_bool_fn: (optional) the TF reduction function for bool inputs. It will
|
tf_bool_fn: (optional) the TF reduction function for bool inputs. It will
|
||||||
only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s dtype
|
only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s dtype
|
||||||
|
@ -766,48 +770,60 @@ def amin(a, axis=None, keepdims=None):
|
||||||
preserve_bool=True)
|
preserve_bool=True)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(wangpeng): Remove this workaround once b/157232284 is fixed
|
||||||
|
def _reduce_variance_complex(input_tensor, axis, keepdims):
|
||||||
|
f = functools.partial(math_ops.reduce_variance, axis=axis, keepdims=keepdims)
|
||||||
|
return f(math_ops.real(input_tensor)) + f(math_ops.imag(input_tensor))
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(wangpeng): Remove this workaround once b/157232284 is fixed
|
||||||
|
def _reduce_std_complex(input_tensor, axis, keepdims):
|
||||||
|
y = _reduce_variance_complex(
|
||||||
|
input_tensor=input_tensor, axis=axis, keepdims=keepdims)
|
||||||
|
return math_ops.sqrt(y)
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc(np.var)
|
@np_utils.np_doc(np.var)
|
||||||
def var(a, axis=None, keepdims=None):
|
def var(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring
|
||||||
|
|
||||||
|
def f(input_tensor, axis, keepdims):
|
||||||
|
if input_tensor.dtype in (dtypes.complex64, dtypes.complex128):
|
||||||
|
# A workaround for b/157232284
|
||||||
|
fn = _reduce_variance_complex
|
||||||
|
else:
|
||||||
|
fn = math_ops.reduce_variance
|
||||||
|
return fn(input_tensor=input_tensor, axis=axis, keepdims=keepdims)
|
||||||
|
|
||||||
return _reduce(
|
return _reduce(
|
||||||
math_ops.reduce_variance,
|
f, a, axis=axis, dtype=None, keepdims=keepdims, promote_int=_TO_FLOAT)
|
||||||
a,
|
|
||||||
axis=axis,
|
|
||||||
dtype=None,
|
|
||||||
keepdims=keepdims,
|
|
||||||
promote_int=_TO_FLOAT)
|
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc(np.std)
|
@np_utils.np_doc(np.std)
|
||||||
def std(a, axis=None, keepdims=None):
|
def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring
|
||||||
|
|
||||||
|
def f(input_tensor, axis, keepdims):
|
||||||
|
if input_tensor.dtype in (dtypes.complex64, dtypes.complex128):
|
||||||
|
# A workaround for b/157232284
|
||||||
|
fn = _reduce_std_complex
|
||||||
|
else:
|
||||||
|
fn = math_ops.reduce_std
|
||||||
|
return fn(input_tensor=input_tensor, axis=axis, keepdims=keepdims)
|
||||||
|
|
||||||
return _reduce(
|
return _reduce(
|
||||||
math_ops.reduce_std,
|
f, a, axis=axis, dtype=None, keepdims=keepdims, promote_int=_TO_FLOAT)
|
||||||
a,
|
|
||||||
axis=axis,
|
|
||||||
dtype=None,
|
|
||||||
keepdims=keepdims,
|
|
||||||
promote_int=_TO_FLOAT)
|
|
||||||
|
|
||||||
|
|
||||||
def ravel(a):
|
@np_utils.np_doc(np.ravel)
|
||||||
"""Flattens `a` into a 1-d array.
|
def ravel(a): # pylint: disable=missing-docstring
|
||||||
|
|
||||||
If `a` is already a 1-d ndarray it is returned as is.
|
|
||||||
|
|
||||||
Uses `tf.reshape`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
a: array_like. Could be an ndarray, a Tensor or any object that can be
|
|
||||||
converted to a Tensor using `tf.convert_to_tensor`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A 1-d ndarray.
|
|
||||||
"""
|
|
||||||
a = asarray(a)
|
a = asarray(a)
|
||||||
if a.ndim == 1:
|
if a.ndim == 1:
|
||||||
return a
|
return a
|
||||||
return np_utils.tensor_to_ndarray(array_ops.reshape(a.data, [-1]))
|
return np_utils.tensor_to_ndarray(array_ops.reshape(a.data, [-1]))
|
||||||
|
|
||||||
|
|
||||||
|
setattr(np_arrays.ndarray, 'ravel', ravel)
|
||||||
|
|
||||||
|
|
||||||
def real(val):
|
def real(val):
|
||||||
"""Returns real parts of all elements in `a`.
|
"""Returns real parts of all elements in `a`.
|
||||||
|
|
||||||
|
@ -827,10 +843,32 @@ def real(val):
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc(np.repeat)
|
@np_utils.np_doc(np.repeat)
|
||||||
def repeat(a, repeats, axis=None):
|
def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring
|
||||||
a = asarray(a).data
|
a = asarray(a).data
|
||||||
|
original_shape = a._shape_as_list() # pylint: disable=protected-access
|
||||||
|
# Best effort recovery of the shape.
|
||||||
|
if original_shape is not None and None not in original_shape:
|
||||||
|
if not original_shape:
|
||||||
|
original_shape = (repeats,)
|
||||||
|
else:
|
||||||
|
repeats_np = np.ravel(np.array(repeats))
|
||||||
|
if repeats_np.size == 1:
|
||||||
|
repeats_np = repeats_np.item()
|
||||||
|
if axis is None:
|
||||||
|
original_shape = (repeats_np * np.prod(original_shape),)
|
||||||
|
else:
|
||||||
|
original_shape[axis] = repeats_np * original_shape[axis]
|
||||||
|
else:
|
||||||
|
if axis is None:
|
||||||
|
original_shape = (repeats_np.sum(),)
|
||||||
|
else:
|
||||||
|
original_shape[axis] = repeats_np.sum()
|
||||||
|
|
||||||
repeats = asarray(repeats).data
|
repeats = asarray(repeats).data
|
||||||
return np_utils.tensor_to_ndarray(array_ops.repeat(a, repeats, axis))
|
result = array_ops.repeat(a, repeats, axis)
|
||||||
|
result.set_shape(original_shape)
|
||||||
|
|
||||||
|
return np_utils.tensor_to_ndarray(result)
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc(np.around)
|
@np_utils.np_doc(np.around)
|
||||||
|
@ -838,9 +876,12 @@ def around(a, decimals=0): # pylint: disable=missing-docstring
|
||||||
a = asarray(a)
|
a = asarray(a)
|
||||||
dtype = a.dtype
|
dtype = a.dtype
|
||||||
factor = math.pow(10, decimals)
|
factor = math.pow(10, decimals)
|
||||||
# Use float as the working dtype instead of a.dtype, because a.dtype can be
|
if np.issubdtype(dtype, np.inexact):
|
||||||
# integer and `decimals` can be negative.
|
factor = math_ops.cast(factor, dtype)
|
||||||
float_dtype = np_dtypes.default_float_type()
|
else:
|
||||||
|
# Use float as the working dtype when a.dtype is exact (e.g. integer),
|
||||||
|
# because `decimals` can be negative.
|
||||||
|
float_dtype = dtypes.default_float_type()
|
||||||
a = a.astype(float_dtype).data
|
a = a.astype(float_dtype).data
|
||||||
factor = math_ops.cast(factor, float_dtype)
|
factor = math_ops.cast(factor, float_dtype)
|
||||||
a = math_ops.multiply(a, factor)
|
a = math_ops.multiply(a, factor)
|
||||||
|
@ -853,21 +894,36 @@ round_ = around
|
||||||
setattr(np_arrays.ndarray, '__round__', around)
|
setattr(np_arrays.ndarray, '__round__', around)
|
||||||
|
|
||||||
|
|
||||||
def reshape(a, newshape):
|
@np_utils.np_doc(np.reshape)
|
||||||
"""Reshapes an array.
|
def reshape(a, newshape, order='C'):
|
||||||
|
"""order argument can only b 'C' or 'F'."""
|
||||||
|
if order not in {'C', 'F'}:
|
||||||
|
raise ValueError('Unsupported order argument {}'.format(order))
|
||||||
|
|
||||||
Args:
|
|
||||||
a: array_like. Could be an ndarray, a Tensor or any object that can be
|
|
||||||
converted to a Tensor using `tf.convert_to_tensor`.
|
|
||||||
newshape: 0-d or 1-d array_like.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An ndarray with the contents and dtype of `a` and shape `newshape`.
|
|
||||||
"""
|
|
||||||
a = asarray(a)
|
a = asarray(a)
|
||||||
if isinstance(newshape, np_arrays.ndarray):
|
if isinstance(newshape, np_arrays.ndarray):
|
||||||
newshape = newshape.data
|
newshape = newshape.data
|
||||||
return np_utils.tensor_to_ndarray(array_ops.reshape(a.data, newshape))
|
if isinstance(newshape, int):
|
||||||
|
newshape = [newshape]
|
||||||
|
|
||||||
|
if order == 'F':
|
||||||
|
r = array_ops.transpose(
|
||||||
|
array_ops.reshape(array_ops.transpose(a.data), newshape[::-1]))
|
||||||
|
else:
|
||||||
|
r = array_ops.reshape(a.data, newshape)
|
||||||
|
|
||||||
|
return np_utils.tensor_to_ndarray(r)
|
||||||
|
|
||||||
|
|
||||||
|
def _reshape_method_wrapper(a, *newshape, **kwargs):
|
||||||
|
order = kwargs.pop('order', 'C')
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError('Unsupported arguments: {}'.format(kwargs.keys()))
|
||||||
|
|
||||||
|
if len(newshape) == 1 and not isinstance(newshape[0], int):
|
||||||
|
newshape = newshape[0]
|
||||||
|
|
||||||
|
return reshape(a, newshape, order=order)
|
||||||
|
|
||||||
|
|
||||||
def expand_dims(a, axis):
|
def expand_dims(a, axis):
|
||||||
|
@ -1096,29 +1152,31 @@ def pad(ary, pad_width, mode, constant_values=0):
|
||||||
constant_values=constant_values))
|
constant_values=constant_values))
|
||||||
|
|
||||||
|
|
||||||
def take(a, indices, axis=None):
|
@np_utils.np_doc(np.take)
|
||||||
"""Take elements from an array along an axis.
|
def take(a, indices, axis=None, out=None, mode='clip'):
|
||||||
|
"""out argument is not supported, and default mode is clip."""
|
||||||
|
if out is not None:
|
||||||
|
raise ValueError('out argument is not supported in take.')
|
||||||
|
|
||||||
See https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html for
|
if mode not in {'raise', 'clip', 'wrap'}:
|
||||||
description.
|
raise ValueError("Invalid mode '{}' for take".format(mode))
|
||||||
|
|
||||||
Args:
|
a = asarray(a).data
|
||||||
a: array_like. The source array.
|
indices = asarray(indices).data
|
||||||
indices: array_like. The indices of the values to extract.
|
|
||||||
axis: int, optional. The axis over which to select values. By default, the
|
|
||||||
flattened input array is used.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A ndarray. The returned array has the same type as `a`.
|
|
||||||
"""
|
|
||||||
a = asarray(a)
|
|
||||||
indices = asarray(indices)
|
|
||||||
a = a.data
|
|
||||||
if axis is None:
|
if axis is None:
|
||||||
a = array_ops.reshape(a, [-1])
|
a = array_ops.reshape(a, [-1])
|
||||||
axis = 0
|
axis = 0
|
||||||
return np_utils.tensor_to_ndarray(
|
|
||||||
array_ops.gather(a, indices.data, axis=axis))
|
axis_size = array_ops.shape(a, out_type=indices.dtype)[axis]
|
||||||
|
if mode == 'clip':
|
||||||
|
indices = clip_ops.clip_by_value(indices, 0, axis_size - 1)
|
||||||
|
elif mode == 'wrap':
|
||||||
|
indices = math_ops.floormod(indices, axis_size)
|
||||||
|
else:
|
||||||
|
raise ValueError("The 'raise' mode to take is not supported.")
|
||||||
|
|
||||||
|
return np_utils.tensor_to_ndarray(array_ops.gather(a, indices, axis=axis))
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc_only(np.where)
|
@np_utils.np_doc_only(np.where)
|
||||||
|
@ -1134,6 +1192,23 @@ def where(condition, x=None, y=None):
|
||||||
raise ValueError('Both x and y must be ndarrays, or both must be None.')
|
raise ValueError('Both x and y must be ndarrays, or both must be None.')
|
||||||
|
|
||||||
|
|
||||||
|
@np_utils.np_doc(np.select)
|
||||||
|
def select(condlist, choicelist, default=0): # pylint: disable=missing-docstring
|
||||||
|
if len(condlist) != len(choicelist):
|
||||||
|
msg = 'condlist must have length equal to choicelist ({} vs {})'
|
||||||
|
raise ValueError(msg.format(len(condlist), len(choicelist)))
|
||||||
|
if not condlist:
|
||||||
|
raise ValueError('condlist must be non-empty')
|
||||||
|
choices = _promote_dtype(default, *choicelist)
|
||||||
|
choicelist = choices[1:]
|
||||||
|
output = choices[0]
|
||||||
|
# The traversal is in reverse order so we can return the first value in
|
||||||
|
# choicelist where condlist is True.
|
||||||
|
for cond, choice in zip(condlist[::-1], choicelist[::-1]):
|
||||||
|
output = where(cond, choice, output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def shape(a):
|
def shape(a):
|
||||||
"""Return the shape of an array.
|
"""Return the shape of an array.
|
||||||
|
|
||||||
|
@ -1413,3 +1488,130 @@ def triu(m, k=0): # pylint: disable=missing-docstring
|
||||||
return np_utils.tensor_to_ndarray(
|
return np_utils.tensor_to_ndarray(
|
||||||
array_ops.where_v2(
|
array_ops.where_v2(
|
||||||
array_ops.broadcast_to(mask, array_ops.shape(m)), z, m))
|
array_ops.broadcast_to(mask, array_ops.shape(m)), z, m))
|
||||||
|
|
||||||
|
|
||||||
|
@np_utils.np_doc(np.flip)
|
||||||
|
def flip(m, axis=None): # pylint: disable=missing-docstring
|
||||||
|
m = asarray(m).data
|
||||||
|
|
||||||
|
if axis is None:
|
||||||
|
return np_utils.tensor_to_ndarray(
|
||||||
|
array_ops.reverse(m, math_ops.range(array_ops.rank(m))))
|
||||||
|
|
||||||
|
axis = np_utils._canonicalize_axis(axis, array_ops.rank(m)) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
return np_utils.tensor_to_ndarray(array_ops.reverse(m, [axis]))
|
||||||
|
|
||||||
|
|
||||||
|
@np_utils.np_doc(np.flipud)
|
||||||
|
def flipud(m): # pylint: disable=missing-docstring
|
||||||
|
return flip(m, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@np_utils.np_doc(np.fliplr)
|
||||||
|
def fliplr(m): # pylint: disable=missing-docstring
|
||||||
|
return flip(m, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@np_utils.np_doc(np.roll)
|
||||||
|
def roll(a, shift, axis=None): # pylint: disable=missing-docstring
|
||||||
|
a = asarray(a).data
|
||||||
|
|
||||||
|
if axis is not None:
|
||||||
|
return np_utils.tensor_to_ndarray(array_ops.roll(a, shift, axis))
|
||||||
|
|
||||||
|
# If axis is None, the roll happens as a 1-d tensor.
|
||||||
|
original_shape = array_ops.shape(a)
|
||||||
|
a = array_ops.roll(array_ops.reshape(a, [-1]), shift, 0)
|
||||||
|
return np_utils.tensor_to_ndarray(array_ops.reshape(a, original_shape))
|
||||||
|
|
||||||
|
|
||||||
|
@np_utils.np_doc(np.rot90)
|
||||||
|
def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring
|
||||||
|
m_rank = array_ops.rank(m)
|
||||||
|
ax1, ax2 = np_utils._canonicalize_axes(axes, m_rank) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
k = k % 4
|
||||||
|
if k == 0:
|
||||||
|
return m
|
||||||
|
elif k == 2:
|
||||||
|
return flip(flip(m, ax1), ax2)
|
||||||
|
else:
|
||||||
|
perm = math_ops.range(m_rank)
|
||||||
|
perm = array_ops.tensor_scatter_nd_update(perm, [[ax1], [ax2]], [ax2, ax1])
|
||||||
|
|
||||||
|
if k == 1:
|
||||||
|
return transpose(flip(m, ax2), perm)
|
||||||
|
else:
|
||||||
|
return flip(transpose(m, perm), ax2)
|
||||||
|
|
||||||
|
|
||||||
|
@np_utils.np_doc(np.vander)
|
||||||
|
def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,invalid-name
|
||||||
|
x = asarray(x).data
|
||||||
|
|
||||||
|
x_shape = array_ops.shape(x)
|
||||||
|
N = N or x_shape[0]
|
||||||
|
|
||||||
|
N_temp = np_utils.get_static_value(N) # pylint: disable=invalid-name
|
||||||
|
if N_temp is not None:
|
||||||
|
N = N_temp
|
||||||
|
if N < 0:
|
||||||
|
raise ValueError('N must be nonnegative')
|
||||||
|
else:
|
||||||
|
control_flow_ops.Assert(N >= 0, [N])
|
||||||
|
|
||||||
|
rank = array_ops.rank(x)
|
||||||
|
rank_temp = np_utils.get_static_value(rank)
|
||||||
|
if rank_temp is not None:
|
||||||
|
rank = rank_temp
|
||||||
|
if rank != 1:
|
||||||
|
raise ValueError('x must be a one-dimensional array')
|
||||||
|
else:
|
||||||
|
control_flow_ops.Assert(math_ops.equal(rank, 1), [rank])
|
||||||
|
|
||||||
|
if increasing:
|
||||||
|
start = 0
|
||||||
|
limit = N
|
||||||
|
delta = 1
|
||||||
|
else:
|
||||||
|
start = N - 1
|
||||||
|
limit = -1
|
||||||
|
delta = -1
|
||||||
|
|
||||||
|
x = array_ops.expand_dims(x, -1)
|
||||||
|
return np_utils.tensor_to_ndarray(
|
||||||
|
math_ops.pow(
|
||||||
|
x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype)))
|
||||||
|
|
||||||
|
|
||||||
|
@np_utils.np_doc(np.ix_)
|
||||||
|
def ix_(*args): # pylint: disable=missing-docstring
|
||||||
|
n = len(args)
|
||||||
|
output = []
|
||||||
|
for i, a in enumerate(args):
|
||||||
|
a = asarray(a).data
|
||||||
|
a_rank = array_ops.rank(a)
|
||||||
|
a_rank_temp = np_utils.get_static_value(a_rank)
|
||||||
|
if a_rank_temp is not None:
|
||||||
|
a_rank = a_rank_temp
|
||||||
|
if a_rank != 1:
|
||||||
|
raise ValueError('Arguments must be 1-d, got arg {} of rank {}'.format(
|
||||||
|
i, a_rank))
|
||||||
|
else:
|
||||||
|
control_flow_ops.Assert(math_ops.equal(a_rank, 1), [a_rank])
|
||||||
|
|
||||||
|
new_shape = [1] * n
|
||||||
|
new_shape[i] = -1
|
||||||
|
dtype = a.dtype
|
||||||
|
if dtype == dtypes.bool:
|
||||||
|
output.append(
|
||||||
|
np_utils.tensor_to_ndarray(
|
||||||
|
array_ops.reshape(nonzero(a)[0].data, new_shape)))
|
||||||
|
elif dtype.is_integer:
|
||||||
|
output.append(np_utils.tensor_to_ndarray(array_ops.reshape(a, new_shape)))
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Only integer and bool dtypes are supported, got {}'.format(dtype))
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
|
@ -1104,6 +1104,16 @@ class ArrayManipulationTest(test.TestCase):
|
||||||
run_test([[1, 2]], (3, 2))
|
run_test([[1, 2]], (3, 2))
|
||||||
run_test([[[1, 2]], [[3, 4]], [[5, 6]]], (3, 4, 2))
|
run_test([[[1, 2]], [[3, 4]], [[5, 6]]], (3, 4, 2))
|
||||||
|
|
||||||
|
def testIx_(self):
|
||||||
|
possible_arys = [[True, True], [True, False], [False, False],
|
||||||
|
list(range(5)), np_array_ops.empty(0, dtype=np.int64)]
|
||||||
|
for r in range(len(possible_arys)):
|
||||||
|
for arys in itertools.combinations_with_replacement(possible_arys, r):
|
||||||
|
tnp_ans = np_array_ops.ix_(*arys)
|
||||||
|
onp_ans = np.ix_(*arys)
|
||||||
|
for t, o in zip(tnp_ans, onp_ans):
|
||||||
|
self.match(t, o)
|
||||||
|
|
||||||
def match_shape(self, actual, expected, msg=None):
|
def match_shape(self, actual, expected, msg=None):
|
||||||
if msg:
|
if msg:
|
||||||
msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
|
msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Mathematical operations."""
|
"""Mathematical operations."""
|
||||||
|
# pylint: disable=g-direct-tensorflow-import
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
@ -95,7 +97,7 @@ def multiply(x1, x2):
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc(np.true_divide)
|
@np_utils.np_doc(np.true_divide)
|
||||||
def true_divide(x1, x2):
|
def true_divide(x1, x2): # pylint: disable=missing-function-docstring
|
||||||
|
|
||||||
def _avoid_float64(x1, x2):
|
def _avoid_float64(x1, x2):
|
||||||
if x1.dtype == x2.dtype and x1.dtype in (dtypes.int32, dtypes.int64):
|
if x1.dtype == x2.dtype and x1.dtype in (dtypes.int32, dtypes.int64):
|
||||||
|
@ -122,7 +124,7 @@ divide = true_divide
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc(np.floor_divide)
|
@np_utils.np_doc(np.floor_divide)
|
||||||
def floor_divide(x1, x2):
|
def floor_divide(x1, x2): # pylint: disable=missing-function-docstring
|
||||||
|
|
||||||
def f(x1, x2):
|
def f(x1, x2):
|
||||||
if x1.dtype == dtypes.bool:
|
if x1.dtype == dtypes.bool:
|
||||||
|
@ -135,7 +137,7 @@ def floor_divide(x1, x2):
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc(np.mod)
|
@np_utils.np_doc(np.mod)
|
||||||
def mod(x1, x2):
|
def mod(x1, x2): # pylint: disable=missing-function-docstring
|
||||||
|
|
||||||
def f(x1, x2):
|
def f(x1, x2):
|
||||||
if x1.dtype == dtypes.bool:
|
if x1.dtype == dtypes.bool:
|
||||||
|
@ -219,7 +221,7 @@ def tensordot(a, b, axes=2):
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc_only(np.inner)
|
@np_utils.np_doc_only(np.inner)
|
||||||
def inner(a, b):
|
def inner(a, b): # pylint: disable=missing-function-docstring
|
||||||
|
|
||||||
def f(a, b):
|
def f(a, b):
|
||||||
return np_utils.cond(
|
return np_utils.cond(
|
||||||
|
@ -328,7 +330,7 @@ def nextafter(x1, x2):
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc(np.heaviside)
|
@np_utils.np_doc(np.heaviside)
|
||||||
def heaviside(x1, x2):
|
def heaviside(x1, x2): # pylint: disable=missing-function-docstring
|
||||||
|
|
||||||
def f(x1, x2):
|
def f(x1, x2):
|
||||||
return array_ops.where_v2(
|
return array_ops.where_v2(
|
||||||
|
@ -347,7 +349,7 @@ def hypot(x1, x2):
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc(np.kron)
|
@np_utils.np_doc(np.kron)
|
||||||
def kron(a, b):
|
def kron(a, b): # pylint: disable=missing-function-docstring
|
||||||
# pylint: disable=protected-access,g-complex-comprehension
|
# pylint: disable=protected-access,g-complex-comprehension
|
||||||
a, b = np_array_ops._promote_dtype(a, b)
|
a, b = np_array_ops._promote_dtype(a, b)
|
||||||
ndim = max(a.ndim, b.ndim)
|
ndim = max(a.ndim, b.ndim)
|
||||||
|
@ -392,7 +394,7 @@ def logaddexp2(x1, x2):
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc(np.polyval)
|
@np_utils.np_doc(np.polyval)
|
||||||
def polyval(p, x):
|
def polyval(p, x): # pylint: disable=missing-function-docstring
|
||||||
|
|
||||||
def f(p, x):
|
def f(p, x):
|
||||||
if p.shape.rank == 0:
|
if p.shape.rank == 0:
|
||||||
|
@ -433,9 +435,9 @@ def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
|
||||||
isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan))
|
isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan))
|
||||||
|
|
||||||
|
|
||||||
def _tf_gcd(x1, x2):
|
def _tf_gcd(x1, x2): # pylint: disable=missing-function-docstring
|
||||||
|
|
||||||
def _gcd_cond_fn(x1, x2):
|
def _gcd_cond_fn(_, x2):
|
||||||
return math_ops.reduce_any(x2 != 0)
|
return math_ops.reduce_any(x2 != 0)
|
||||||
|
|
||||||
def _gcd_body_fn(x1, x2):
|
def _gcd_body_fn(x1, x2):
|
||||||
|
@ -455,20 +457,20 @@ def _tf_gcd(x1, x2):
|
||||||
shape = array_ops.broadcast_static_shape(x1.shape, x2.shape)
|
shape = array_ops.broadcast_static_shape(x1.shape, x2.shape)
|
||||||
x1 = array_ops.broadcast_to(x1, shape)
|
x1 = array_ops.broadcast_to(x1, shape)
|
||||||
x2 = array_ops.broadcast_to(x2, shape)
|
x2 = array_ops.broadcast_to(x2, shape)
|
||||||
gcd, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn,
|
value, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn,
|
||||||
(math_ops.abs(x1), math_ops.abs(x2)))
|
(math_ops.abs(x1), math_ops.abs(x2)))
|
||||||
return gcd
|
return value
|
||||||
|
|
||||||
|
|
||||||
# Note that np.gcd may not be present in some supported versions of numpy.
|
# Note that np.gcd may not be present in some supported versions of numpy.
|
||||||
@np_utils.np_doc(None, np_fun_name="gcd")
|
@np_utils.np_doc(None, np_fun_name='gcd')
|
||||||
def gcd(x1, x2):
|
def gcd(x1, x2):
|
||||||
return _bin_op(_tf_gcd, x1, x2)
|
return _bin_op(_tf_gcd, x1, x2)
|
||||||
|
|
||||||
|
|
||||||
# Note that np.lcm may not be present in some supported versions of numpy.
|
# Note that np.lcm may not be present in some supported versions of numpy.
|
||||||
@np_utils.np_doc(None, np_fun_name="lcm")
|
@np_utils.np_doc(None, np_fun_name='lcm')
|
||||||
def lcm(x1, x2):
|
def lcm(x1, x2): # pylint: disable=missing-function-docstring
|
||||||
|
|
||||||
def f(x1, x2):
|
def f(x1, x2):
|
||||||
d = _tf_gcd(x1, x2)
|
d = _tf_gcd(x1, x2)
|
||||||
|
@ -482,7 +484,7 @@ def lcm(x1, x2):
|
||||||
return _bin_op(f, x1, x2)
|
return _bin_op(f, x1, x2)
|
||||||
|
|
||||||
|
|
||||||
def _bitwise_binary_op(tf_fn, x1, x2):
|
def _bitwise_binary_op(tf_fn, x1, x2): # pylint: disable=missing-function-docstring
|
||||||
|
|
||||||
def f(x1, x2):
|
def f(x1, x2):
|
||||||
is_bool = (x1.dtype == dtypes.bool)
|
is_bool = (x1.dtype == dtypes.bool)
|
||||||
|
@ -691,7 +693,7 @@ _tf_float_types = [
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc(np.angle)
|
@np_utils.np_doc(np.angle)
|
||||||
def angle(z, deg=False):
|
def angle(z, deg=False): # pylint: disable=missing-function-docstring
|
||||||
|
|
||||||
def f(x):
|
def f(x):
|
||||||
if x.dtype in _tf_float_types:
|
if x.dtype in _tf_float_types:
|
||||||
|
@ -861,7 +863,7 @@ def square(x):
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc(np.diff)
|
@np_utils.np_doc(np.diff)
|
||||||
def diff(a, n=1, axis=-1):
|
def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring
|
||||||
|
|
||||||
def f(a):
|
def f(a):
|
||||||
nd = a.shape.rank
|
nd = a.shape.rank
|
||||||
|
@ -1059,7 +1061,7 @@ def concatenate(arys, axis=0):
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc_only(np.tile)
|
@np_utils.np_doc_only(np.tile)
|
||||||
def tile(a, reps):
|
def tile(a, reps): # pylint: disable=missing-function-docstring
|
||||||
a = np_array_ops.array(a).data
|
a = np_array_ops.array(a).data
|
||||||
reps = np_array_ops.array(reps, dtype=dtypes.int32).reshape([-1]).data
|
reps = np_array_ops.array(reps, dtype=dtypes.int32).reshape([-1]).data
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Utility functions for internal use."""
|
"""Utility functions for internal use."""
|
||||||
|
# pylint: disable=g-direct-tensorflow-import
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
Loading…
Reference in New Issue