tf.numpy: Improve ndarray.__getitem__ to match numpy semantics.
PiperOrigin-RevId: 317256717 Change-Id: Ie89b81689f96242e3e9b01568e13937b80aaffc7
This commit is contained in:
parent
539e9cb3a2
commit
8e654afea4
@ -20,12 +20,15 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import numbers
|
||||
from typing import Sequence
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -164,9 +167,11 @@ def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None): #
|
||||
@np_utils.np_doc_only(np.array)
|
||||
def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name
|
||||
"""Since Tensors are immutable, a copy is made only if val is placed on a
|
||||
|
||||
different device than the current one. Even if `copy` is False, a new Tensor
|
||||
may need to be built to satisfy `dtype` and `ndim`. This is used only if `val`
|
||||
is an ndarray or a Tensor.""" # pylint:disable=g-docstring-missing-newline
|
||||
is an ndarray or a Tensor.
|
||||
""" # pylint:disable=g-docstring-missing-newline
|
||||
if dtype:
|
||||
dtype = np_utils.result_type(dtype)
|
||||
if isinstance(val, np_arrays.ndarray):
|
||||
@ -215,6 +220,8 @@ def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-out
|
||||
result_t = np_utils.cond(
|
||||
np_utils.greater(ndmin, ndims), true_fn, lambda: result_t)
|
||||
return np_arrays.tensor_to_ndarray(result_t)
|
||||
|
||||
|
||||
# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args
|
||||
|
||||
|
||||
@ -1446,14 +1453,13 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring
|
||||
# broadcast.
|
||||
arr_shape_original = array_ops.shape(arr)
|
||||
indices_shape_original = array_ops.shape(indices)
|
||||
arr_shape = array_ops.tensor_scatter_update(
|
||||
arr_shape_original, [[axis]], [1])
|
||||
indices_shape = array_ops.tensor_scatter_update(
|
||||
indices_shape_original, [[axis]], [1])
|
||||
broadcasted_shape = array_ops.broadcast_dynamic_shape(
|
||||
arr_shape, indices_shape)
|
||||
arr_shape = array_ops.tensor_scatter_update(
|
||||
broadcasted_shape, [[axis]], [arr_shape_original[axis]])
|
||||
arr_shape = array_ops.tensor_scatter_update(arr_shape_original, [[axis]], [1])
|
||||
indices_shape = array_ops.tensor_scatter_update(indices_shape_original,
|
||||
[[axis]], [1])
|
||||
broadcasted_shape = array_ops.broadcast_dynamic_shape(arr_shape,
|
||||
indices_shape)
|
||||
arr_shape = array_ops.tensor_scatter_update(broadcasted_shape, [[axis]],
|
||||
[arr_shape_original[axis]])
|
||||
indices_shape = array_ops.tensor_scatter_update(
|
||||
broadcasted_shape, [[axis]], [indices_shape_original[axis]])
|
||||
arr = array_ops.broadcast_to(arr, arr_shape)
|
||||
@ -1468,10 +1474,10 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring
|
||||
swapaxes_ = lambda t: swapaxes(np_utils.tensor_to_ndarray(t), axis, -1).data
|
||||
|
||||
dont_move_axis_to_end = math_ops.equal(axis, rank - 1)
|
||||
arr = np_utils.cond(
|
||||
dont_move_axis_to_end, lambda: arr, lambda: swapaxes_(arr))
|
||||
indices = np_utils.cond(
|
||||
dont_move_axis_to_end, lambda: indices, lambda: swapaxes_(indices))
|
||||
arr = np_utils.cond(dont_move_axis_to_end, lambda: arr,
|
||||
lambda: swapaxes_(arr))
|
||||
indices = np_utils.cond(dont_move_axis_to_end, lambda: indices,
|
||||
lambda: swapaxes_(indices))
|
||||
|
||||
arr_shape = array_ops.shape(arr)
|
||||
arr = array_ops.reshape(arr, [-1, arr_shape[-1]])
|
||||
@ -1481,8 +1487,231 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring
|
||||
|
||||
result = array_ops.gather(arr, indices, batch_dims=1)
|
||||
result = array_ops.reshape(result, indices_shape)
|
||||
result = np_utils.cond(
|
||||
dont_move_axis_to_end, lambda: result, lambda: swapaxes_(result))
|
||||
result = np_utils.cond(dont_move_axis_to_end, lambda: result,
|
||||
lambda: swapaxes_(result))
|
||||
result.set_shape(possible_result_shape)
|
||||
|
||||
return np_utils.tensor_to_ndarray(result)
|
||||
return np_utils.tensor_to_ndarray(result)
|
||||
|
||||
|
||||
_SLICE_ERORR = (
|
||||
'only integers, slices (`:`), ellipsis (`...`), '
|
||||
'numpy.newaxis (`None`) and integer or boolean arrays are valid indices')
|
||||
|
||||
|
||||
def _as_index(idx, need_scalar=True):
|
||||
"""Helper function to parse idx as an index.
|
||||
|
||||
Args:
|
||||
idx: index
|
||||
need_scalar: If idx needs to be a scalar value.
|
||||
|
||||
Returns:
|
||||
A pair, (indx, bool). First one is the parsed index and can be a tensor,
|
||||
or scalar integer / Dimension. Second one is True if rank is known to be 0.
|
||||
|
||||
Raises:
|
||||
IndexError: For incorrect indices.
|
||||
"""
|
||||
if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)):
|
||||
return idx, True
|
||||
data = asarray(idx).data
|
||||
if data.dtype == dtypes.bool:
|
||||
if data.shape.ndims != 1:
|
||||
# TODO(agarwal): handle higher rank boolean masks.
|
||||
raise NotImplementedError('Need rank 1 for bool index %s' % idx)
|
||||
data = array_ops.where_v2(data)
|
||||
data = array_ops.reshape(data, [-1])
|
||||
if need_scalar and data.shape.rank not in (None, 0):
|
||||
raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
|
||||
np_dtype = data.dtype.as_numpy_dtype
|
||||
if not np.issubdtype(np_dtype, np.integer):
|
||||
raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
|
||||
if data.dtype not in (dtypes.int64, dtypes.int32):
|
||||
# TF slicing can only handle int32/int64. So we need to cast.
|
||||
promoted_dtype = np.promote_types(np.int32, np_dtype)
|
||||
if promoted_dtype == np.int32:
|
||||
data = math_ops.cast(data, dtypes.int32)
|
||||
elif promoted_dtype == np.int64:
|
||||
data = math_ops.cast(data, dtypes.int64)
|
||||
else:
|
||||
raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
|
||||
return data, data.shape.rank == 0
|
||||
|
||||
|
||||
def _slice_helper(tensor, slice_spec):
|
||||
"""Helper function for __getitem__."""
|
||||
begin, end, strides = [], [], []
|
||||
new_axis_mask, shrink_axis_mask = 0, 0
|
||||
begin_mask, end_mask = 0, 0
|
||||
ellipsis_mask = 0
|
||||
advanced_indices = []
|
||||
shrink_indices = []
|
||||
for index, s in enumerate(slice_spec):
|
||||
if isinstance(s, slice):
|
||||
if s.start is not None:
|
||||
begin.append(_as_index(s.start)[0])
|
||||
else:
|
||||
begin.append(0)
|
||||
begin_mask |= (1 << index)
|
||||
if s.stop is not None:
|
||||
end.append(_as_index(s.stop)[0])
|
||||
else:
|
||||
end.append(0)
|
||||
end_mask |= (1 << index)
|
||||
if s.step is not None:
|
||||
strides.append(_as_index(s.step)[0])
|
||||
else:
|
||||
strides.append(1)
|
||||
elif s is Ellipsis:
|
||||
begin.append(0)
|
||||
end.append(0)
|
||||
strides.append(1)
|
||||
ellipsis_mask |= (1 << index)
|
||||
elif s is array_ops.newaxis:
|
||||
begin.append(0)
|
||||
end.append(0)
|
||||
strides.append(1)
|
||||
new_axis_mask |= (1 << index)
|
||||
else:
|
||||
s, is_scalar = _as_index(s, False)
|
||||
if is_scalar:
|
||||
begin.append(s)
|
||||
end.append(s + 1)
|
||||
strides.append(1)
|
||||
shrink_axis_mask |= (1 << index)
|
||||
shrink_indices.append(index)
|
||||
else:
|
||||
begin.append(0)
|
||||
end.append(0)
|
||||
strides.append(1)
|
||||
begin_mask |= (1 << index)
|
||||
end_mask |= (1 << index)
|
||||
advanced_indices.append((index, s, ellipsis_mask != 0))
|
||||
|
||||
# stack possibly involves no tensors, so we must use op_scope correct graph.
|
||||
with ops.name_scope(
|
||||
None,
|
||||
'strided_slice', [tensor] + begin + end + strides,
|
||||
skip_on_eager=False) as name:
|
||||
if begin:
|
||||
packed_begin, packed_end, packed_strides = (array_ops.stack(begin),
|
||||
array_ops.stack(end),
|
||||
array_ops.stack(strides))
|
||||
if (packed_begin.dtype == dtypes.int64 or
|
||||
packed_end.dtype == dtypes.int64 or
|
||||
packed_strides.dtype == dtypes.int64):
|
||||
if packed_begin.dtype != dtypes.int64:
|
||||
packed_begin = math_ops.cast(packed_begin, dtypes.int64)
|
||||
if packed_end.dtype != dtypes.int64:
|
||||
packed_end = math_ops.cast(packed_end, dtypes.int64)
|
||||
if packed_strides.dtype != dtypes.int64:
|
||||
packed_strides = math_ops.cast(packed_strides, dtypes.int64)
|
||||
else:
|
||||
var_empty = constant_op.constant([], dtype=dtypes.int32)
|
||||
packed_begin = packed_end = packed_strides = var_empty
|
||||
# TODO(agarwal): set_shape on tensor to set rank.
|
||||
tensor = array_ops.strided_slice(
|
||||
tensor,
|
||||
packed_begin,
|
||||
packed_end,
|
||||
packed_strides,
|
||||
begin_mask=begin_mask,
|
||||
end_mask=end_mask,
|
||||
shrink_axis_mask=shrink_axis_mask,
|
||||
new_axis_mask=new_axis_mask,
|
||||
ellipsis_mask=ellipsis_mask,
|
||||
name=name)
|
||||
if not advanced_indices:
|
||||
return tensor
|
||||
advanced_indices_map = {}
|
||||
for index, data, had_ellipsis in advanced_indices:
|
||||
if had_ellipsis:
|
||||
num_shrink = len([x for x in shrink_indices if x > index])
|
||||
dim = index - len(slice_spec) + num_shrink
|
||||
else:
|
||||
num_shrink = len([x for x in shrink_indices if x < index])
|
||||
dim = index - num_shrink
|
||||
advanced_indices_map[dim] = data
|
||||
dims = sorted(advanced_indices_map.keys())
|
||||
dims_contiguous = True
|
||||
if len(dims) > 1:
|
||||
if dims[0] < 0 and dims[-1] >= 0: # not all same sign
|
||||
dims_contiguous = False
|
||||
else:
|
||||
for i in range(len(dims) - 1):
|
||||
if dims[i] + 1 != dims[i + 1]:
|
||||
dims_contiguous = False
|
||||
break
|
||||
indices = [advanced_indices_map[x] for x in dims]
|
||||
indices = [x.data for x in _promote_dtype(*indices)]
|
||||
indices = np_utils.tf_broadcast(*indices)
|
||||
stacked_indices = array_ops.stack(indices, axis=-1)
|
||||
if not dims_contiguous:
|
||||
tensor = moveaxis(tensor, dims, range(len(dims))).data
|
||||
tensor_shape_prefix = array_ops.shape(
|
||||
tensor, out_type=stacked_indices.dtype)[:len(dims)]
|
||||
stacked_indices = array_ops.where_v2(
|
||||
stacked_indices < 0, stacked_indices + tensor_shape_prefix,
|
||||
stacked_indices)
|
||||
return array_ops.gather_nd(tensor, stacked_indices)
|
||||
# Note that gather_nd does not support gathering from inside the array.
|
||||
# To avoid shuffling data back and forth, we transform the indices and
|
||||
# do a gather instead.
|
||||
rank = np_utils._maybe_static(array_ops.rank(tensor)) # pylint: disable=protected-access
|
||||
dims = [(x + rank if x < 0 else x) for x in dims]
|
||||
shape_tensor = array_ops.shape(tensor, out_type=stacked_indices.dtype)
|
||||
dim_sizes = array_ops.gather(shape_tensor, dims)
|
||||
if len(dims) == 1:
|
||||
stacked_indices = indices[0]
|
||||
stacked_indices = array_ops.where_v2(stacked_indices < 0,
|
||||
stacked_indices + dim_sizes,
|
||||
stacked_indices)
|
||||
axis = dims[0]
|
||||
if len(dims) > 1:
|
||||
index_scaling = math_ops.cumprod(
|
||||
dim_sizes, reverse=True, exclusive=True)
|
||||
stacked_indices = math_ops.tensordot(
|
||||
stacked_indices, index_scaling, axes=1)
|
||||
flat_shape = array_ops.concat(
|
||||
[shape_tensor[:axis], [-1], shape_tensor[axis + len(dims):]],
|
||||
axis=0)
|
||||
tensor = array_ops.reshape(tensor, flat_shape)
|
||||
|
||||
return array_ops.gather(tensor, stacked_indices, axis=axis)
|
||||
|
||||
|
||||
def _as_spec_tuple(slice_spec):
|
||||
"""Convert slice_spec to tuple."""
|
||||
if isinstance(slice_spec,
|
||||
Sequence) and not isinstance(slice_spec, np.ndarray):
|
||||
is_index = True
|
||||
for s in slice_spec:
|
||||
if s is None or s is Ellipsis or isinstance(s, (Sequence, slice)):
|
||||
is_index = False
|
||||
break
|
||||
elif isinstance(s, (np_arrays.ndarray, np.ndarray)) and s.ndim != 0:
|
||||
is_index = False
|
||||
break
|
||||
if not is_index:
|
||||
return tuple(slice_spec)
|
||||
return (slice_spec,)
|
||||
|
||||
|
||||
def _getitem(self, slice_spec):
|
||||
"""Implementation of ndarray.__getitem__."""
|
||||
if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
|
||||
slice_spec.dtype == dtypes.bool) or
|
||||
(isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and
|
||||
slice_spec.dtype == np.bool)):
|
||||
return np_utils.tensor_to_ndarray(
|
||||
array_ops.boolean_mask(tensor=self.data, mask=slice_spec))
|
||||
|
||||
if not isinstance(slice_spec, tuple):
|
||||
slice_spec = _as_spec_tuple(slice_spec)
|
||||
|
||||
result_t = _slice_helper(self.data, slice_spec)
|
||||
return np_utils.tensor_to_ndarray(result_t)
|
||||
|
||||
|
||||
setattr(np_arrays.ndarray, '__getitem__', _getitem)
|
||||
|
@ -20,138 +20,17 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numbers
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.numpy_ops import np_dtypes
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
_SLICE_TYPE_ERROR = (
|
||||
'Only integers, slices (`:`), ellipsis (`...`), '
|
||||
'tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid '
|
||||
'indices')
|
||||
|
||||
_SUPPORTED_SLICE_DTYPES = (dtypes.int32, dtypes.int32_ref, dtypes.int64,
|
||||
dtypes.int64_ref)
|
||||
|
||||
|
||||
def _check_index(idx):
|
||||
"""Check if a given value is a valid index into a tensor."""
|
||||
if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)):
|
||||
return
|
||||
|
||||
# Optimistic check. Assumptions:
|
||||
# * any object with a dtype is supported
|
||||
# * any object with a dtype has a sizeable shape attribute.
|
||||
dtype = getattr(idx, 'dtype', None)
|
||||
if (dtype is None or dtypes.as_dtype(dtype) not in _SUPPORTED_SLICE_DTYPES or
|
||||
idx.shape and len(idx.shape) == 1):
|
||||
# TODO(slebedev): IndexError seems more appropriate here, but it
|
||||
# will break `_slice_helper` contract.
|
||||
raise TypeError(_SLICE_TYPE_ERROR + ', got {!r}'.format(idx))
|
||||
|
||||
|
||||
def _is_undefined_dimension(d):
|
||||
return isinstance(d, tensor_shape.Dimension) and d.value is None
|
||||
|
||||
|
||||
def _slice_helper(tensor, slice_spec, var=None):
|
||||
"""Copied from array_ops._slice_helper, will be merged back later."""
|
||||
if isinstance(slice_spec, bool) or \
|
||||
(isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \
|
||||
(isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool):
|
||||
return array_ops.boolean_mask(tensor=tensor, mask=slice_spec)
|
||||
|
||||
if not isinstance(slice_spec, (list, tuple)):
|
||||
slice_spec = [slice_spec]
|
||||
|
||||
begin, end, strides = [], [], []
|
||||
index = 0
|
||||
|
||||
new_axis_mask, shrink_axis_mask = 0, 0
|
||||
begin_mask, end_mask = 0, 0
|
||||
ellipsis_mask = 0
|
||||
for s in slice_spec:
|
||||
if isinstance(s, slice):
|
||||
if s.start is not None and not _is_undefined_dimension(s.start):
|
||||
_check_index(s.start)
|
||||
begin.append(s.start)
|
||||
else:
|
||||
begin.append(0)
|
||||
begin_mask |= (1 << index)
|
||||
if s.stop is not None and not _is_undefined_dimension(s.stop):
|
||||
_check_index(s.stop)
|
||||
end.append(s.stop)
|
||||
else:
|
||||
end.append(0)
|
||||
end_mask |= (1 << index)
|
||||
if s.step is not None and not _is_undefined_dimension(s.step):
|
||||
_check_index(s.step)
|
||||
strides.append(s.step)
|
||||
else:
|
||||
strides.append(1)
|
||||
elif s is Ellipsis:
|
||||
begin.append(0)
|
||||
end.append(0)
|
||||
strides.append(1)
|
||||
ellipsis_mask |= (1 << index)
|
||||
elif s is array_ops.newaxis:
|
||||
begin.append(0)
|
||||
end.append(0)
|
||||
strides.append(1)
|
||||
new_axis_mask |= (1 << index)
|
||||
else:
|
||||
_check_index(s)
|
||||
begin.append(s)
|
||||
end.append(s + 1)
|
||||
strides.append(1)
|
||||
shrink_axis_mask |= (1 << index)
|
||||
index += 1
|
||||
|
||||
# stack possibly involves no tensors, so we must use op_scope correct graph.
|
||||
with ops.name_scope(
|
||||
None,
|
||||
'strided_slice', [tensor] + begin + end + strides,
|
||||
skip_on_eager=False) as name:
|
||||
if begin:
|
||||
packed_begin, packed_end, packed_strides = (array_ops.stack(begin),
|
||||
array_ops.stack(end),
|
||||
array_ops.stack(strides))
|
||||
if (packed_begin.dtype == dtypes.int64 or
|
||||
packed_end.dtype == dtypes.int64 or
|
||||
packed_strides.dtype == dtypes.int64):
|
||||
if packed_begin.dtype != dtypes.int64:
|
||||
packed_begin = math_ops.cast(packed_begin, dtypes.int64)
|
||||
if packed_end.dtype != dtypes.int64:
|
||||
packed_end = math_ops.cast(packed_end, dtypes.int64)
|
||||
if packed_strides.dtype != dtypes.int64:
|
||||
packed_strides = math_ops.cast(packed_strides, dtypes.int64)
|
||||
else:
|
||||
var_empty = constant_op.constant([], dtype=dtypes.int32)
|
||||
packed_begin = packed_end = packed_strides = var_empty
|
||||
return array_ops.strided_slice(
|
||||
tensor,
|
||||
packed_begin,
|
||||
packed_end,
|
||||
packed_strides,
|
||||
begin_mask=begin_mask,
|
||||
end_mask=end_mask,
|
||||
shrink_axis_mask=shrink_axis_mask,
|
||||
new_axis_mask=new_axis_mask,
|
||||
ellipsis_mask=ellipsis_mask,
|
||||
var=var,
|
||||
name=name)
|
||||
|
||||
|
||||
def convert_to_tensor(value, dtype=None, dtype_hint=None):
|
||||
@ -361,22 +240,6 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
|
||||
def __bool__(self):
|
||||
return self.__nonzero__()
|
||||
|
||||
def __getitem__(self, slice_spec):
|
||||
# TODO(srbs): Need to support better indexing.
|
||||
def _gettensor(x):
|
||||
if isinstance(x, ndarray):
|
||||
x = x.data
|
||||
if isinstance(x, ops.Tensor) and x.dtype not in (
|
||||
dtypes.int32, dtypes.int64):
|
||||
# Currently _slice_helper will only work with int32/int64 tensors, but
|
||||
# type inference by numpy can create {u,}int{8,16}, so just cast.
|
||||
x = math_ops.cast(x, dtypes.int32)
|
||||
return x
|
||||
slice_spec = nest.map_structure(_gettensor, slice_spec)
|
||||
|
||||
result_t = _slice_helper(self.data, slice_spec)
|
||||
return tensor_to_ndarray(result_t)
|
||||
|
||||
def __iter__(self):
|
||||
if not isinstance(self.data, ops.EagerTensor):
|
||||
raise TypeError('Iteration over symbolic tensor is not allowed')
|
||||
|
Loading…
Reference in New Issue
Block a user