tf.numpy: Change a bunch of ops to handle unknown shapes.

Fix logic in sort_ops to handle 1D values.

PiperOrigin-RevId: 316620637
Change-Id: Iedc2ba8aad7673bbe210661bb741bf0660f047aa
This commit is contained in:
A. Unique TensorFlower 2020-06-15 23:00:10 -07:00 committed by TensorFlower Gardener
parent e74010b4e8
commit 42a734170d
5 changed files with 153 additions and 72 deletions

View File

@ -223,9 +223,10 @@ def full(shape, fill_value, dtype=None): # pylint: disable=redefined-outer-name
Raises:
ValueError: if `fill_value` can not be broadcast to shape `shape`.
"""
if not isinstance(shape, np_arrays.ndarray):
shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32))
shape = atleast_1d(shape).data
fill_value = asarray(fill_value, dtype=dtype)
if np_utils.isscalar(shape):
shape = array_ops.reshape(shape, [1])
return np_arrays.tensor_to_ndarray(
array_ops.broadcast_to(fill_value.data, shape))
@ -808,16 +809,21 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint: d
@np_utils.np_doc(np.std)
def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring
return _reduce(
math_ops.reduce_std, a, axis=axis, dtype=None, keepdims=keepdims,
math_ops.reduce_std,
a,
axis=axis,
dtype=None,
keepdims=keepdims,
promote_int=_TO_FLOAT)
@np_utils.np_doc(np.ravel)
def ravel(a): # pylint: disable=missing-docstring
a = asarray(a)
if a.ndim == 1:
return a
return np_utils.tensor_to_ndarray(array_ops.reshape(a.data, [-1]))
out = np_utils.cond(
math_ops.equal(a.ndim, 1), lambda: a.data,
lambda: array_ops.reshape(a.data, [-1]))
return np_utils.tensor_to_ndarray(out)
setattr(np_arrays.ndarray, 'ravel', ravel)
@ -846,7 +852,8 @@ def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring
a = asarray(a).data
original_shape = a._shape_as_list() # pylint: disable=protected-access
# Best effort recovery of the shape.
if original_shape is not None and None not in original_shape:
known_shape = original_shape is not None and None not in original_shape
if known_shape:
if not original_shape:
original_shape = (repeats,)
else:
@ -865,7 +872,8 @@ def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring
repeats = asarray(repeats).data
result = array_ops.repeat(a, repeats, axis)
result.set_shape(original_shape)
if known_shape:
result.set_shape(original_shape)
return np_utils.tensor_to_ndarray(result)
@ -1287,7 +1295,13 @@ def broadcast_to(array, shape): # pylint: disable=redefined-outer-name
@np_utils.np_doc(np.stack)
def stack(arrays, axis=0):
def stack(arrays, axis=0): # pylint: disable=missing-function-docstring
if isinstance(arrays, (np_arrays.ndarray, ops.Tensor)):
arrays = asarray(arrays)
if axis == 0:
return arrays
else:
return swapaxes(arrays, 0, axis)
arrays = _promote_dtype(*arrays) # pylint: disable=protected-access
unwrapped_arrays = [
a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays
@ -1450,6 +1464,8 @@ def tri(N, M=None, k=0, dtype=None): # pylint: disable=invalid-name,missing-doc
@np_utils.np_doc(np.tril)
def tril(m, k=0): # pylint: disable=missing-docstring
m = asarray(m).data
if m.shape.ndims is None:
raise ValueError('Argument to tril should have known rank')
m_shape = m.shape.as_list()
if len(m_shape) < 2:
@ -1470,6 +1486,8 @@ def tril(m, k=0): # pylint: disable=missing-docstring
@np_utils.np_doc(np.triu)
def triu(m, k=0): # pylint: disable=missing-docstring
m = asarray(m).data
if m.shape.ndims is None:
raise ValueError('Argument to triu should have known rank')
m_shape = m.shape.as_list()
if len(m_shape) < 2:

View File

@ -13,6 +13,9 @@
# limitations under the License.
# ==============================================================================
"""ndarray class."""
# pylint: disable=g-direct-tensorflow-import
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -151,13 +154,16 @@ def _slice_helper(tensor, slice_spec, var=None):
def convert_to_tensor(value, dtype=None, dtype_hint=None):
"""Wrapper over `tf.convert_to_tensor`.
Args:
value: value to convert
dtype: (optional) the type we would like it to be converted to.
dtype_hint: (optional) soft preference for the type we would like it to
be converted to. `tf.convert_to_tensor` will attempt to convert value
to this type first, but will not fail if conversion is not possible
falling back to inferring the type instead.
Args:
value: value to convert
dtype: (optional) the type we would like it to be converted to.
dtype_hint: (optional) soft preference for the type we would like it to be
converted to. `tf.convert_to_tensor` will attempt to convert value to this
type first, but will not fail if conversion is not possible falling back
to inferring the type instead.
Returns:
Value converted to tf.Tensor.
"""
# A safer version of `tf.convert_to_tensor` to work around b/149876037.
# TODO(wangpeng): Remove this function once the bug is fixed.
@ -250,8 +256,12 @@ class ndarray(object): # pylint: disable=invalid-name
@property
def shape(self):
"""Returns a tuple of array dimensions."""
return self.data._shape_tuple() # pylint: disable=protected-access
"""Returns a tuple or tf.Tensor of array dimensions."""
shape = self.data.shape
if shape.is_fully_defined():
return tuple(shape.as_list())
else:
return array_ops.shape(self.data)
@property
def dtype(self):
@ -259,19 +269,30 @@ class ndarray(object): # pylint: disable=invalid-name
@property
def ndim(self):
return self.data.shape.ndims
ndims = self.data.shape.ndims
if ndims is None:
return array_ops.rank(self.data)
else:
return ndims
@property
def size(self):
"""Returns the number of elements in the array."""
return np.prod(self.shape)
shape = self.shape
if isinstance(shape, ops.Tensor):
return array_ops.size(self.data)
else:
return np.prod(self.shape)
@property
def T(self): # pylint: disable=invalid-name
return self.transpose()
def __len__(self):
if self.shape:
shape = self.shape
if isinstance(shape, ops.Tensor):
raise TypeError('len() of symbolic tensor undefined')
elif shape:
return self.shape[0]
else:
raise TypeError('len() of unsized object.')
@ -320,6 +341,8 @@ class ndarray(object): # pylint: disable=invalid-name
return tensor_to_ndarray(result_t)
def __iter__(self):
if not isinstance(self.data, ops.EagerTensor):
raise TypeError('Iteration over symbolic tensor is not allowed')
for i in range(self.shape[0]):
result_t = self.data[i]
yield tensor_to_ndarray(result_t)
@ -356,6 +379,8 @@ class ndarray(object): # pylint: disable=invalid-name
ValueError: If the array does not have size 1.
"""
# TODO(wangpeng): Handle graph mode
if not isinstance(self.data, ops.EagerTensor):
raise TypeError('Indexing using symbolic tensor is not allowed')
return np.asscalar(self.data.numpy())
def tolist(self):
@ -384,5 +409,3 @@ def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False):
ops.register_tensor_conversion_function(ndarray, ndarray_to_tensor)

View File

@ -50,9 +50,9 @@ def dot(a, b): # pylint: disable=missing-docstring
math_ops.equal(array_ops.rank(b), 0)),
lambda: a * b,
lambda: np_utils.cond( # pylint: disable=g-long-lambda
math_ops.equal(array_ops.rank(b), 1), lambda: math_ops.tensordot(
a, b, axes=[[-1], [-1]]), lambda: math_ops.tensordot(
a, b, axes=[[-1], [-2]])))
math_ops.equal(array_ops.rank(b), 1),
lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]]),
lambda: math_ops.tensordot(a, b, axes=[[-1], [-2]])))
return _bin_op(f, a, b)
@ -204,8 +204,8 @@ def matmul(x1, x2): # pylint: disable=missing-docstring
return np_utils.cond(
math_ops.equal(array_ops.rank(x2), 1),
lambda: math_ops.tensordot(x1, x2, axes=1),
lambda: np_utils.cond(
math_ops.equal(array_ops.rank(x1), 1), # pylint: disable=g-long-lambda
lambda: np_utils.cond( # pylint: disable=g-long-lambda
math_ops.equal(array_ops.rank(x1), 1),
lambda: math_ops.tensordot( # pylint: disable=g-long-lambda
x1, x2, axes=[[0], [-2]]),
lambda: math_ops.matmul(x1, x2)))
@ -352,14 +352,30 @@ def hypot(x1, x2):
def kron(a, b): # pylint: disable=missing-function-docstring
# pylint: disable=protected-access,g-complex-comprehension
a, b = np_array_ops._promote_dtype(a, b)
ndim = max(a.ndim, b.ndim)
if a.ndim < ndim:
a = np_array_ops.reshape(a, np_array_ops._pad_left_to(ndim, a.shape))
if b.ndim < ndim:
b = np_array_ops.reshape(b, np_array_ops._pad_left_to(ndim, b.shape))
a_reshaped = np_array_ops.reshape(a, [i for d in a.shape for i in (d, 1)])
b_reshaped = np_array_ops.reshape(b, [i for d in b.shape for i in (1, d)])
out_shape = tuple(np.multiply(a.shape, b.shape))
t_a = np_utils.cond(
a.ndim < b.ndim,
lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda
a.data, np_array_ops._pad_left_to(b.ndim, a.shape)),
lambda: a.data)
t_b = np_utils.cond(
b.ndim < a.ndim,
lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda
b.data, np_array_ops._pad_left_to(a.ndim, b.shape)),
lambda: b.data)
def _make_shape(shape, prepend):
ones = array_ops.ones_like(shape)
if prepend:
shapes = [ones, shape]
else:
shapes = [shape, ones]
return array_ops.reshape(array_ops.stack(shapes, axis=1), [-1])
a_shape = array_ops.shape(t_a)
b_shape = array_ops.shape(t_b)
a_reshaped = np_array_ops.reshape(t_a, _make_shape(a_shape, False))
b_reshaped = np_array_ops.reshape(t_b, _make_shape(b_shape, True))
out_shape = a_shape * b_shape
return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape)
@ -454,7 +470,8 @@ def _tf_gcd(x1, x2): # pylint: disable=missing-function-docstring
if (not np.issubdtype(x1.dtype.as_numpy_dtype, np.integer) or
not np.issubdtype(x2.dtype.as_numpy_dtype, np.integer)):
raise ValueError('Arguments to gcd must be integers.')
shape = array_ops.broadcast_static_shape(x1.shape, x2.shape)
shape = array_ops.broadcast_dynamic_shape(
array_ops.shape(x1), array_ops.shape(x2))
x1 = array_ops.broadcast_to(x1, shape)
x2 = array_ops.broadcast_to(x2, shape)
value, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn,
@ -607,7 +624,7 @@ def signbit(x):
def f(x):
if x.dtype == dtypes.bool:
return array_ops.fill(x.shape, False)
return array_ops.fill(array_ops.shape(x), False)
return x < 0
return _scalar(f, x)
@ -866,7 +883,11 @@ def square(x):
def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring
def f(a):
# TODO(agarwal): transpose and reshape to N, H, 1 and do a 1D convolution
# TODO(agarwal): avoid depending on static rank.
nd = a.shape.rank
if nd is None:
raise ValueError('diff currently requires known rank for input `a`')
if (axis + nd if axis < 0 else axis) >= nd:
raise ValueError('axis %s is out of bounds for array of dimension %s' %
(axis, nd))
@ -887,8 +908,10 @@ def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring
def _flip_args(f):
def _f(a, b):
return f(b, a)
return _f
@ -910,6 +933,7 @@ setattr(np_arrays.ndarray, '__rtruediv__', _flip_args(true_divide))
def _comparison(tf_fun, x1, x2, cast_bool_to_int=False):
"""Helper function for comparision."""
dtype = np_utils.result_type(x1, x2)
# Cast x1 and x2 to the result_type if needed.
x1 = np_array_ops.array(x1, dtype=dtype)
@ -953,12 +977,18 @@ def less_equal(x1, x2):
@np_utils.np_doc(np.array_equal)
def array_equal(a1, a2):
def array_equal(a1, a2): # pylint: disable=missing-function-docstring
def f(a1, a2):
if a1.shape != a2.shape:
return constant_op.constant(False)
return math_ops.reduce_all(math_ops.equal(a1, a2))
def f(x1, x2):
return np_utils.cond(
math_ops.equal(array_ops.rank(x1), array_ops.rank(x2)),
lambda: np_utils.cond( # pylint: disable=g-long-lambda
np_utils.reduce_all(
math_ops.equal(array_ops.shape(x1), array_ops.shape(x2))
),
lambda: math_ops.reduce_all(math_ops.equal(x1, x2)),
lambda: constant_op.constant(False)),
lambda: constant_op.constant(False))
return _comparison(f, a1, a2)
@ -1001,7 +1031,13 @@ setattr(np_arrays.ndarray, '__ne__', not_equal)
@np_utils.np_doc(np.linspace)
def linspace( # pylint: disable=missing-docstring
start, stop, num=50, endpoint=True, retstep=False, dtype=float, axis=0):
start,
stop,
num=50,
endpoint=True,
retstep=False,
dtype=float,
axis=0):
if dtype:
dtype = np_utils.result_type(dtype)
start = np_array_ops.array(start, dtype=dtype).data
@ -1054,10 +1090,14 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): # pylint
start_sign = 1 - np_array_ops.sign(np_array_ops.real(start))
stop_sign = 1 - np_array_ops.sign(np_array_ops.real(stop))
signflip = 1 - start_sign * stop_sign // 2
res = signflip * logspace(log10(signflip * start),
log10(signflip * stop), num,
endpoint=endpoint, base=10.0,
dtype=computation_dtype, axis=0)
res = signflip * logspace(
log10(signflip * start),
log10(signflip * stop),
num,
endpoint=endpoint,
base=10.0,
dtype=computation_dtype,
axis=0)
if axis != 0:
res = np_array_ops.moveaxis(res, 0, axis)
return np_utils.tensor_to_ndarray(math_ops.cast(res, dtype))

View File

@ -47,7 +47,7 @@ def _canonicalize_axes(axes, rank):
canonicalizer = (
lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis))
else:
canonicalizer = lambda axis: axis+rank if axis < 0 else axis
canonicalizer = lambda axis: axis + rank if axis < 0 else axis
return [canonicalizer(axis) for axis in axes]
@ -100,9 +100,16 @@ def finfo(dtype):
def isscalar(val):
"""Returns whether `val` is a scalar value or scalar Tensor."""
if isinstance(val, (np.ndarray, np_arrays.ndarray, ops.Tensor)):
return len(val.shape) == 0 # pylint: disable=g-explicit-length-test
return np.isscalar(val)
if isinstance(val, np_arrays.ndarray):
val = val.data
if isinstance(val, ops.Tensor):
ndims = val.shape.ndims
if ndims is not None:
return ndims == 0
else:
return math_ops.equal(array_ops.rank(val), 0)
else:
return np.isscalar(val)
# Can't use np_doc because np.result_type is a builtin function.
@ -119,8 +126,8 @@ def result_type(*arrays_and_dtypes):
def maybe_get_dtype(x):
# Don't put np.ndarray in this list, because np.result_type looks at the
# value (not just dtype) of np.ndarray to decide the result type.
if isinstance(x, (np_arrays.ndarray, ops.Tensor,
indexed_slices.IndexedSlices)):
if isinstance(
x, (np_arrays.ndarray, ops.Tensor, indexed_slices.IndexedSlices)):
return _to_numpy_type(x.dtype)
elif isinstance(x, dtypes.DType):
return _to_numpy_type(x)
@ -277,8 +284,11 @@ def np_doc(np_fun, np_fun_name=None):
# for name in np_sig.parameters:
# if name not in sig.parameters:
# unsupported_params.append(name)
f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name,
unsupported_params=unsupported_params)
f.__doc__ = _np_doc_helper(
f,
np_fun,
np_fun_name=np_fun_name,
unsupported_params=unsupported_params)
return f
return decorator
@ -287,9 +297,9 @@ def np_doc(np_fun, np_fun_name=None):
def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None):
"""Helper to get docs."""
if not unsupported_params and not _has_docstring(f) and _has_docstring(np_f):
# TODO(wangpeng): It looks like code snippets in numpy doc don't work
# correctly with doctest. Fix that and remove the reformatting of the np_f
# comment, here and below.
# TODO(wangpeng): It looks like code snippets in numpy doc don't work
# correctly with doctest. Fix that and remove the reformatting of the np_f
# comment, here and below.
return np_f.__doc__.replace('>>>', '>')
assert np_f or np_fun_name
if not np_fun_name:

View File

@ -134,7 +134,7 @@ def _sort_or_argsort(values, axis, direction, return_argsort):
# Axis must be an integer, not a Tensor.
axis = framework_ops.convert_to_tensor(axis, name='axis')
axis_static = tensor_util.constant_value(axis)
if axis.shape.ndims != 0 or axis_static is None:
if axis.shape.ndims not in (None, 0) or axis_static is None:
raise ValueError('axis must be a constant scalar')
axis_static = int(axis_static) # Avoids NumPy casting error
@ -184,18 +184,8 @@ def _descending_sort(values, axis, return_argsort=False):
name='transposition')
else:
# Generate the transposition array from the tensors.
transposition = array_ops.concat(
[
# Axes up to axis are unchanged.
math_ops.range(axis),
# Swap axis and rank - 1.
[rank - 1],
# Axes in [axis + 1, rank - 1) are unchanged.
math_ops.range(axis + 1, rank - 1),
# Swap axis and rank - 1.
[axis]
],
axis=0)
transposition = array_ops.tensor_scatter_update(
math_ops.range(rank), [[axis], [rank-1]], [rank-1, axis])
top_k_input = array_ops.transpose(values, transposition)
values, indices = nn_ops.top_k(top_k_input, k)