tf.numpy: Add more APIs.
Also fix some lint errors. PiperOrigin-RevId: 314862020 Change-Id: I247619fa738f8103e1d27b7e3076785726c9c52b
This commit is contained in:
parent
ca8c3462f9
commit
370dd56ff2
|
@ -13,10 +13,13 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Common array methods."""
|
||||
# pylint: disable=g-direct-tensorflow-import
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import math
|
||||
import numpy as np
|
||||
import six
|
||||
|
@ -25,6 +28,7 @@ from tensorflow.python.framework import constant_op
|
|||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
@ -660,8 +664,8 @@ def _reduce(tf_fn,
|
|||
dtype: (optional) the dtype of the result.
|
||||
keepdims: (optional) whether to keep the reduced dimension(s).
|
||||
promote_int: how to promote integer and bool inputs. There are three
|
||||
choices: (1) _TO_INT64: always promote them to int64 or uint64; (2)
|
||||
_TO_FLOAT: always promote them to a float type (determined by
|
||||
choices. (1) `_TO_INT64` always promotes them to int64 or uint64; (2)
|
||||
`_TO_FLOAT` always promotes them to a float type (determined by
|
||||
dtypes.default_float_type); (3) None: don't promote.
|
||||
tf_bool_fn: (optional) the TF reduction function for bool inputs. It will
|
||||
only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s dtype
|
||||
|
@ -766,48 +770,60 @@ def amin(a, axis=None, keepdims=None):
|
|||
preserve_bool=True)
|
||||
|
||||
|
||||
# TODO(wangpeng): Remove this workaround once b/157232284 is fixed
|
||||
def _reduce_variance_complex(input_tensor, axis, keepdims):
|
||||
f = functools.partial(math_ops.reduce_variance, axis=axis, keepdims=keepdims)
|
||||
return f(math_ops.real(input_tensor)) + f(math_ops.imag(input_tensor))
|
||||
|
||||
|
||||
# TODO(wangpeng): Remove this workaround once b/157232284 is fixed
|
||||
def _reduce_std_complex(input_tensor, axis, keepdims):
|
||||
y = _reduce_variance_complex(
|
||||
input_tensor=input_tensor, axis=axis, keepdims=keepdims)
|
||||
return math_ops.sqrt(y)
|
||||
|
||||
|
||||
@np_utils.np_doc(np.var)
|
||||
def var(a, axis=None, keepdims=None):
|
||||
def var(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring
|
||||
|
||||
def f(input_tensor, axis, keepdims):
|
||||
if input_tensor.dtype in (dtypes.complex64, dtypes.complex128):
|
||||
# A workaround for b/157232284
|
||||
fn = _reduce_variance_complex
|
||||
else:
|
||||
fn = math_ops.reduce_variance
|
||||
return fn(input_tensor=input_tensor, axis=axis, keepdims=keepdims)
|
||||
|
||||
return _reduce(
|
||||
math_ops.reduce_variance,
|
||||
a,
|
||||
axis=axis,
|
||||
dtype=None,
|
||||
keepdims=keepdims,
|
||||
promote_int=_TO_FLOAT)
|
||||
f, a, axis=axis, dtype=None, keepdims=keepdims, promote_int=_TO_FLOAT)
|
||||
|
||||
|
||||
@np_utils.np_doc(np.std)
|
||||
def std(a, axis=None, keepdims=None):
|
||||
def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring
|
||||
|
||||
def f(input_tensor, axis, keepdims):
|
||||
if input_tensor.dtype in (dtypes.complex64, dtypes.complex128):
|
||||
# A workaround for b/157232284
|
||||
fn = _reduce_std_complex
|
||||
else:
|
||||
fn = math_ops.reduce_std
|
||||
return fn(input_tensor=input_tensor, axis=axis, keepdims=keepdims)
|
||||
|
||||
return _reduce(
|
||||
math_ops.reduce_std,
|
||||
a,
|
||||
axis=axis,
|
||||
dtype=None,
|
||||
keepdims=keepdims,
|
||||
promote_int=_TO_FLOAT)
|
||||
f, a, axis=axis, dtype=None, keepdims=keepdims, promote_int=_TO_FLOAT)
|
||||
|
||||
|
||||
def ravel(a):
|
||||
"""Flattens `a` into a 1-d array.
|
||||
|
||||
If `a` is already a 1-d ndarray it is returned as is.
|
||||
|
||||
Uses `tf.reshape`.
|
||||
|
||||
Args:
|
||||
a: array_like. Could be an ndarray, a Tensor or any object that can be
|
||||
converted to a Tensor using `tf.convert_to_tensor`.
|
||||
|
||||
Returns:
|
||||
A 1-d ndarray.
|
||||
"""
|
||||
@np_utils.np_doc(np.ravel)
|
||||
def ravel(a): # pylint: disable=missing-docstring
|
||||
a = asarray(a)
|
||||
if a.ndim == 1:
|
||||
return a
|
||||
return np_utils.tensor_to_ndarray(array_ops.reshape(a.data, [-1]))
|
||||
|
||||
|
||||
setattr(np_arrays.ndarray, 'ravel', ravel)
|
||||
|
||||
|
||||
def real(val):
|
||||
"""Returns real parts of all elements in `a`.
|
||||
|
||||
|
@ -827,10 +843,32 @@ def real(val):
|
|||
|
||||
|
||||
@np_utils.np_doc(np.repeat)
|
||||
def repeat(a, repeats, axis=None):
|
||||
def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring
|
||||
a = asarray(a).data
|
||||
original_shape = a._shape_as_list() # pylint: disable=protected-access
|
||||
# Best effort recovery of the shape.
|
||||
if original_shape is not None and None not in original_shape:
|
||||
if not original_shape:
|
||||
original_shape = (repeats,)
|
||||
else:
|
||||
repeats_np = np.ravel(np.array(repeats))
|
||||
if repeats_np.size == 1:
|
||||
repeats_np = repeats_np.item()
|
||||
if axis is None:
|
||||
original_shape = (repeats_np * np.prod(original_shape),)
|
||||
else:
|
||||
original_shape[axis] = repeats_np * original_shape[axis]
|
||||
else:
|
||||
if axis is None:
|
||||
original_shape = (repeats_np.sum(),)
|
||||
else:
|
||||
original_shape[axis] = repeats_np.sum()
|
||||
|
||||
repeats = asarray(repeats).data
|
||||
return np_utils.tensor_to_ndarray(array_ops.repeat(a, repeats, axis))
|
||||
result = array_ops.repeat(a, repeats, axis)
|
||||
result.set_shape(original_shape)
|
||||
|
||||
return np_utils.tensor_to_ndarray(result)
|
||||
|
||||
|
||||
@np_utils.np_doc(np.around)
|
||||
|
@ -838,11 +876,14 @@ def around(a, decimals=0): # pylint: disable=missing-docstring
|
|||
a = asarray(a)
|
||||
dtype = a.dtype
|
||||
factor = math.pow(10, decimals)
|
||||
# Use float as the working dtype instead of a.dtype, because a.dtype can be
|
||||
# integer and `decimals` can be negative.
|
||||
float_dtype = np_dtypes.default_float_type()
|
||||
a = a.astype(float_dtype).data
|
||||
factor = math_ops.cast(factor, float_dtype)
|
||||
if np.issubdtype(dtype, np.inexact):
|
||||
factor = math_ops.cast(factor, dtype)
|
||||
else:
|
||||
# Use float as the working dtype when a.dtype is exact (e.g. integer),
|
||||
# because `decimals` can be negative.
|
||||
float_dtype = dtypes.default_float_type()
|
||||
a = a.astype(float_dtype).data
|
||||
factor = math_ops.cast(factor, float_dtype)
|
||||
a = math_ops.multiply(a, factor)
|
||||
a = math_ops.round(a)
|
||||
a = math_ops.divide(a, factor)
|
||||
|
@ -853,21 +894,36 @@ round_ = around
|
|||
setattr(np_arrays.ndarray, '__round__', around)
|
||||
|
||||
|
||||
def reshape(a, newshape):
|
||||
"""Reshapes an array.
|
||||
@np_utils.np_doc(np.reshape)
|
||||
def reshape(a, newshape, order='C'):
|
||||
"""order argument can only b 'C' or 'F'."""
|
||||
if order not in {'C', 'F'}:
|
||||
raise ValueError('Unsupported order argument {}'.format(order))
|
||||
|
||||
Args:
|
||||
a: array_like. Could be an ndarray, a Tensor or any object that can be
|
||||
converted to a Tensor using `tf.convert_to_tensor`.
|
||||
newshape: 0-d or 1-d array_like.
|
||||
|
||||
Returns:
|
||||
An ndarray with the contents and dtype of `a` and shape `newshape`.
|
||||
"""
|
||||
a = asarray(a)
|
||||
if isinstance(newshape, np_arrays.ndarray):
|
||||
newshape = newshape.data
|
||||
return np_utils.tensor_to_ndarray(array_ops.reshape(a.data, newshape))
|
||||
if isinstance(newshape, int):
|
||||
newshape = [newshape]
|
||||
|
||||
if order == 'F':
|
||||
r = array_ops.transpose(
|
||||
array_ops.reshape(array_ops.transpose(a.data), newshape[::-1]))
|
||||
else:
|
||||
r = array_ops.reshape(a.data, newshape)
|
||||
|
||||
return np_utils.tensor_to_ndarray(r)
|
||||
|
||||
|
||||
def _reshape_method_wrapper(a, *newshape, **kwargs):
|
||||
order = kwargs.pop('order', 'C')
|
||||
if kwargs:
|
||||
raise ValueError('Unsupported arguments: {}'.format(kwargs.keys()))
|
||||
|
||||
if len(newshape) == 1 and not isinstance(newshape[0], int):
|
||||
newshape = newshape[0]
|
||||
|
||||
return reshape(a, newshape, order=order)
|
||||
|
||||
|
||||
def expand_dims(a, axis):
|
||||
|
@ -1096,29 +1152,31 @@ def pad(ary, pad_width, mode, constant_values=0):
|
|||
constant_values=constant_values))
|
||||
|
||||
|
||||
def take(a, indices, axis=None):
|
||||
"""Take elements from an array along an axis.
|
||||
@np_utils.np_doc(np.take)
|
||||
def take(a, indices, axis=None, out=None, mode='clip'):
|
||||
"""out argument is not supported, and default mode is clip."""
|
||||
if out is not None:
|
||||
raise ValueError('out argument is not supported in take.')
|
||||
|
||||
See https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html for
|
||||
description.
|
||||
if mode not in {'raise', 'clip', 'wrap'}:
|
||||
raise ValueError("Invalid mode '{}' for take".format(mode))
|
||||
|
||||
Args:
|
||||
a: array_like. The source array.
|
||||
indices: array_like. The indices of the values to extract.
|
||||
axis: int, optional. The axis over which to select values. By default, the
|
||||
flattened input array is used.
|
||||
a = asarray(a).data
|
||||
indices = asarray(indices).data
|
||||
|
||||
Returns:
|
||||
A ndarray. The returned array has the same type as `a`.
|
||||
"""
|
||||
a = asarray(a)
|
||||
indices = asarray(indices)
|
||||
a = a.data
|
||||
if axis is None:
|
||||
a = array_ops.reshape(a, [-1])
|
||||
axis = 0
|
||||
return np_utils.tensor_to_ndarray(
|
||||
array_ops.gather(a, indices.data, axis=axis))
|
||||
|
||||
axis_size = array_ops.shape(a, out_type=indices.dtype)[axis]
|
||||
if mode == 'clip':
|
||||
indices = clip_ops.clip_by_value(indices, 0, axis_size - 1)
|
||||
elif mode == 'wrap':
|
||||
indices = math_ops.floormod(indices, axis_size)
|
||||
else:
|
||||
raise ValueError("The 'raise' mode to take is not supported.")
|
||||
|
||||
return np_utils.tensor_to_ndarray(array_ops.gather(a, indices, axis=axis))
|
||||
|
||||
|
||||
@np_utils.np_doc_only(np.where)
|
||||
|
@ -1134,6 +1192,23 @@ def where(condition, x=None, y=None):
|
|||
raise ValueError('Both x and y must be ndarrays, or both must be None.')
|
||||
|
||||
|
||||
@np_utils.np_doc(np.select)
|
||||
def select(condlist, choicelist, default=0): # pylint: disable=missing-docstring
|
||||
if len(condlist) != len(choicelist):
|
||||
msg = 'condlist must have length equal to choicelist ({} vs {})'
|
||||
raise ValueError(msg.format(len(condlist), len(choicelist)))
|
||||
if not condlist:
|
||||
raise ValueError('condlist must be non-empty')
|
||||
choices = _promote_dtype(default, *choicelist)
|
||||
choicelist = choices[1:]
|
||||
output = choices[0]
|
||||
# The traversal is in reverse order so we can return the first value in
|
||||
# choicelist where condlist is True.
|
||||
for cond, choice in zip(condlist[::-1], choicelist[::-1]):
|
||||
output = where(cond, choice, output)
|
||||
return output
|
||||
|
||||
|
||||
def shape(a):
|
||||
"""Return the shape of an array.
|
||||
|
||||
|
@ -1413,3 +1488,130 @@ def triu(m, k=0): # pylint: disable=missing-docstring
|
|||
return np_utils.tensor_to_ndarray(
|
||||
array_ops.where_v2(
|
||||
array_ops.broadcast_to(mask, array_ops.shape(m)), z, m))
|
||||
|
||||
|
||||
@np_utils.np_doc(np.flip)
|
||||
def flip(m, axis=None): # pylint: disable=missing-docstring
|
||||
m = asarray(m).data
|
||||
|
||||
if axis is None:
|
||||
return np_utils.tensor_to_ndarray(
|
||||
array_ops.reverse(m, math_ops.range(array_ops.rank(m))))
|
||||
|
||||
axis = np_utils._canonicalize_axis(axis, array_ops.rank(m)) # pylint: disable=protected-access
|
||||
|
||||
return np_utils.tensor_to_ndarray(array_ops.reverse(m, [axis]))
|
||||
|
||||
|
||||
@np_utils.np_doc(np.flipud)
|
||||
def flipud(m): # pylint: disable=missing-docstring
|
||||
return flip(m, 0)
|
||||
|
||||
|
||||
@np_utils.np_doc(np.fliplr)
|
||||
def fliplr(m): # pylint: disable=missing-docstring
|
||||
return flip(m, 1)
|
||||
|
||||
|
||||
@np_utils.np_doc(np.roll)
|
||||
def roll(a, shift, axis=None): # pylint: disable=missing-docstring
|
||||
a = asarray(a).data
|
||||
|
||||
if axis is not None:
|
||||
return np_utils.tensor_to_ndarray(array_ops.roll(a, shift, axis))
|
||||
|
||||
# If axis is None, the roll happens as a 1-d tensor.
|
||||
original_shape = array_ops.shape(a)
|
||||
a = array_ops.roll(array_ops.reshape(a, [-1]), shift, 0)
|
||||
return np_utils.tensor_to_ndarray(array_ops.reshape(a, original_shape))
|
||||
|
||||
|
||||
@np_utils.np_doc(np.rot90)
|
||||
def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring
|
||||
m_rank = array_ops.rank(m)
|
||||
ax1, ax2 = np_utils._canonicalize_axes(axes, m_rank) # pylint: disable=protected-access
|
||||
|
||||
k = k % 4
|
||||
if k == 0:
|
||||
return m
|
||||
elif k == 2:
|
||||
return flip(flip(m, ax1), ax2)
|
||||
else:
|
||||
perm = math_ops.range(m_rank)
|
||||
perm = array_ops.tensor_scatter_nd_update(perm, [[ax1], [ax2]], [ax2, ax1])
|
||||
|
||||
if k == 1:
|
||||
return transpose(flip(m, ax2), perm)
|
||||
else:
|
||||
return flip(transpose(m, perm), ax2)
|
||||
|
||||
|
||||
@np_utils.np_doc(np.vander)
|
||||
def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,invalid-name
|
||||
x = asarray(x).data
|
||||
|
||||
x_shape = array_ops.shape(x)
|
||||
N = N or x_shape[0]
|
||||
|
||||
N_temp = np_utils.get_static_value(N) # pylint: disable=invalid-name
|
||||
if N_temp is not None:
|
||||
N = N_temp
|
||||
if N < 0:
|
||||
raise ValueError('N must be nonnegative')
|
||||
else:
|
||||
control_flow_ops.Assert(N >= 0, [N])
|
||||
|
||||
rank = array_ops.rank(x)
|
||||
rank_temp = np_utils.get_static_value(rank)
|
||||
if rank_temp is not None:
|
||||
rank = rank_temp
|
||||
if rank != 1:
|
||||
raise ValueError('x must be a one-dimensional array')
|
||||
else:
|
||||
control_flow_ops.Assert(math_ops.equal(rank, 1), [rank])
|
||||
|
||||
if increasing:
|
||||
start = 0
|
||||
limit = N
|
||||
delta = 1
|
||||
else:
|
||||
start = N - 1
|
||||
limit = -1
|
||||
delta = -1
|
||||
|
||||
x = array_ops.expand_dims(x, -1)
|
||||
return np_utils.tensor_to_ndarray(
|
||||
math_ops.pow(
|
||||
x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype)))
|
||||
|
||||
|
||||
@np_utils.np_doc(np.ix_)
|
||||
def ix_(*args): # pylint: disable=missing-docstring
|
||||
n = len(args)
|
||||
output = []
|
||||
for i, a in enumerate(args):
|
||||
a = asarray(a).data
|
||||
a_rank = array_ops.rank(a)
|
||||
a_rank_temp = np_utils.get_static_value(a_rank)
|
||||
if a_rank_temp is not None:
|
||||
a_rank = a_rank_temp
|
||||
if a_rank != 1:
|
||||
raise ValueError('Arguments must be 1-d, got arg {} of rank {}'.format(
|
||||
i, a_rank))
|
||||
else:
|
||||
control_flow_ops.Assert(math_ops.equal(a_rank, 1), [a_rank])
|
||||
|
||||
new_shape = [1] * n
|
||||
new_shape[i] = -1
|
||||
dtype = a.dtype
|
||||
if dtype == dtypes.bool:
|
||||
output.append(
|
||||
np_utils.tensor_to_ndarray(
|
||||
array_ops.reshape(nonzero(a)[0].data, new_shape)))
|
||||
elif dtype.is_integer:
|
||||
output.append(np_utils.tensor_to_ndarray(array_ops.reshape(a, new_shape)))
|
||||
else:
|
||||
raise ValueError(
|
||||
'Only integer and bool dtypes are supported, got {}'.format(dtype))
|
||||
|
||||
return output
|
||||
|
|
|
@ -1104,6 +1104,16 @@ class ArrayManipulationTest(test.TestCase):
|
|||
run_test([[1, 2]], (3, 2))
|
||||
run_test([[[1, 2]], [[3, 4]], [[5, 6]]], (3, 4, 2))
|
||||
|
||||
def testIx_(self):
|
||||
possible_arys = [[True, True], [True, False], [False, False],
|
||||
list(range(5)), np_array_ops.empty(0, dtype=np.int64)]
|
||||
for r in range(len(possible_arys)):
|
||||
for arys in itertools.combinations_with_replacement(possible_arys, r):
|
||||
tnp_ans = np_array_ops.ix_(*arys)
|
||||
onp_ans = np.ix_(*arys)
|
||||
for t, o in zip(tnp_ans, onp_ans):
|
||||
self.match(t, o)
|
||||
|
||||
def match_shape(self, actual, expected, msg=None):
|
||||
if msg:
|
||||
msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Mathematical operations."""
|
||||
# pylint: disable=g-direct-tensorflow-import
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
@ -95,7 +97,7 @@ def multiply(x1, x2):
|
|||
|
||||
|
||||
@np_utils.np_doc(np.true_divide)
|
||||
def true_divide(x1, x2):
|
||||
def true_divide(x1, x2): # pylint: disable=missing-function-docstring
|
||||
|
||||
def _avoid_float64(x1, x2):
|
||||
if x1.dtype == x2.dtype and x1.dtype in (dtypes.int32, dtypes.int64):
|
||||
|
@ -122,7 +124,7 @@ divide = true_divide
|
|||
|
||||
|
||||
@np_utils.np_doc(np.floor_divide)
|
||||
def floor_divide(x1, x2):
|
||||
def floor_divide(x1, x2): # pylint: disable=missing-function-docstring
|
||||
|
||||
def f(x1, x2):
|
||||
if x1.dtype == dtypes.bool:
|
||||
|
@ -135,7 +137,7 @@ def floor_divide(x1, x2):
|
|||
|
||||
|
||||
@np_utils.np_doc(np.mod)
|
||||
def mod(x1, x2):
|
||||
def mod(x1, x2): # pylint: disable=missing-function-docstring
|
||||
|
||||
def f(x1, x2):
|
||||
if x1.dtype == dtypes.bool:
|
||||
|
@ -219,7 +221,7 @@ def tensordot(a, b, axes=2):
|
|||
|
||||
|
||||
@np_utils.np_doc_only(np.inner)
|
||||
def inner(a, b):
|
||||
def inner(a, b): # pylint: disable=missing-function-docstring
|
||||
|
||||
def f(a, b):
|
||||
return np_utils.cond(
|
||||
|
@ -328,7 +330,7 @@ def nextafter(x1, x2):
|
|||
|
||||
|
||||
@np_utils.np_doc(np.heaviside)
|
||||
def heaviside(x1, x2):
|
||||
def heaviside(x1, x2): # pylint: disable=missing-function-docstring
|
||||
|
||||
def f(x1, x2):
|
||||
return array_ops.where_v2(
|
||||
|
@ -347,7 +349,7 @@ def hypot(x1, x2):
|
|||
|
||||
|
||||
@np_utils.np_doc(np.kron)
|
||||
def kron(a, b):
|
||||
def kron(a, b): # pylint: disable=missing-function-docstring
|
||||
# pylint: disable=protected-access,g-complex-comprehension
|
||||
a, b = np_array_ops._promote_dtype(a, b)
|
||||
ndim = max(a.ndim, b.ndim)
|
||||
|
@ -392,7 +394,7 @@ def logaddexp2(x1, x2):
|
|||
|
||||
|
||||
@np_utils.np_doc(np.polyval)
|
||||
def polyval(p, x):
|
||||
def polyval(p, x): # pylint: disable=missing-function-docstring
|
||||
|
||||
def f(p, x):
|
||||
if p.shape.rank == 0:
|
||||
|
@ -433,9 +435,9 @@ def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
|
|||
isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan))
|
||||
|
||||
|
||||
def _tf_gcd(x1, x2):
|
||||
def _tf_gcd(x1, x2): # pylint: disable=missing-function-docstring
|
||||
|
||||
def _gcd_cond_fn(x1, x2):
|
||||
def _gcd_cond_fn(_, x2):
|
||||
return math_ops.reduce_any(x2 != 0)
|
||||
|
||||
def _gcd_body_fn(x1, x2):
|
||||
|
@ -455,20 +457,20 @@ def _tf_gcd(x1, x2):
|
|||
shape = array_ops.broadcast_static_shape(x1.shape, x2.shape)
|
||||
x1 = array_ops.broadcast_to(x1, shape)
|
||||
x2 = array_ops.broadcast_to(x2, shape)
|
||||
gcd, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn,
|
||||
(math_ops.abs(x1), math_ops.abs(x2)))
|
||||
return gcd
|
||||
value, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn,
|
||||
(math_ops.abs(x1), math_ops.abs(x2)))
|
||||
return value
|
||||
|
||||
|
||||
# Note that np.gcd may not be present in some supported versions of numpy.
|
||||
@np_utils.np_doc(None, np_fun_name="gcd")
|
||||
@np_utils.np_doc(None, np_fun_name='gcd')
|
||||
def gcd(x1, x2):
|
||||
return _bin_op(_tf_gcd, x1, x2)
|
||||
|
||||
|
||||
# Note that np.lcm may not be present in some supported versions of numpy.
|
||||
@np_utils.np_doc(None, np_fun_name="lcm")
|
||||
def lcm(x1, x2):
|
||||
@np_utils.np_doc(None, np_fun_name='lcm')
|
||||
def lcm(x1, x2): # pylint: disable=missing-function-docstring
|
||||
|
||||
def f(x1, x2):
|
||||
d = _tf_gcd(x1, x2)
|
||||
|
@ -482,7 +484,7 @@ def lcm(x1, x2):
|
|||
return _bin_op(f, x1, x2)
|
||||
|
||||
|
||||
def _bitwise_binary_op(tf_fn, x1, x2):
|
||||
def _bitwise_binary_op(tf_fn, x1, x2): # pylint: disable=missing-function-docstring
|
||||
|
||||
def f(x1, x2):
|
||||
is_bool = (x1.dtype == dtypes.bool)
|
||||
|
@ -691,7 +693,7 @@ _tf_float_types = [
|
|||
|
||||
|
||||
@np_utils.np_doc(np.angle)
|
||||
def angle(z, deg=False):
|
||||
def angle(z, deg=False): # pylint: disable=missing-function-docstring
|
||||
|
||||
def f(x):
|
||||
if x.dtype in _tf_float_types:
|
||||
|
@ -861,7 +863,7 @@ def square(x):
|
|||
|
||||
|
||||
@np_utils.np_doc(np.diff)
|
||||
def diff(a, n=1, axis=-1):
|
||||
def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring
|
||||
|
||||
def f(a):
|
||||
nd = a.shape.rank
|
||||
|
@ -1059,7 +1061,7 @@ def concatenate(arys, axis=0):
|
|||
|
||||
|
||||
@np_utils.np_doc_only(np.tile)
|
||||
def tile(a, reps):
|
||||
def tile(a, reps): # pylint: disable=missing-function-docstring
|
||||
a = np_array_ops.array(a).data
|
||||
reps = np_array_ops.array(reps, dtype=dtypes.int32).reshape([-1]).data
|
||||
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility functions for internal use."""
|
||||
# pylint: disable=g-direct-tensorflow-import
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
|
Loading…
Reference in New Issue