tf.numpy: Change a bunch of ops to handle unknown shapes.
Fix logic in sort_ops to handle 1D values. PiperOrigin-RevId: 316620637 Change-Id: Iedc2ba8aad7673bbe210661bb741bf0660f047aa
This commit is contained in:
parent
e74010b4e8
commit
42a734170d
@ -223,9 +223,10 @@ def full(shape, fill_value, dtype=None): # pylint: disable=redefined-outer-name
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: if `fill_value` can not be broadcast to shape `shape`.
|
ValueError: if `fill_value` can not be broadcast to shape `shape`.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(shape, np_arrays.ndarray):
|
||||||
|
shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32))
|
||||||
|
shape = atleast_1d(shape).data
|
||||||
fill_value = asarray(fill_value, dtype=dtype)
|
fill_value = asarray(fill_value, dtype=dtype)
|
||||||
if np_utils.isscalar(shape):
|
|
||||||
shape = array_ops.reshape(shape, [1])
|
|
||||||
return np_arrays.tensor_to_ndarray(
|
return np_arrays.tensor_to_ndarray(
|
||||||
array_ops.broadcast_to(fill_value.data, shape))
|
array_ops.broadcast_to(fill_value.data, shape))
|
||||||
|
|
||||||
@ -808,16 +809,21 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint: d
|
|||||||
@np_utils.np_doc(np.std)
|
@np_utils.np_doc(np.std)
|
||||||
def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring
|
def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring
|
||||||
return _reduce(
|
return _reduce(
|
||||||
math_ops.reduce_std, a, axis=axis, dtype=None, keepdims=keepdims,
|
math_ops.reduce_std,
|
||||||
|
a,
|
||||||
|
axis=axis,
|
||||||
|
dtype=None,
|
||||||
|
keepdims=keepdims,
|
||||||
promote_int=_TO_FLOAT)
|
promote_int=_TO_FLOAT)
|
||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc(np.ravel)
|
@np_utils.np_doc(np.ravel)
|
||||||
def ravel(a): # pylint: disable=missing-docstring
|
def ravel(a): # pylint: disable=missing-docstring
|
||||||
a = asarray(a)
|
a = asarray(a)
|
||||||
if a.ndim == 1:
|
out = np_utils.cond(
|
||||||
return a
|
math_ops.equal(a.ndim, 1), lambda: a.data,
|
||||||
return np_utils.tensor_to_ndarray(array_ops.reshape(a.data, [-1]))
|
lambda: array_ops.reshape(a.data, [-1]))
|
||||||
|
return np_utils.tensor_to_ndarray(out)
|
||||||
|
|
||||||
|
|
||||||
setattr(np_arrays.ndarray, 'ravel', ravel)
|
setattr(np_arrays.ndarray, 'ravel', ravel)
|
||||||
@ -846,7 +852,8 @@ def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring
|
|||||||
a = asarray(a).data
|
a = asarray(a).data
|
||||||
original_shape = a._shape_as_list() # pylint: disable=protected-access
|
original_shape = a._shape_as_list() # pylint: disable=protected-access
|
||||||
# Best effort recovery of the shape.
|
# Best effort recovery of the shape.
|
||||||
if original_shape is not None and None not in original_shape:
|
known_shape = original_shape is not None and None not in original_shape
|
||||||
|
if known_shape:
|
||||||
if not original_shape:
|
if not original_shape:
|
||||||
original_shape = (repeats,)
|
original_shape = (repeats,)
|
||||||
else:
|
else:
|
||||||
@ -865,7 +872,8 @@ def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring
|
|||||||
|
|
||||||
repeats = asarray(repeats).data
|
repeats = asarray(repeats).data
|
||||||
result = array_ops.repeat(a, repeats, axis)
|
result = array_ops.repeat(a, repeats, axis)
|
||||||
result.set_shape(original_shape)
|
if known_shape:
|
||||||
|
result.set_shape(original_shape)
|
||||||
|
|
||||||
return np_utils.tensor_to_ndarray(result)
|
return np_utils.tensor_to_ndarray(result)
|
||||||
|
|
||||||
@ -1287,7 +1295,13 @@ def broadcast_to(array, shape): # pylint: disable=redefined-outer-name
|
|||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc(np.stack)
|
@np_utils.np_doc(np.stack)
|
||||||
def stack(arrays, axis=0):
|
def stack(arrays, axis=0): # pylint: disable=missing-function-docstring
|
||||||
|
if isinstance(arrays, (np_arrays.ndarray, ops.Tensor)):
|
||||||
|
arrays = asarray(arrays)
|
||||||
|
if axis == 0:
|
||||||
|
return arrays
|
||||||
|
else:
|
||||||
|
return swapaxes(arrays, 0, axis)
|
||||||
arrays = _promote_dtype(*arrays) # pylint: disable=protected-access
|
arrays = _promote_dtype(*arrays) # pylint: disable=protected-access
|
||||||
unwrapped_arrays = [
|
unwrapped_arrays = [
|
||||||
a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays
|
a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays
|
||||||
@ -1450,6 +1464,8 @@ def tri(N, M=None, k=0, dtype=None): # pylint: disable=invalid-name,missing-doc
|
|||||||
@np_utils.np_doc(np.tril)
|
@np_utils.np_doc(np.tril)
|
||||||
def tril(m, k=0): # pylint: disable=missing-docstring
|
def tril(m, k=0): # pylint: disable=missing-docstring
|
||||||
m = asarray(m).data
|
m = asarray(m).data
|
||||||
|
if m.shape.ndims is None:
|
||||||
|
raise ValueError('Argument to tril should have known rank')
|
||||||
m_shape = m.shape.as_list()
|
m_shape = m.shape.as_list()
|
||||||
|
|
||||||
if len(m_shape) < 2:
|
if len(m_shape) < 2:
|
||||||
@ -1470,6 +1486,8 @@ def tril(m, k=0): # pylint: disable=missing-docstring
|
|||||||
@np_utils.np_doc(np.triu)
|
@np_utils.np_doc(np.triu)
|
||||||
def triu(m, k=0): # pylint: disable=missing-docstring
|
def triu(m, k=0): # pylint: disable=missing-docstring
|
||||||
m = asarray(m).data
|
m = asarray(m).data
|
||||||
|
if m.shape.ndims is None:
|
||||||
|
raise ValueError('Argument to triu should have known rank')
|
||||||
m_shape = m.shape.as_list()
|
m_shape = m.shape.as_list()
|
||||||
|
|
||||||
if len(m_shape) < 2:
|
if len(m_shape) < 2:
|
||||||
|
@ -13,6 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""ndarray class."""
|
"""ndarray class."""
|
||||||
|
|
||||||
|
# pylint: disable=g-direct-tensorflow-import
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
@ -151,13 +154,16 @@ def _slice_helper(tensor, slice_spec, var=None):
|
|||||||
def convert_to_tensor(value, dtype=None, dtype_hint=None):
|
def convert_to_tensor(value, dtype=None, dtype_hint=None):
|
||||||
"""Wrapper over `tf.convert_to_tensor`.
|
"""Wrapper over `tf.convert_to_tensor`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
value: value to convert
|
value: value to convert
|
||||||
dtype: (optional) the type we would like it to be converted to.
|
dtype: (optional) the type we would like it to be converted to.
|
||||||
dtype_hint: (optional) soft preference for the type we would like it to
|
dtype_hint: (optional) soft preference for the type we would like it to be
|
||||||
be converted to. `tf.convert_to_tensor` will attempt to convert value
|
converted to. `tf.convert_to_tensor` will attempt to convert value to this
|
||||||
to this type first, but will not fail if conversion is not possible
|
type first, but will not fail if conversion is not possible falling back
|
||||||
falling back to inferring the type instead.
|
to inferring the type instead.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Value converted to tf.Tensor.
|
||||||
"""
|
"""
|
||||||
# A safer version of `tf.convert_to_tensor` to work around b/149876037.
|
# A safer version of `tf.convert_to_tensor` to work around b/149876037.
|
||||||
# TODO(wangpeng): Remove this function once the bug is fixed.
|
# TODO(wangpeng): Remove this function once the bug is fixed.
|
||||||
@ -250,8 +256,12 @@ class ndarray(object): # pylint: disable=invalid-name
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
"""Returns a tuple of array dimensions."""
|
"""Returns a tuple or tf.Tensor of array dimensions."""
|
||||||
return self.data._shape_tuple() # pylint: disable=protected-access
|
shape = self.data.shape
|
||||||
|
if shape.is_fully_defined():
|
||||||
|
return tuple(shape.as_list())
|
||||||
|
else:
|
||||||
|
return array_ops.shape(self.data)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
@ -259,19 +269,30 @@ class ndarray(object): # pylint: disable=invalid-name
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def ndim(self):
|
def ndim(self):
|
||||||
return self.data.shape.ndims
|
ndims = self.data.shape.ndims
|
||||||
|
if ndims is None:
|
||||||
|
return array_ops.rank(self.data)
|
||||||
|
else:
|
||||||
|
return ndims
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def size(self):
|
def size(self):
|
||||||
"""Returns the number of elements in the array."""
|
"""Returns the number of elements in the array."""
|
||||||
return np.prod(self.shape)
|
shape = self.shape
|
||||||
|
if isinstance(shape, ops.Tensor):
|
||||||
|
return array_ops.size(self.data)
|
||||||
|
else:
|
||||||
|
return np.prod(self.shape)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def T(self): # pylint: disable=invalid-name
|
def T(self): # pylint: disable=invalid-name
|
||||||
return self.transpose()
|
return self.transpose()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if self.shape:
|
shape = self.shape
|
||||||
|
if isinstance(shape, ops.Tensor):
|
||||||
|
raise TypeError('len() of symbolic tensor undefined')
|
||||||
|
elif shape:
|
||||||
return self.shape[0]
|
return self.shape[0]
|
||||||
else:
|
else:
|
||||||
raise TypeError('len() of unsized object.')
|
raise TypeError('len() of unsized object.')
|
||||||
@ -320,6 +341,8 @@ class ndarray(object): # pylint: disable=invalid-name
|
|||||||
return tensor_to_ndarray(result_t)
|
return tensor_to_ndarray(result_t)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
if not isinstance(self.data, ops.EagerTensor):
|
||||||
|
raise TypeError('Iteration over symbolic tensor is not allowed')
|
||||||
for i in range(self.shape[0]):
|
for i in range(self.shape[0]):
|
||||||
result_t = self.data[i]
|
result_t = self.data[i]
|
||||||
yield tensor_to_ndarray(result_t)
|
yield tensor_to_ndarray(result_t)
|
||||||
@ -356,6 +379,8 @@ class ndarray(object): # pylint: disable=invalid-name
|
|||||||
ValueError: If the array does not have size 1.
|
ValueError: If the array does not have size 1.
|
||||||
"""
|
"""
|
||||||
# TODO(wangpeng): Handle graph mode
|
# TODO(wangpeng): Handle graph mode
|
||||||
|
if not isinstance(self.data, ops.EagerTensor):
|
||||||
|
raise TypeError('Indexing using symbolic tensor is not allowed')
|
||||||
return np.asscalar(self.data.numpy())
|
return np.asscalar(self.data.numpy())
|
||||||
|
|
||||||
def tolist(self):
|
def tolist(self):
|
||||||
@ -384,5 +409,3 @@ def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False):
|
|||||||
|
|
||||||
|
|
||||||
ops.register_tensor_conversion_function(ndarray, ndarray_to_tensor)
|
ops.register_tensor_conversion_function(ndarray, ndarray_to_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,9 +50,9 @@ def dot(a, b): # pylint: disable=missing-docstring
|
|||||||
math_ops.equal(array_ops.rank(b), 0)),
|
math_ops.equal(array_ops.rank(b), 0)),
|
||||||
lambda: a * b,
|
lambda: a * b,
|
||||||
lambda: np_utils.cond( # pylint: disable=g-long-lambda
|
lambda: np_utils.cond( # pylint: disable=g-long-lambda
|
||||||
math_ops.equal(array_ops.rank(b), 1), lambda: math_ops.tensordot(
|
math_ops.equal(array_ops.rank(b), 1),
|
||||||
a, b, axes=[[-1], [-1]]), lambda: math_ops.tensordot(
|
lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]]),
|
||||||
a, b, axes=[[-1], [-2]])))
|
lambda: math_ops.tensordot(a, b, axes=[[-1], [-2]])))
|
||||||
|
|
||||||
return _bin_op(f, a, b)
|
return _bin_op(f, a, b)
|
||||||
|
|
||||||
@ -204,8 +204,8 @@ def matmul(x1, x2): # pylint: disable=missing-docstring
|
|||||||
return np_utils.cond(
|
return np_utils.cond(
|
||||||
math_ops.equal(array_ops.rank(x2), 1),
|
math_ops.equal(array_ops.rank(x2), 1),
|
||||||
lambda: math_ops.tensordot(x1, x2, axes=1),
|
lambda: math_ops.tensordot(x1, x2, axes=1),
|
||||||
lambda: np_utils.cond(
|
lambda: np_utils.cond( # pylint: disable=g-long-lambda
|
||||||
math_ops.equal(array_ops.rank(x1), 1), # pylint: disable=g-long-lambda
|
math_ops.equal(array_ops.rank(x1), 1),
|
||||||
lambda: math_ops.tensordot( # pylint: disable=g-long-lambda
|
lambda: math_ops.tensordot( # pylint: disable=g-long-lambda
|
||||||
x1, x2, axes=[[0], [-2]]),
|
x1, x2, axes=[[0], [-2]]),
|
||||||
lambda: math_ops.matmul(x1, x2)))
|
lambda: math_ops.matmul(x1, x2)))
|
||||||
@ -352,14 +352,30 @@ def hypot(x1, x2):
|
|||||||
def kron(a, b): # pylint: disable=missing-function-docstring
|
def kron(a, b): # pylint: disable=missing-function-docstring
|
||||||
# pylint: disable=protected-access,g-complex-comprehension
|
# pylint: disable=protected-access,g-complex-comprehension
|
||||||
a, b = np_array_ops._promote_dtype(a, b)
|
a, b = np_array_ops._promote_dtype(a, b)
|
||||||
ndim = max(a.ndim, b.ndim)
|
t_a = np_utils.cond(
|
||||||
if a.ndim < ndim:
|
a.ndim < b.ndim,
|
||||||
a = np_array_ops.reshape(a, np_array_ops._pad_left_to(ndim, a.shape))
|
lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda
|
||||||
if b.ndim < ndim:
|
a.data, np_array_ops._pad_left_to(b.ndim, a.shape)),
|
||||||
b = np_array_ops.reshape(b, np_array_ops._pad_left_to(ndim, b.shape))
|
lambda: a.data)
|
||||||
a_reshaped = np_array_ops.reshape(a, [i for d in a.shape for i in (d, 1)])
|
t_b = np_utils.cond(
|
||||||
b_reshaped = np_array_ops.reshape(b, [i for d in b.shape for i in (1, d)])
|
b.ndim < a.ndim,
|
||||||
out_shape = tuple(np.multiply(a.shape, b.shape))
|
lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda
|
||||||
|
b.data, np_array_ops._pad_left_to(a.ndim, b.shape)),
|
||||||
|
lambda: b.data)
|
||||||
|
|
||||||
|
def _make_shape(shape, prepend):
|
||||||
|
ones = array_ops.ones_like(shape)
|
||||||
|
if prepend:
|
||||||
|
shapes = [ones, shape]
|
||||||
|
else:
|
||||||
|
shapes = [shape, ones]
|
||||||
|
return array_ops.reshape(array_ops.stack(shapes, axis=1), [-1])
|
||||||
|
|
||||||
|
a_shape = array_ops.shape(t_a)
|
||||||
|
b_shape = array_ops.shape(t_b)
|
||||||
|
a_reshaped = np_array_ops.reshape(t_a, _make_shape(a_shape, False))
|
||||||
|
b_reshaped = np_array_ops.reshape(t_b, _make_shape(b_shape, True))
|
||||||
|
out_shape = a_shape * b_shape
|
||||||
return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape)
|
return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape)
|
||||||
|
|
||||||
|
|
||||||
@ -454,7 +470,8 @@ def _tf_gcd(x1, x2): # pylint: disable=missing-function-docstring
|
|||||||
if (not np.issubdtype(x1.dtype.as_numpy_dtype, np.integer) or
|
if (not np.issubdtype(x1.dtype.as_numpy_dtype, np.integer) or
|
||||||
not np.issubdtype(x2.dtype.as_numpy_dtype, np.integer)):
|
not np.issubdtype(x2.dtype.as_numpy_dtype, np.integer)):
|
||||||
raise ValueError('Arguments to gcd must be integers.')
|
raise ValueError('Arguments to gcd must be integers.')
|
||||||
shape = array_ops.broadcast_static_shape(x1.shape, x2.shape)
|
shape = array_ops.broadcast_dynamic_shape(
|
||||||
|
array_ops.shape(x1), array_ops.shape(x2))
|
||||||
x1 = array_ops.broadcast_to(x1, shape)
|
x1 = array_ops.broadcast_to(x1, shape)
|
||||||
x2 = array_ops.broadcast_to(x2, shape)
|
x2 = array_ops.broadcast_to(x2, shape)
|
||||||
value, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn,
|
value, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn,
|
||||||
@ -607,7 +624,7 @@ def signbit(x):
|
|||||||
|
|
||||||
def f(x):
|
def f(x):
|
||||||
if x.dtype == dtypes.bool:
|
if x.dtype == dtypes.bool:
|
||||||
return array_ops.fill(x.shape, False)
|
return array_ops.fill(array_ops.shape(x), False)
|
||||||
return x < 0
|
return x < 0
|
||||||
|
|
||||||
return _scalar(f, x)
|
return _scalar(f, x)
|
||||||
@ -866,7 +883,11 @@ def square(x):
|
|||||||
def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring
|
def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring
|
||||||
|
|
||||||
def f(a):
|
def f(a):
|
||||||
|
# TODO(agarwal): transpose and reshape to N, H, 1 and do a 1D convolution
|
||||||
|
# TODO(agarwal): avoid depending on static rank.
|
||||||
nd = a.shape.rank
|
nd = a.shape.rank
|
||||||
|
if nd is None:
|
||||||
|
raise ValueError('diff currently requires known rank for input `a`')
|
||||||
if (axis + nd if axis < 0 else axis) >= nd:
|
if (axis + nd if axis < 0 else axis) >= nd:
|
||||||
raise ValueError('axis %s is out of bounds for array of dimension %s' %
|
raise ValueError('axis %s is out of bounds for array of dimension %s' %
|
||||||
(axis, nd))
|
(axis, nd))
|
||||||
@ -887,8 +908,10 @@ def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring
|
|||||||
|
|
||||||
|
|
||||||
def _flip_args(f):
|
def _flip_args(f):
|
||||||
|
|
||||||
def _f(a, b):
|
def _f(a, b):
|
||||||
return f(b, a)
|
return f(b, a)
|
||||||
|
|
||||||
return _f
|
return _f
|
||||||
|
|
||||||
|
|
||||||
@ -910,6 +933,7 @@ setattr(np_arrays.ndarray, '__rtruediv__', _flip_args(true_divide))
|
|||||||
|
|
||||||
|
|
||||||
def _comparison(tf_fun, x1, x2, cast_bool_to_int=False):
|
def _comparison(tf_fun, x1, x2, cast_bool_to_int=False):
|
||||||
|
"""Helper function for comparision."""
|
||||||
dtype = np_utils.result_type(x1, x2)
|
dtype = np_utils.result_type(x1, x2)
|
||||||
# Cast x1 and x2 to the result_type if needed.
|
# Cast x1 and x2 to the result_type if needed.
|
||||||
x1 = np_array_ops.array(x1, dtype=dtype)
|
x1 = np_array_ops.array(x1, dtype=dtype)
|
||||||
@ -953,12 +977,18 @@ def less_equal(x1, x2):
|
|||||||
|
|
||||||
|
|
||||||
@np_utils.np_doc(np.array_equal)
|
@np_utils.np_doc(np.array_equal)
|
||||||
def array_equal(a1, a2):
|
def array_equal(a1, a2): # pylint: disable=missing-function-docstring
|
||||||
|
|
||||||
def f(a1, a2):
|
def f(x1, x2):
|
||||||
if a1.shape != a2.shape:
|
return np_utils.cond(
|
||||||
return constant_op.constant(False)
|
math_ops.equal(array_ops.rank(x1), array_ops.rank(x2)),
|
||||||
return math_ops.reduce_all(math_ops.equal(a1, a2))
|
lambda: np_utils.cond( # pylint: disable=g-long-lambda
|
||||||
|
np_utils.reduce_all(
|
||||||
|
math_ops.equal(array_ops.shape(x1), array_ops.shape(x2))
|
||||||
|
),
|
||||||
|
lambda: math_ops.reduce_all(math_ops.equal(x1, x2)),
|
||||||
|
lambda: constant_op.constant(False)),
|
||||||
|
lambda: constant_op.constant(False))
|
||||||
|
|
||||||
return _comparison(f, a1, a2)
|
return _comparison(f, a1, a2)
|
||||||
|
|
||||||
@ -1001,7 +1031,13 @@ setattr(np_arrays.ndarray, '__ne__', not_equal)
|
|||||||
|
|
||||||
@np_utils.np_doc(np.linspace)
|
@np_utils.np_doc(np.linspace)
|
||||||
def linspace( # pylint: disable=missing-docstring
|
def linspace( # pylint: disable=missing-docstring
|
||||||
start, stop, num=50, endpoint=True, retstep=False, dtype=float, axis=0):
|
start,
|
||||||
|
stop,
|
||||||
|
num=50,
|
||||||
|
endpoint=True,
|
||||||
|
retstep=False,
|
||||||
|
dtype=float,
|
||||||
|
axis=0):
|
||||||
if dtype:
|
if dtype:
|
||||||
dtype = np_utils.result_type(dtype)
|
dtype = np_utils.result_type(dtype)
|
||||||
start = np_array_ops.array(start, dtype=dtype).data
|
start = np_array_ops.array(start, dtype=dtype).data
|
||||||
@ -1054,10 +1090,14 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): # pylint
|
|||||||
start_sign = 1 - np_array_ops.sign(np_array_ops.real(start))
|
start_sign = 1 - np_array_ops.sign(np_array_ops.real(start))
|
||||||
stop_sign = 1 - np_array_ops.sign(np_array_ops.real(stop))
|
stop_sign = 1 - np_array_ops.sign(np_array_ops.real(stop))
|
||||||
signflip = 1 - start_sign * stop_sign // 2
|
signflip = 1 - start_sign * stop_sign // 2
|
||||||
res = signflip * logspace(log10(signflip * start),
|
res = signflip * logspace(
|
||||||
log10(signflip * stop), num,
|
log10(signflip * start),
|
||||||
endpoint=endpoint, base=10.0,
|
log10(signflip * stop),
|
||||||
dtype=computation_dtype, axis=0)
|
num,
|
||||||
|
endpoint=endpoint,
|
||||||
|
base=10.0,
|
||||||
|
dtype=computation_dtype,
|
||||||
|
axis=0)
|
||||||
if axis != 0:
|
if axis != 0:
|
||||||
res = np_array_ops.moveaxis(res, 0, axis)
|
res = np_array_ops.moveaxis(res, 0, axis)
|
||||||
return np_utils.tensor_to_ndarray(math_ops.cast(res, dtype))
|
return np_utils.tensor_to_ndarray(math_ops.cast(res, dtype))
|
||||||
|
@ -47,7 +47,7 @@ def _canonicalize_axes(axes, rank):
|
|||||||
canonicalizer = (
|
canonicalizer = (
|
||||||
lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis))
|
lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis))
|
||||||
else:
|
else:
|
||||||
canonicalizer = lambda axis: axis+rank if axis < 0 else axis
|
canonicalizer = lambda axis: axis + rank if axis < 0 else axis
|
||||||
|
|
||||||
return [canonicalizer(axis) for axis in axes]
|
return [canonicalizer(axis) for axis in axes]
|
||||||
|
|
||||||
@ -100,9 +100,16 @@ def finfo(dtype):
|
|||||||
|
|
||||||
def isscalar(val):
|
def isscalar(val):
|
||||||
"""Returns whether `val` is a scalar value or scalar Tensor."""
|
"""Returns whether `val` is a scalar value or scalar Tensor."""
|
||||||
if isinstance(val, (np.ndarray, np_arrays.ndarray, ops.Tensor)):
|
if isinstance(val, np_arrays.ndarray):
|
||||||
return len(val.shape) == 0 # pylint: disable=g-explicit-length-test
|
val = val.data
|
||||||
return np.isscalar(val)
|
if isinstance(val, ops.Tensor):
|
||||||
|
ndims = val.shape.ndims
|
||||||
|
if ndims is not None:
|
||||||
|
return ndims == 0
|
||||||
|
else:
|
||||||
|
return math_ops.equal(array_ops.rank(val), 0)
|
||||||
|
else:
|
||||||
|
return np.isscalar(val)
|
||||||
|
|
||||||
|
|
||||||
# Can't use np_doc because np.result_type is a builtin function.
|
# Can't use np_doc because np.result_type is a builtin function.
|
||||||
@ -119,8 +126,8 @@ def result_type(*arrays_and_dtypes):
|
|||||||
def maybe_get_dtype(x):
|
def maybe_get_dtype(x):
|
||||||
# Don't put np.ndarray in this list, because np.result_type looks at the
|
# Don't put np.ndarray in this list, because np.result_type looks at the
|
||||||
# value (not just dtype) of np.ndarray to decide the result type.
|
# value (not just dtype) of np.ndarray to decide the result type.
|
||||||
if isinstance(x, (np_arrays.ndarray, ops.Tensor,
|
if isinstance(
|
||||||
indexed_slices.IndexedSlices)):
|
x, (np_arrays.ndarray, ops.Tensor, indexed_slices.IndexedSlices)):
|
||||||
return _to_numpy_type(x.dtype)
|
return _to_numpy_type(x.dtype)
|
||||||
elif isinstance(x, dtypes.DType):
|
elif isinstance(x, dtypes.DType):
|
||||||
return _to_numpy_type(x)
|
return _to_numpy_type(x)
|
||||||
@ -277,8 +284,11 @@ def np_doc(np_fun, np_fun_name=None):
|
|||||||
# for name in np_sig.parameters:
|
# for name in np_sig.parameters:
|
||||||
# if name not in sig.parameters:
|
# if name not in sig.parameters:
|
||||||
# unsupported_params.append(name)
|
# unsupported_params.append(name)
|
||||||
f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name,
|
f.__doc__ = _np_doc_helper(
|
||||||
unsupported_params=unsupported_params)
|
f,
|
||||||
|
np_fun,
|
||||||
|
np_fun_name=np_fun_name,
|
||||||
|
unsupported_params=unsupported_params)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
@ -287,9 +297,9 @@ def np_doc(np_fun, np_fun_name=None):
|
|||||||
def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None):
|
def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None):
|
||||||
"""Helper to get docs."""
|
"""Helper to get docs."""
|
||||||
if not unsupported_params and not _has_docstring(f) and _has_docstring(np_f):
|
if not unsupported_params and not _has_docstring(f) and _has_docstring(np_f):
|
||||||
# TODO(wangpeng): It looks like code snippets in numpy doc don't work
|
# TODO(wangpeng): It looks like code snippets in numpy doc don't work
|
||||||
# correctly with doctest. Fix that and remove the reformatting of the np_f
|
# correctly with doctest. Fix that and remove the reformatting of the np_f
|
||||||
# comment, here and below.
|
# comment, here and below.
|
||||||
return np_f.__doc__.replace('>>>', '>')
|
return np_f.__doc__.replace('>>>', '>')
|
||||||
assert np_f or np_fun_name
|
assert np_f or np_fun_name
|
||||||
if not np_fun_name:
|
if not np_fun_name:
|
||||||
|
@ -134,7 +134,7 @@ def _sort_or_argsort(values, axis, direction, return_argsort):
|
|||||||
# Axis must be an integer, not a Tensor.
|
# Axis must be an integer, not a Tensor.
|
||||||
axis = framework_ops.convert_to_tensor(axis, name='axis')
|
axis = framework_ops.convert_to_tensor(axis, name='axis')
|
||||||
axis_static = tensor_util.constant_value(axis)
|
axis_static = tensor_util.constant_value(axis)
|
||||||
if axis.shape.ndims != 0 or axis_static is None:
|
if axis.shape.ndims not in (None, 0) or axis_static is None:
|
||||||
raise ValueError('axis must be a constant scalar')
|
raise ValueError('axis must be a constant scalar')
|
||||||
axis_static = int(axis_static) # Avoids NumPy casting error
|
axis_static = int(axis_static) # Avoids NumPy casting error
|
||||||
|
|
||||||
@ -184,18 +184,8 @@ def _descending_sort(values, axis, return_argsort=False):
|
|||||||
name='transposition')
|
name='transposition')
|
||||||
else:
|
else:
|
||||||
# Generate the transposition array from the tensors.
|
# Generate the transposition array from the tensors.
|
||||||
transposition = array_ops.concat(
|
transposition = array_ops.tensor_scatter_update(
|
||||||
[
|
math_ops.range(rank), [[axis], [rank-1]], [rank-1, axis])
|
||||||
# Axes up to axis are unchanged.
|
|
||||||
math_ops.range(axis),
|
|
||||||
# Swap axis and rank - 1.
|
|
||||||
[rank - 1],
|
|
||||||
# Axes in [axis + 1, rank - 1) are unchanged.
|
|
||||||
math_ops.range(axis + 1, rank - 1),
|
|
||||||
# Swap axis and rank - 1.
|
|
||||||
[axis]
|
|
||||||
],
|
|
||||||
axis=0)
|
|
||||||
top_k_input = array_ops.transpose(values, transposition)
|
top_k_input = array_ops.transpose(values, transposition)
|
||||||
|
|
||||||
values, indices = nn_ops.top_k(top_k_input, k)
|
values, indices = nn_ops.top_k(top_k_input, k)
|
||||||
|
Loading…
Reference in New Issue
Block a user