tf.numpy: Improve ndarray.__getitem__ to match numpy semantics.

PiperOrigin-RevId: 317256717
Change-Id: Ie89b81689f96242e3e9b01568e13937b80aaffc7
This commit is contained in:
A. Unique TensorFlower 2020-06-18 23:35:25 -07:00 committed by TensorFlower Gardener
parent 539e9cb3a2
commit 8e654afea4
2 changed files with 245 additions and 153 deletions

View File

@ -20,12 +20,15 @@ from __future__ import division
from __future__ import print_function
import math
import numbers
from typing import Sequence
import numpy as np
import six
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
@ -164,9 +167,11 @@ def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None): #
@np_utils.np_doc_only(np.array)
def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name
"""Since Tensors are immutable, a copy is made only if val is placed on a
different device than the current one. Even if `copy` is False, a new Tensor
may need to be built to satisfy `dtype` and `ndim`. This is used only if `val`
is an ndarray or a Tensor.""" # pylint:disable=g-docstring-missing-newline
is an ndarray or a Tensor.
""" # pylint:disable=g-docstring-missing-newline
if dtype:
dtype = np_utils.result_type(dtype)
if isinstance(val, np_arrays.ndarray):
@ -215,6 +220,8 @@ def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-out
result_t = np_utils.cond(
np_utils.greater(ndmin, ndims), true_fn, lambda: result_t)
return np_arrays.tensor_to_ndarray(result_t)
# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args
@ -1446,14 +1453,13 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring
# broadcast.
arr_shape_original = array_ops.shape(arr)
indices_shape_original = array_ops.shape(indices)
arr_shape = array_ops.tensor_scatter_update(
arr_shape_original, [[axis]], [1])
indices_shape = array_ops.tensor_scatter_update(
indices_shape_original, [[axis]], [1])
broadcasted_shape = array_ops.broadcast_dynamic_shape(
arr_shape, indices_shape)
arr_shape = array_ops.tensor_scatter_update(
broadcasted_shape, [[axis]], [arr_shape_original[axis]])
arr_shape = array_ops.tensor_scatter_update(arr_shape_original, [[axis]], [1])
indices_shape = array_ops.tensor_scatter_update(indices_shape_original,
[[axis]], [1])
broadcasted_shape = array_ops.broadcast_dynamic_shape(arr_shape,
indices_shape)
arr_shape = array_ops.tensor_scatter_update(broadcasted_shape, [[axis]],
[arr_shape_original[axis]])
indices_shape = array_ops.tensor_scatter_update(
broadcasted_shape, [[axis]], [indices_shape_original[axis]])
arr = array_ops.broadcast_to(arr, arr_shape)
@ -1468,10 +1474,10 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring
swapaxes_ = lambda t: swapaxes(np_utils.tensor_to_ndarray(t), axis, -1).data
dont_move_axis_to_end = math_ops.equal(axis, rank - 1)
arr = np_utils.cond(
dont_move_axis_to_end, lambda: arr, lambda: swapaxes_(arr))
indices = np_utils.cond(
dont_move_axis_to_end, lambda: indices, lambda: swapaxes_(indices))
arr = np_utils.cond(dont_move_axis_to_end, lambda: arr,
lambda: swapaxes_(arr))
indices = np_utils.cond(dont_move_axis_to_end, lambda: indices,
lambda: swapaxes_(indices))
arr_shape = array_ops.shape(arr)
arr = array_ops.reshape(arr, [-1, arr_shape[-1]])
@ -1481,8 +1487,231 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring
result = array_ops.gather(arr, indices, batch_dims=1)
result = array_ops.reshape(result, indices_shape)
result = np_utils.cond(
dont_move_axis_to_end, lambda: result, lambda: swapaxes_(result))
result = np_utils.cond(dont_move_axis_to_end, lambda: result,
lambda: swapaxes_(result))
result.set_shape(possible_result_shape)
return np_utils.tensor_to_ndarray(result)
return np_utils.tensor_to_ndarray(result)
_SLICE_ERORR = (
'only integers, slices (`:`), ellipsis (`...`), '
'numpy.newaxis (`None`) and integer or boolean arrays are valid indices')
def _as_index(idx, need_scalar=True):
"""Helper function to parse idx as an index.
Args:
idx: index
need_scalar: If idx needs to be a scalar value.
Returns:
A pair, (indx, bool). First one is the parsed index and can be a tensor,
or scalar integer / Dimension. Second one is True if rank is known to be 0.
Raises:
IndexError: For incorrect indices.
"""
if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)):
return idx, True
data = asarray(idx).data
if data.dtype == dtypes.bool:
if data.shape.ndims != 1:
# TODO(agarwal): handle higher rank boolean masks.
raise NotImplementedError('Need rank 1 for bool index %s' % idx)
data = array_ops.where_v2(data)
data = array_ops.reshape(data, [-1])
if need_scalar and data.shape.rank not in (None, 0):
raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
np_dtype = data.dtype.as_numpy_dtype
if not np.issubdtype(np_dtype, np.integer):
raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
if data.dtype not in (dtypes.int64, dtypes.int32):
# TF slicing can only handle int32/int64. So we need to cast.
promoted_dtype = np.promote_types(np.int32, np_dtype)
if promoted_dtype == np.int32:
data = math_ops.cast(data, dtypes.int32)
elif promoted_dtype == np.int64:
data = math_ops.cast(data, dtypes.int64)
else:
raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
return data, data.shape.rank == 0
def _slice_helper(tensor, slice_spec):
"""Helper function for __getitem__."""
begin, end, strides = [], [], []
new_axis_mask, shrink_axis_mask = 0, 0
begin_mask, end_mask = 0, 0
ellipsis_mask = 0
advanced_indices = []
shrink_indices = []
for index, s in enumerate(slice_spec):
if isinstance(s, slice):
if s.start is not None:
begin.append(_as_index(s.start)[0])
else:
begin.append(0)
begin_mask |= (1 << index)
if s.stop is not None:
end.append(_as_index(s.stop)[0])
else:
end.append(0)
end_mask |= (1 << index)
if s.step is not None:
strides.append(_as_index(s.step)[0])
else:
strides.append(1)
elif s is Ellipsis:
begin.append(0)
end.append(0)
strides.append(1)
ellipsis_mask |= (1 << index)
elif s is array_ops.newaxis:
begin.append(0)
end.append(0)
strides.append(1)
new_axis_mask |= (1 << index)
else:
s, is_scalar = _as_index(s, False)
if is_scalar:
begin.append(s)
end.append(s + 1)
strides.append(1)
shrink_axis_mask |= (1 << index)
shrink_indices.append(index)
else:
begin.append(0)
end.append(0)
strides.append(1)
begin_mask |= (1 << index)
end_mask |= (1 << index)
advanced_indices.append((index, s, ellipsis_mask != 0))
# stack possibly involves no tensors, so we must use op_scope correct graph.
with ops.name_scope(
None,
'strided_slice', [tensor] + begin + end + strides,
skip_on_eager=False) as name:
if begin:
packed_begin, packed_end, packed_strides = (array_ops.stack(begin),
array_ops.stack(end),
array_ops.stack(strides))
if (packed_begin.dtype == dtypes.int64 or
packed_end.dtype == dtypes.int64 or
packed_strides.dtype == dtypes.int64):
if packed_begin.dtype != dtypes.int64:
packed_begin = math_ops.cast(packed_begin, dtypes.int64)
if packed_end.dtype != dtypes.int64:
packed_end = math_ops.cast(packed_end, dtypes.int64)
if packed_strides.dtype != dtypes.int64:
packed_strides = math_ops.cast(packed_strides, dtypes.int64)
else:
var_empty = constant_op.constant([], dtype=dtypes.int32)
packed_begin = packed_end = packed_strides = var_empty
# TODO(agarwal): set_shape on tensor to set rank.
tensor = array_ops.strided_slice(
tensor,
packed_begin,
packed_end,
packed_strides,
begin_mask=begin_mask,
end_mask=end_mask,
shrink_axis_mask=shrink_axis_mask,
new_axis_mask=new_axis_mask,
ellipsis_mask=ellipsis_mask,
name=name)
if not advanced_indices:
return tensor
advanced_indices_map = {}
for index, data, had_ellipsis in advanced_indices:
if had_ellipsis:
num_shrink = len([x for x in shrink_indices if x > index])
dim = index - len(slice_spec) + num_shrink
else:
num_shrink = len([x for x in shrink_indices if x < index])
dim = index - num_shrink
advanced_indices_map[dim] = data
dims = sorted(advanced_indices_map.keys())
dims_contiguous = True
if len(dims) > 1:
if dims[0] < 0 and dims[-1] >= 0: # not all same sign
dims_contiguous = False
else:
for i in range(len(dims) - 1):
if dims[i] + 1 != dims[i + 1]:
dims_contiguous = False
break
indices = [advanced_indices_map[x] for x in dims]
indices = [x.data for x in _promote_dtype(*indices)]
indices = np_utils.tf_broadcast(*indices)
stacked_indices = array_ops.stack(indices, axis=-1)
if not dims_contiguous:
tensor = moveaxis(tensor, dims, range(len(dims))).data
tensor_shape_prefix = array_ops.shape(
tensor, out_type=stacked_indices.dtype)[:len(dims)]
stacked_indices = array_ops.where_v2(
stacked_indices < 0, stacked_indices + tensor_shape_prefix,
stacked_indices)
return array_ops.gather_nd(tensor, stacked_indices)
# Note that gather_nd does not support gathering from inside the array.
# To avoid shuffling data back and forth, we transform the indices and
# do a gather instead.
rank = np_utils._maybe_static(array_ops.rank(tensor)) # pylint: disable=protected-access
dims = [(x + rank if x < 0 else x) for x in dims]
shape_tensor = array_ops.shape(tensor, out_type=stacked_indices.dtype)
dim_sizes = array_ops.gather(shape_tensor, dims)
if len(dims) == 1:
stacked_indices = indices[0]
stacked_indices = array_ops.where_v2(stacked_indices < 0,
stacked_indices + dim_sizes,
stacked_indices)
axis = dims[0]
if len(dims) > 1:
index_scaling = math_ops.cumprod(
dim_sizes, reverse=True, exclusive=True)
stacked_indices = math_ops.tensordot(
stacked_indices, index_scaling, axes=1)
flat_shape = array_ops.concat(
[shape_tensor[:axis], [-1], shape_tensor[axis + len(dims):]],
axis=0)
tensor = array_ops.reshape(tensor, flat_shape)
return array_ops.gather(tensor, stacked_indices, axis=axis)
def _as_spec_tuple(slice_spec):
"""Convert slice_spec to tuple."""
if isinstance(slice_spec,
Sequence) and not isinstance(slice_spec, np.ndarray):
is_index = True
for s in slice_spec:
if s is None or s is Ellipsis or isinstance(s, (Sequence, slice)):
is_index = False
break
elif isinstance(s, (np_arrays.ndarray, np.ndarray)) and s.ndim != 0:
is_index = False
break
if not is_index:
return tuple(slice_spec)
return (slice_spec,)
def _getitem(self, slice_spec):
"""Implementation of ndarray.__getitem__."""
if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
slice_spec.dtype == dtypes.bool) or
(isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and
slice_spec.dtype == np.bool)):
return np_utils.tensor_to_ndarray(
array_ops.boolean_mask(tensor=self.data, mask=slice_spec))
if not isinstance(slice_spec, tuple):
slice_spec = _as_spec_tuple(slice_spec)
result_t = _slice_helper(self.data, slice_spec)
return np_utils.tensor_to_ndarray(result_t)
setattr(np_arrays.ndarray, '__getitem__', _getitem)

View File

@ -20,138 +20,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numbers
import numpy as np
import six
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.numpy_ops import np_dtypes
from tensorflow.python.util import nest
_SLICE_TYPE_ERROR = (
'Only integers, slices (`:`), ellipsis (`...`), '
'tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid '
'indices')
_SUPPORTED_SLICE_DTYPES = (dtypes.int32, dtypes.int32_ref, dtypes.int64,
dtypes.int64_ref)
def _check_index(idx):
"""Check if a given value is a valid index into a tensor."""
if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)):
return
# Optimistic check. Assumptions:
# * any object with a dtype is supported
# * any object with a dtype has a sizeable shape attribute.
dtype = getattr(idx, 'dtype', None)
if (dtype is None or dtypes.as_dtype(dtype) not in _SUPPORTED_SLICE_DTYPES or
idx.shape and len(idx.shape) == 1):
# TODO(slebedev): IndexError seems more appropriate here, but it
# will break `_slice_helper` contract.
raise TypeError(_SLICE_TYPE_ERROR + ', got {!r}'.format(idx))
def _is_undefined_dimension(d):
return isinstance(d, tensor_shape.Dimension) and d.value is None
def _slice_helper(tensor, slice_spec, var=None):
"""Copied from array_ops._slice_helper, will be merged back later."""
if isinstance(slice_spec, bool) or \
(isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \
(isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool):
return array_ops.boolean_mask(tensor=tensor, mask=slice_spec)
if not isinstance(slice_spec, (list, tuple)):
slice_spec = [slice_spec]
begin, end, strides = [], [], []
index = 0
new_axis_mask, shrink_axis_mask = 0, 0
begin_mask, end_mask = 0, 0
ellipsis_mask = 0
for s in slice_spec:
if isinstance(s, slice):
if s.start is not None and not _is_undefined_dimension(s.start):
_check_index(s.start)
begin.append(s.start)
else:
begin.append(0)
begin_mask |= (1 << index)
if s.stop is not None and not _is_undefined_dimension(s.stop):
_check_index(s.stop)
end.append(s.stop)
else:
end.append(0)
end_mask |= (1 << index)
if s.step is not None and not _is_undefined_dimension(s.step):
_check_index(s.step)
strides.append(s.step)
else:
strides.append(1)
elif s is Ellipsis:
begin.append(0)
end.append(0)
strides.append(1)
ellipsis_mask |= (1 << index)
elif s is array_ops.newaxis:
begin.append(0)
end.append(0)
strides.append(1)
new_axis_mask |= (1 << index)
else:
_check_index(s)
begin.append(s)
end.append(s + 1)
strides.append(1)
shrink_axis_mask |= (1 << index)
index += 1
# stack possibly involves no tensors, so we must use op_scope correct graph.
with ops.name_scope(
None,
'strided_slice', [tensor] + begin + end + strides,
skip_on_eager=False) as name:
if begin:
packed_begin, packed_end, packed_strides = (array_ops.stack(begin),
array_ops.stack(end),
array_ops.stack(strides))
if (packed_begin.dtype == dtypes.int64 or
packed_end.dtype == dtypes.int64 or
packed_strides.dtype == dtypes.int64):
if packed_begin.dtype != dtypes.int64:
packed_begin = math_ops.cast(packed_begin, dtypes.int64)
if packed_end.dtype != dtypes.int64:
packed_end = math_ops.cast(packed_end, dtypes.int64)
if packed_strides.dtype != dtypes.int64:
packed_strides = math_ops.cast(packed_strides, dtypes.int64)
else:
var_empty = constant_op.constant([], dtype=dtypes.int32)
packed_begin = packed_end = packed_strides = var_empty
return array_ops.strided_slice(
tensor,
packed_begin,
packed_end,
packed_strides,
begin_mask=begin_mask,
end_mask=end_mask,
shrink_axis_mask=shrink_axis_mask,
new_axis_mask=new_axis_mask,
ellipsis_mask=ellipsis_mask,
var=var,
name=name)
def convert_to_tensor(value, dtype=None, dtype_hint=None):
@ -361,22 +240,6 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
def __bool__(self):
return self.__nonzero__()
def __getitem__(self, slice_spec):
# TODO(srbs): Need to support better indexing.
def _gettensor(x):
if isinstance(x, ndarray):
x = x.data
if isinstance(x, ops.Tensor) and x.dtype not in (
dtypes.int32, dtypes.int64):
# Currently _slice_helper will only work with int32/int64 tensors, but
# type inference by numpy can create {u,}int{8,16}, so just cast.
x = math_ops.cast(x, dtypes.int32)
return x
slice_spec = nest.map_structure(_gettensor, slice_spec)
result_t = _slice_helper(self.data, slice_spec)
return tensor_to_ndarray(result_t)
def __iter__(self):
if not isinstance(self.data, ops.EagerTensor):
raise TypeError('Iteration over symbolic tensor is not allowed')