tf.numpy: Change a bunch of ops to handle unknown shapes.
Fix logic in sort_ops to handle 1D values. PiperOrigin-RevId: 316620637 Change-Id: Iedc2ba8aad7673bbe210661bb741bf0660f047aa
This commit is contained in:
parent
e74010b4e8
commit
42a734170d
@ -223,9 +223,10 @@ def full(shape, fill_value, dtype=None): # pylint: disable=redefined-outer-name
|
||||
Raises:
|
||||
ValueError: if `fill_value` can not be broadcast to shape `shape`.
|
||||
"""
|
||||
if not isinstance(shape, np_arrays.ndarray):
|
||||
shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32))
|
||||
shape = atleast_1d(shape).data
|
||||
fill_value = asarray(fill_value, dtype=dtype)
|
||||
if np_utils.isscalar(shape):
|
||||
shape = array_ops.reshape(shape, [1])
|
||||
return np_arrays.tensor_to_ndarray(
|
||||
array_ops.broadcast_to(fill_value.data, shape))
|
||||
|
||||
@ -808,16 +809,21 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint: d
|
||||
@np_utils.np_doc(np.std)
|
||||
def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring
|
||||
return _reduce(
|
||||
math_ops.reduce_std, a, axis=axis, dtype=None, keepdims=keepdims,
|
||||
math_ops.reduce_std,
|
||||
a,
|
||||
axis=axis,
|
||||
dtype=None,
|
||||
keepdims=keepdims,
|
||||
promote_int=_TO_FLOAT)
|
||||
|
||||
|
||||
@np_utils.np_doc(np.ravel)
|
||||
def ravel(a): # pylint: disable=missing-docstring
|
||||
a = asarray(a)
|
||||
if a.ndim == 1:
|
||||
return a
|
||||
return np_utils.tensor_to_ndarray(array_ops.reshape(a.data, [-1]))
|
||||
out = np_utils.cond(
|
||||
math_ops.equal(a.ndim, 1), lambda: a.data,
|
||||
lambda: array_ops.reshape(a.data, [-1]))
|
||||
return np_utils.tensor_to_ndarray(out)
|
||||
|
||||
|
||||
setattr(np_arrays.ndarray, 'ravel', ravel)
|
||||
@ -846,7 +852,8 @@ def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring
|
||||
a = asarray(a).data
|
||||
original_shape = a._shape_as_list() # pylint: disable=protected-access
|
||||
# Best effort recovery of the shape.
|
||||
if original_shape is not None and None not in original_shape:
|
||||
known_shape = original_shape is not None and None not in original_shape
|
||||
if known_shape:
|
||||
if not original_shape:
|
||||
original_shape = (repeats,)
|
||||
else:
|
||||
@ -865,7 +872,8 @@ def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring
|
||||
|
||||
repeats = asarray(repeats).data
|
||||
result = array_ops.repeat(a, repeats, axis)
|
||||
result.set_shape(original_shape)
|
||||
if known_shape:
|
||||
result.set_shape(original_shape)
|
||||
|
||||
return np_utils.tensor_to_ndarray(result)
|
||||
|
||||
@ -1287,7 +1295,13 @@ def broadcast_to(array, shape): # pylint: disable=redefined-outer-name
|
||||
|
||||
|
||||
@np_utils.np_doc(np.stack)
|
||||
def stack(arrays, axis=0):
|
||||
def stack(arrays, axis=0): # pylint: disable=missing-function-docstring
|
||||
if isinstance(arrays, (np_arrays.ndarray, ops.Tensor)):
|
||||
arrays = asarray(arrays)
|
||||
if axis == 0:
|
||||
return arrays
|
||||
else:
|
||||
return swapaxes(arrays, 0, axis)
|
||||
arrays = _promote_dtype(*arrays) # pylint: disable=protected-access
|
||||
unwrapped_arrays = [
|
||||
a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays
|
||||
@ -1450,6 +1464,8 @@ def tri(N, M=None, k=0, dtype=None): # pylint: disable=invalid-name,missing-doc
|
||||
@np_utils.np_doc(np.tril)
|
||||
def tril(m, k=0): # pylint: disable=missing-docstring
|
||||
m = asarray(m).data
|
||||
if m.shape.ndims is None:
|
||||
raise ValueError('Argument to tril should have known rank')
|
||||
m_shape = m.shape.as_list()
|
||||
|
||||
if len(m_shape) < 2:
|
||||
@ -1470,6 +1486,8 @@ def tril(m, k=0): # pylint: disable=missing-docstring
|
||||
@np_utils.np_doc(np.triu)
|
||||
def triu(m, k=0): # pylint: disable=missing-docstring
|
||||
m = asarray(m).data
|
||||
if m.shape.ndims is None:
|
||||
raise ValueError('Argument to triu should have known rank')
|
||||
m_shape = m.shape.as_list()
|
||||
|
||||
if len(m_shape) < 2:
|
||||
|
@ -13,6 +13,9 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ndarray class."""
|
||||
|
||||
# pylint: disable=g-direct-tensorflow-import
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -151,13 +154,16 @@ def _slice_helper(tensor, slice_spec, var=None):
|
||||
def convert_to_tensor(value, dtype=None, dtype_hint=None):
|
||||
"""Wrapper over `tf.convert_to_tensor`.
|
||||
|
||||
Args:
|
||||
value: value to convert
|
||||
dtype: (optional) the type we would like it to be converted to.
|
||||
dtype_hint: (optional) soft preference for the type we would like it to
|
||||
be converted to. `tf.convert_to_tensor` will attempt to convert value
|
||||
to this type first, but will not fail if conversion is not possible
|
||||
falling back to inferring the type instead.
|
||||
Args:
|
||||
value: value to convert
|
||||
dtype: (optional) the type we would like it to be converted to.
|
||||
dtype_hint: (optional) soft preference for the type we would like it to be
|
||||
converted to. `tf.convert_to_tensor` will attempt to convert value to this
|
||||
type first, but will not fail if conversion is not possible falling back
|
||||
to inferring the type instead.
|
||||
|
||||
Returns:
|
||||
Value converted to tf.Tensor.
|
||||
"""
|
||||
# A safer version of `tf.convert_to_tensor` to work around b/149876037.
|
||||
# TODO(wangpeng): Remove this function once the bug is fixed.
|
||||
@ -250,8 +256,12 @@ class ndarray(object): # pylint: disable=invalid-name
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""Returns a tuple of array dimensions."""
|
||||
return self.data._shape_tuple() # pylint: disable=protected-access
|
||||
"""Returns a tuple or tf.Tensor of array dimensions."""
|
||||
shape = self.data.shape
|
||||
if shape.is_fully_defined():
|
||||
return tuple(shape.as_list())
|
||||
else:
|
||||
return array_ops.shape(self.data)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
@ -259,19 +269,30 @@ class ndarray(object): # pylint: disable=invalid-name
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return self.data.shape.ndims
|
||||
ndims = self.data.shape.ndims
|
||||
if ndims is None:
|
||||
return array_ops.rank(self.data)
|
||||
else:
|
||||
return ndims
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
"""Returns the number of elements in the array."""
|
||||
return np.prod(self.shape)
|
||||
shape = self.shape
|
||||
if isinstance(shape, ops.Tensor):
|
||||
return array_ops.size(self.data)
|
||||
else:
|
||||
return np.prod(self.shape)
|
||||
|
||||
@property
|
||||
def T(self): # pylint: disable=invalid-name
|
||||
return self.transpose()
|
||||
|
||||
def __len__(self):
|
||||
if self.shape:
|
||||
shape = self.shape
|
||||
if isinstance(shape, ops.Tensor):
|
||||
raise TypeError('len() of symbolic tensor undefined')
|
||||
elif shape:
|
||||
return self.shape[0]
|
||||
else:
|
||||
raise TypeError('len() of unsized object.')
|
||||
@ -320,6 +341,8 @@ class ndarray(object): # pylint: disable=invalid-name
|
||||
return tensor_to_ndarray(result_t)
|
||||
|
||||
def __iter__(self):
|
||||
if not isinstance(self.data, ops.EagerTensor):
|
||||
raise TypeError('Iteration over symbolic tensor is not allowed')
|
||||
for i in range(self.shape[0]):
|
||||
result_t = self.data[i]
|
||||
yield tensor_to_ndarray(result_t)
|
||||
@ -356,6 +379,8 @@ class ndarray(object): # pylint: disable=invalid-name
|
||||
ValueError: If the array does not have size 1.
|
||||
"""
|
||||
# TODO(wangpeng): Handle graph mode
|
||||
if not isinstance(self.data, ops.EagerTensor):
|
||||
raise TypeError('Indexing using symbolic tensor is not allowed')
|
||||
return np.asscalar(self.data.numpy())
|
||||
|
||||
def tolist(self):
|
||||
@ -384,5 +409,3 @@ def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False):
|
||||
|
||||
|
||||
ops.register_tensor_conversion_function(ndarray, ndarray_to_tensor)
|
||||
|
||||
|
||||
|
@ -50,9 +50,9 @@ def dot(a, b): # pylint: disable=missing-docstring
|
||||
math_ops.equal(array_ops.rank(b), 0)),
|
||||
lambda: a * b,
|
||||
lambda: np_utils.cond( # pylint: disable=g-long-lambda
|
||||
math_ops.equal(array_ops.rank(b), 1), lambda: math_ops.tensordot(
|
||||
a, b, axes=[[-1], [-1]]), lambda: math_ops.tensordot(
|
||||
a, b, axes=[[-1], [-2]])))
|
||||
math_ops.equal(array_ops.rank(b), 1),
|
||||
lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]]),
|
||||
lambda: math_ops.tensordot(a, b, axes=[[-1], [-2]])))
|
||||
|
||||
return _bin_op(f, a, b)
|
||||
|
||||
@ -204,8 +204,8 @@ def matmul(x1, x2): # pylint: disable=missing-docstring
|
||||
return np_utils.cond(
|
||||
math_ops.equal(array_ops.rank(x2), 1),
|
||||
lambda: math_ops.tensordot(x1, x2, axes=1),
|
||||
lambda: np_utils.cond(
|
||||
math_ops.equal(array_ops.rank(x1), 1), # pylint: disable=g-long-lambda
|
||||
lambda: np_utils.cond( # pylint: disable=g-long-lambda
|
||||
math_ops.equal(array_ops.rank(x1), 1),
|
||||
lambda: math_ops.tensordot( # pylint: disable=g-long-lambda
|
||||
x1, x2, axes=[[0], [-2]]),
|
||||
lambda: math_ops.matmul(x1, x2)))
|
||||
@ -352,14 +352,30 @@ def hypot(x1, x2):
|
||||
def kron(a, b): # pylint: disable=missing-function-docstring
|
||||
# pylint: disable=protected-access,g-complex-comprehension
|
||||
a, b = np_array_ops._promote_dtype(a, b)
|
||||
ndim = max(a.ndim, b.ndim)
|
||||
if a.ndim < ndim:
|
||||
a = np_array_ops.reshape(a, np_array_ops._pad_left_to(ndim, a.shape))
|
||||
if b.ndim < ndim:
|
||||
b = np_array_ops.reshape(b, np_array_ops._pad_left_to(ndim, b.shape))
|
||||
a_reshaped = np_array_ops.reshape(a, [i for d in a.shape for i in (d, 1)])
|
||||
b_reshaped = np_array_ops.reshape(b, [i for d in b.shape for i in (1, d)])
|
||||
out_shape = tuple(np.multiply(a.shape, b.shape))
|
||||
t_a = np_utils.cond(
|
||||
a.ndim < b.ndim,
|
||||
lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda
|
||||
a.data, np_array_ops._pad_left_to(b.ndim, a.shape)),
|
||||
lambda: a.data)
|
||||
t_b = np_utils.cond(
|
||||
b.ndim < a.ndim,
|
||||
lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda
|
||||
b.data, np_array_ops._pad_left_to(a.ndim, b.shape)),
|
||||
lambda: b.data)
|
||||
|
||||
def _make_shape(shape, prepend):
|
||||
ones = array_ops.ones_like(shape)
|
||||
if prepend:
|
||||
shapes = [ones, shape]
|
||||
else:
|
||||
shapes = [shape, ones]
|
||||
return array_ops.reshape(array_ops.stack(shapes, axis=1), [-1])
|
||||
|
||||
a_shape = array_ops.shape(t_a)
|
||||
b_shape = array_ops.shape(t_b)
|
||||
a_reshaped = np_array_ops.reshape(t_a, _make_shape(a_shape, False))
|
||||
b_reshaped = np_array_ops.reshape(t_b, _make_shape(b_shape, True))
|
||||
out_shape = a_shape * b_shape
|
||||
return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape)
|
||||
|
||||
|
||||
@ -454,7 +470,8 @@ def _tf_gcd(x1, x2): # pylint: disable=missing-function-docstring
|
||||
if (not np.issubdtype(x1.dtype.as_numpy_dtype, np.integer) or
|
||||
not np.issubdtype(x2.dtype.as_numpy_dtype, np.integer)):
|
||||
raise ValueError('Arguments to gcd must be integers.')
|
||||
shape = array_ops.broadcast_static_shape(x1.shape, x2.shape)
|
||||
shape = array_ops.broadcast_dynamic_shape(
|
||||
array_ops.shape(x1), array_ops.shape(x2))
|
||||
x1 = array_ops.broadcast_to(x1, shape)
|
||||
x2 = array_ops.broadcast_to(x2, shape)
|
||||
value, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn,
|
||||
@ -607,7 +624,7 @@ def signbit(x):
|
||||
|
||||
def f(x):
|
||||
if x.dtype == dtypes.bool:
|
||||
return array_ops.fill(x.shape, False)
|
||||
return array_ops.fill(array_ops.shape(x), False)
|
||||
return x < 0
|
||||
|
||||
return _scalar(f, x)
|
||||
@ -866,7 +883,11 @@ def square(x):
|
||||
def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring
|
||||
|
||||
def f(a):
|
||||
# TODO(agarwal): transpose and reshape to N, H, 1 and do a 1D convolution
|
||||
# TODO(agarwal): avoid depending on static rank.
|
||||
nd = a.shape.rank
|
||||
if nd is None:
|
||||
raise ValueError('diff currently requires known rank for input `a`')
|
||||
if (axis + nd if axis < 0 else axis) >= nd:
|
||||
raise ValueError('axis %s is out of bounds for array of dimension %s' %
|
||||
(axis, nd))
|
||||
@ -887,8 +908,10 @@ def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring
|
||||
|
||||
|
||||
def _flip_args(f):
|
||||
|
||||
def _f(a, b):
|
||||
return f(b, a)
|
||||
|
||||
return _f
|
||||
|
||||
|
||||
@ -910,6 +933,7 @@ setattr(np_arrays.ndarray, '__rtruediv__', _flip_args(true_divide))
|
||||
|
||||
|
||||
def _comparison(tf_fun, x1, x2, cast_bool_to_int=False):
|
||||
"""Helper function for comparision."""
|
||||
dtype = np_utils.result_type(x1, x2)
|
||||
# Cast x1 and x2 to the result_type if needed.
|
||||
x1 = np_array_ops.array(x1, dtype=dtype)
|
||||
@ -953,12 +977,18 @@ def less_equal(x1, x2):
|
||||
|
||||
|
||||
@np_utils.np_doc(np.array_equal)
|
||||
def array_equal(a1, a2):
|
||||
def array_equal(a1, a2): # pylint: disable=missing-function-docstring
|
||||
|
||||
def f(a1, a2):
|
||||
if a1.shape != a2.shape:
|
||||
return constant_op.constant(False)
|
||||
return math_ops.reduce_all(math_ops.equal(a1, a2))
|
||||
def f(x1, x2):
|
||||
return np_utils.cond(
|
||||
math_ops.equal(array_ops.rank(x1), array_ops.rank(x2)),
|
||||
lambda: np_utils.cond( # pylint: disable=g-long-lambda
|
||||
np_utils.reduce_all(
|
||||
math_ops.equal(array_ops.shape(x1), array_ops.shape(x2))
|
||||
),
|
||||
lambda: math_ops.reduce_all(math_ops.equal(x1, x2)),
|
||||
lambda: constant_op.constant(False)),
|
||||
lambda: constant_op.constant(False))
|
||||
|
||||
return _comparison(f, a1, a2)
|
||||
|
||||
@ -1001,7 +1031,13 @@ setattr(np_arrays.ndarray, '__ne__', not_equal)
|
||||
|
||||
@np_utils.np_doc(np.linspace)
|
||||
def linspace( # pylint: disable=missing-docstring
|
||||
start, stop, num=50, endpoint=True, retstep=False, dtype=float, axis=0):
|
||||
start,
|
||||
stop,
|
||||
num=50,
|
||||
endpoint=True,
|
||||
retstep=False,
|
||||
dtype=float,
|
||||
axis=0):
|
||||
if dtype:
|
||||
dtype = np_utils.result_type(dtype)
|
||||
start = np_array_ops.array(start, dtype=dtype).data
|
||||
@ -1054,10 +1090,14 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): # pylint
|
||||
start_sign = 1 - np_array_ops.sign(np_array_ops.real(start))
|
||||
stop_sign = 1 - np_array_ops.sign(np_array_ops.real(stop))
|
||||
signflip = 1 - start_sign * stop_sign // 2
|
||||
res = signflip * logspace(log10(signflip * start),
|
||||
log10(signflip * stop), num,
|
||||
endpoint=endpoint, base=10.0,
|
||||
dtype=computation_dtype, axis=0)
|
||||
res = signflip * logspace(
|
||||
log10(signflip * start),
|
||||
log10(signflip * stop),
|
||||
num,
|
||||
endpoint=endpoint,
|
||||
base=10.0,
|
||||
dtype=computation_dtype,
|
||||
axis=0)
|
||||
if axis != 0:
|
||||
res = np_array_ops.moveaxis(res, 0, axis)
|
||||
return np_utils.tensor_to_ndarray(math_ops.cast(res, dtype))
|
||||
|
@ -47,7 +47,7 @@ def _canonicalize_axes(axes, rank):
|
||||
canonicalizer = (
|
||||
lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis))
|
||||
else:
|
||||
canonicalizer = lambda axis: axis+rank if axis < 0 else axis
|
||||
canonicalizer = lambda axis: axis + rank if axis < 0 else axis
|
||||
|
||||
return [canonicalizer(axis) for axis in axes]
|
||||
|
||||
@ -100,9 +100,16 @@ def finfo(dtype):
|
||||
|
||||
def isscalar(val):
|
||||
"""Returns whether `val` is a scalar value or scalar Tensor."""
|
||||
if isinstance(val, (np.ndarray, np_arrays.ndarray, ops.Tensor)):
|
||||
return len(val.shape) == 0 # pylint: disable=g-explicit-length-test
|
||||
return np.isscalar(val)
|
||||
if isinstance(val, np_arrays.ndarray):
|
||||
val = val.data
|
||||
if isinstance(val, ops.Tensor):
|
||||
ndims = val.shape.ndims
|
||||
if ndims is not None:
|
||||
return ndims == 0
|
||||
else:
|
||||
return math_ops.equal(array_ops.rank(val), 0)
|
||||
else:
|
||||
return np.isscalar(val)
|
||||
|
||||
|
||||
# Can't use np_doc because np.result_type is a builtin function.
|
||||
@ -119,8 +126,8 @@ def result_type(*arrays_and_dtypes):
|
||||
def maybe_get_dtype(x):
|
||||
# Don't put np.ndarray in this list, because np.result_type looks at the
|
||||
# value (not just dtype) of np.ndarray to decide the result type.
|
||||
if isinstance(x, (np_arrays.ndarray, ops.Tensor,
|
||||
indexed_slices.IndexedSlices)):
|
||||
if isinstance(
|
||||
x, (np_arrays.ndarray, ops.Tensor, indexed_slices.IndexedSlices)):
|
||||
return _to_numpy_type(x.dtype)
|
||||
elif isinstance(x, dtypes.DType):
|
||||
return _to_numpy_type(x)
|
||||
@ -277,8 +284,11 @@ def np_doc(np_fun, np_fun_name=None):
|
||||
# for name in np_sig.parameters:
|
||||
# if name not in sig.parameters:
|
||||
# unsupported_params.append(name)
|
||||
f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name,
|
||||
unsupported_params=unsupported_params)
|
||||
f.__doc__ = _np_doc_helper(
|
||||
f,
|
||||
np_fun,
|
||||
np_fun_name=np_fun_name,
|
||||
unsupported_params=unsupported_params)
|
||||
return f
|
||||
|
||||
return decorator
|
||||
@ -287,9 +297,9 @@ def np_doc(np_fun, np_fun_name=None):
|
||||
def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None):
|
||||
"""Helper to get docs."""
|
||||
if not unsupported_params and not _has_docstring(f) and _has_docstring(np_f):
|
||||
# TODO(wangpeng): It looks like code snippets in numpy doc don't work
|
||||
# correctly with doctest. Fix that and remove the reformatting of the np_f
|
||||
# comment, here and below.
|
||||
# TODO(wangpeng): It looks like code snippets in numpy doc don't work
|
||||
# correctly with doctest. Fix that and remove the reformatting of the np_f
|
||||
# comment, here and below.
|
||||
return np_f.__doc__.replace('>>>', '>')
|
||||
assert np_f or np_fun_name
|
||||
if not np_fun_name:
|
||||
|
@ -134,7 +134,7 @@ def _sort_or_argsort(values, axis, direction, return_argsort):
|
||||
# Axis must be an integer, not a Tensor.
|
||||
axis = framework_ops.convert_to_tensor(axis, name='axis')
|
||||
axis_static = tensor_util.constant_value(axis)
|
||||
if axis.shape.ndims != 0 or axis_static is None:
|
||||
if axis.shape.ndims not in (None, 0) or axis_static is None:
|
||||
raise ValueError('axis must be a constant scalar')
|
||||
axis_static = int(axis_static) # Avoids NumPy casting error
|
||||
|
||||
@ -184,18 +184,8 @@ def _descending_sort(values, axis, return_argsort=False):
|
||||
name='transposition')
|
||||
else:
|
||||
# Generate the transposition array from the tensors.
|
||||
transposition = array_ops.concat(
|
||||
[
|
||||
# Axes up to axis are unchanged.
|
||||
math_ops.range(axis),
|
||||
# Swap axis and rank - 1.
|
||||
[rank - 1],
|
||||
# Axes in [axis + 1, rank - 1) are unchanged.
|
||||
math_ops.range(axis + 1, rank - 1),
|
||||
# Swap axis and rank - 1.
|
||||
[axis]
|
||||
],
|
||||
axis=0)
|
||||
transposition = array_ops.tensor_scatter_update(
|
||||
math_ops.range(rank), [[axis], [rank-1]], [rank-1, axis])
|
||||
top_k_input = array_ops.transpose(values, transposition)
|
||||
|
||||
values, indices = nn_ops.top_k(top_k_input, k)
|
||||
|
Loading…
Reference in New Issue
Block a user