From 8e654afea4adba36b94b0f7a3d33a23e788612e0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 18 Jun 2020 23:35:25 -0700 Subject: [PATCH] tf.numpy: Improve ndarray.__getitem__ to match numpy semantics. PiperOrigin-RevId: 317256717 Change-Id: Ie89b81689f96242e3e9b01568e13937b80aaffc7 --- .../python/ops/numpy_ops/np_array_ops.py | 261 ++++++++++++++++-- tensorflow/python/ops/numpy_ops/np_arrays.py | 137 --------- 2 files changed, 245 insertions(+), 153 deletions(-) diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops.py b/tensorflow/python/ops/numpy_ops/np_array_ops.py index 906e53c556d..47236d45561 100644 --- a/tensorflow/python/ops/numpy_ops/np_array_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_array_ops.py @@ -20,12 +20,15 @@ from __future__ import division from __future__ import print_function import math +import numbers +from typing import Sequence import numpy as np import six from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops @@ -164,9 +167,11 @@ def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None): # @np_utils.np_doc_only(np.array) def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name """Since Tensors are immutable, a copy is made only if val is placed on a + different device than the current one. Even if `copy` is False, a new Tensor may need to be built to satisfy `dtype` and `ndim`. This is used only if `val` - is an ndarray or a Tensor.""" # pylint:disable=g-docstring-missing-newline + is an ndarray or a Tensor. + """ # pylint:disable=g-docstring-missing-newline if dtype: dtype = np_utils.result_type(dtype) if isinstance(val, np_arrays.ndarray): @@ -215,6 +220,8 @@ def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-out result_t = np_utils.cond( np_utils.greater(ndmin, ndims), true_fn, lambda: result_t) return np_arrays.tensor_to_ndarray(result_t) + + # pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args @@ -1446,14 +1453,13 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring # broadcast. arr_shape_original = array_ops.shape(arr) indices_shape_original = array_ops.shape(indices) - arr_shape = array_ops.tensor_scatter_update( - arr_shape_original, [[axis]], [1]) - indices_shape = array_ops.tensor_scatter_update( - indices_shape_original, [[axis]], [1]) - broadcasted_shape = array_ops.broadcast_dynamic_shape( - arr_shape, indices_shape) - arr_shape = array_ops.tensor_scatter_update( - broadcasted_shape, [[axis]], [arr_shape_original[axis]]) + arr_shape = array_ops.tensor_scatter_update(arr_shape_original, [[axis]], [1]) + indices_shape = array_ops.tensor_scatter_update(indices_shape_original, + [[axis]], [1]) + broadcasted_shape = array_ops.broadcast_dynamic_shape(arr_shape, + indices_shape) + arr_shape = array_ops.tensor_scatter_update(broadcasted_shape, [[axis]], + [arr_shape_original[axis]]) indices_shape = array_ops.tensor_scatter_update( broadcasted_shape, [[axis]], [indices_shape_original[axis]]) arr = array_ops.broadcast_to(arr, arr_shape) @@ -1468,10 +1474,10 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring swapaxes_ = lambda t: swapaxes(np_utils.tensor_to_ndarray(t), axis, -1).data dont_move_axis_to_end = math_ops.equal(axis, rank - 1) - arr = np_utils.cond( - dont_move_axis_to_end, lambda: arr, lambda: swapaxes_(arr)) - indices = np_utils.cond( - dont_move_axis_to_end, lambda: indices, lambda: swapaxes_(indices)) + arr = np_utils.cond(dont_move_axis_to_end, lambda: arr, + lambda: swapaxes_(arr)) + indices = np_utils.cond(dont_move_axis_to_end, lambda: indices, + lambda: swapaxes_(indices)) arr_shape = array_ops.shape(arr) arr = array_ops.reshape(arr, [-1, arr_shape[-1]]) @@ -1481,8 +1487,231 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring result = array_ops.gather(arr, indices, batch_dims=1) result = array_ops.reshape(result, indices_shape) - result = np_utils.cond( - dont_move_axis_to_end, lambda: result, lambda: swapaxes_(result)) + result = np_utils.cond(dont_move_axis_to_end, lambda: result, + lambda: swapaxes_(result)) result.set_shape(possible_result_shape) - return np_utils.tensor_to_ndarray(result) + return np_utils.tensor_to_ndarray(result) + + +_SLICE_ERORR = ( + 'only integers, slices (`:`), ellipsis (`...`), ' + 'numpy.newaxis (`None`) and integer or boolean arrays are valid indices') + + +def _as_index(idx, need_scalar=True): + """Helper function to parse idx as an index. + + Args: + idx: index + need_scalar: If idx needs to be a scalar value. + + Returns: + A pair, (indx, bool). First one is the parsed index and can be a tensor, + or scalar integer / Dimension. Second one is True if rank is known to be 0. + + Raises: + IndexError: For incorrect indices. + """ + if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)): + return idx, True + data = asarray(idx).data + if data.dtype == dtypes.bool: + if data.shape.ndims != 1: + # TODO(agarwal): handle higher rank boolean masks. + raise NotImplementedError('Need rank 1 for bool index %s' % idx) + data = array_ops.where_v2(data) + data = array_ops.reshape(data, [-1]) + if need_scalar and data.shape.rank not in (None, 0): + raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx)) + np_dtype = data.dtype.as_numpy_dtype + if not np.issubdtype(np_dtype, np.integer): + raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx)) + if data.dtype not in (dtypes.int64, dtypes.int32): + # TF slicing can only handle int32/int64. So we need to cast. + promoted_dtype = np.promote_types(np.int32, np_dtype) + if promoted_dtype == np.int32: + data = math_ops.cast(data, dtypes.int32) + elif promoted_dtype == np.int64: + data = math_ops.cast(data, dtypes.int64) + else: + raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx)) + return data, data.shape.rank == 0 + + +def _slice_helper(tensor, slice_spec): + """Helper function for __getitem__.""" + begin, end, strides = [], [], [] + new_axis_mask, shrink_axis_mask = 0, 0 + begin_mask, end_mask = 0, 0 + ellipsis_mask = 0 + advanced_indices = [] + shrink_indices = [] + for index, s in enumerate(slice_spec): + if isinstance(s, slice): + if s.start is not None: + begin.append(_as_index(s.start)[0]) + else: + begin.append(0) + begin_mask |= (1 << index) + if s.stop is not None: + end.append(_as_index(s.stop)[0]) + else: + end.append(0) + end_mask |= (1 << index) + if s.step is not None: + strides.append(_as_index(s.step)[0]) + else: + strides.append(1) + elif s is Ellipsis: + begin.append(0) + end.append(0) + strides.append(1) + ellipsis_mask |= (1 << index) + elif s is array_ops.newaxis: + begin.append(0) + end.append(0) + strides.append(1) + new_axis_mask |= (1 << index) + else: + s, is_scalar = _as_index(s, False) + if is_scalar: + begin.append(s) + end.append(s + 1) + strides.append(1) + shrink_axis_mask |= (1 << index) + shrink_indices.append(index) + else: + begin.append(0) + end.append(0) + strides.append(1) + begin_mask |= (1 << index) + end_mask |= (1 << index) + advanced_indices.append((index, s, ellipsis_mask != 0)) + + # stack possibly involves no tensors, so we must use op_scope correct graph. + with ops.name_scope( + None, + 'strided_slice', [tensor] + begin + end + strides, + skip_on_eager=False) as name: + if begin: + packed_begin, packed_end, packed_strides = (array_ops.stack(begin), + array_ops.stack(end), + array_ops.stack(strides)) + if (packed_begin.dtype == dtypes.int64 or + packed_end.dtype == dtypes.int64 or + packed_strides.dtype == dtypes.int64): + if packed_begin.dtype != dtypes.int64: + packed_begin = math_ops.cast(packed_begin, dtypes.int64) + if packed_end.dtype != dtypes.int64: + packed_end = math_ops.cast(packed_end, dtypes.int64) + if packed_strides.dtype != dtypes.int64: + packed_strides = math_ops.cast(packed_strides, dtypes.int64) + else: + var_empty = constant_op.constant([], dtype=dtypes.int32) + packed_begin = packed_end = packed_strides = var_empty + # TODO(agarwal): set_shape on tensor to set rank. + tensor = array_ops.strided_slice( + tensor, + packed_begin, + packed_end, + packed_strides, + begin_mask=begin_mask, + end_mask=end_mask, + shrink_axis_mask=shrink_axis_mask, + new_axis_mask=new_axis_mask, + ellipsis_mask=ellipsis_mask, + name=name) + if not advanced_indices: + return tensor + advanced_indices_map = {} + for index, data, had_ellipsis in advanced_indices: + if had_ellipsis: + num_shrink = len([x for x in shrink_indices if x > index]) + dim = index - len(slice_spec) + num_shrink + else: + num_shrink = len([x for x in shrink_indices if x < index]) + dim = index - num_shrink + advanced_indices_map[dim] = data + dims = sorted(advanced_indices_map.keys()) + dims_contiguous = True + if len(dims) > 1: + if dims[0] < 0 and dims[-1] >= 0: # not all same sign + dims_contiguous = False + else: + for i in range(len(dims) - 1): + if dims[i] + 1 != dims[i + 1]: + dims_contiguous = False + break + indices = [advanced_indices_map[x] for x in dims] + indices = [x.data for x in _promote_dtype(*indices)] + indices = np_utils.tf_broadcast(*indices) + stacked_indices = array_ops.stack(indices, axis=-1) + if not dims_contiguous: + tensor = moveaxis(tensor, dims, range(len(dims))).data + tensor_shape_prefix = array_ops.shape( + tensor, out_type=stacked_indices.dtype)[:len(dims)] + stacked_indices = array_ops.where_v2( + stacked_indices < 0, stacked_indices + tensor_shape_prefix, + stacked_indices) + return array_ops.gather_nd(tensor, stacked_indices) + # Note that gather_nd does not support gathering from inside the array. + # To avoid shuffling data back and forth, we transform the indices and + # do a gather instead. + rank = np_utils._maybe_static(array_ops.rank(tensor)) # pylint: disable=protected-access + dims = [(x + rank if x < 0 else x) for x in dims] + shape_tensor = array_ops.shape(tensor, out_type=stacked_indices.dtype) + dim_sizes = array_ops.gather(shape_tensor, dims) + if len(dims) == 1: + stacked_indices = indices[0] + stacked_indices = array_ops.where_v2(stacked_indices < 0, + stacked_indices + dim_sizes, + stacked_indices) + axis = dims[0] + if len(dims) > 1: + index_scaling = math_ops.cumprod( + dim_sizes, reverse=True, exclusive=True) + stacked_indices = math_ops.tensordot( + stacked_indices, index_scaling, axes=1) + flat_shape = array_ops.concat( + [shape_tensor[:axis], [-1], shape_tensor[axis + len(dims):]], + axis=0) + tensor = array_ops.reshape(tensor, flat_shape) + + return array_ops.gather(tensor, stacked_indices, axis=axis) + + +def _as_spec_tuple(slice_spec): + """Convert slice_spec to tuple.""" + if isinstance(slice_spec, + Sequence) and not isinstance(slice_spec, np.ndarray): + is_index = True + for s in slice_spec: + if s is None or s is Ellipsis or isinstance(s, (Sequence, slice)): + is_index = False + break + elif isinstance(s, (np_arrays.ndarray, np.ndarray)) and s.ndim != 0: + is_index = False + break + if not is_index: + return tuple(slice_spec) + return (slice_spec,) + + +def _getitem(self, slice_spec): + """Implementation of ndarray.__getitem__.""" + if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and + slice_spec.dtype == dtypes.bool) or + (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and + slice_spec.dtype == np.bool)): + return np_utils.tensor_to_ndarray( + array_ops.boolean_mask(tensor=self.data, mask=slice_spec)) + + if not isinstance(slice_spec, tuple): + slice_spec = _as_spec_tuple(slice_spec) + + result_t = _slice_helper(self.data, slice_spec) + return np_utils.tensor_to_ndarray(result_t) + + +setattr(np_arrays.ndarray, '__getitem__', _getitem) diff --git a/tensorflow/python/ops/numpy_ops/np_arrays.py b/tensorflow/python/ops/numpy_ops/np_arrays.py index 8bec8a469a2..88bf4e7499a 100644 --- a/tensorflow/python/ops/numpy_ops/np_arrays.py +++ b/tensorflow/python/ops/numpy_ops/np_arrays.py @@ -20,138 +20,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numbers import numpy as np import six from tensorflow.python.framework import composite_tensor -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.numpy_ops import np_dtypes -from tensorflow.python.util import nest - - -_SLICE_TYPE_ERROR = ( - 'Only integers, slices (`:`), ellipsis (`...`), ' - 'tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid ' - 'indices') - -_SUPPORTED_SLICE_DTYPES = (dtypes.int32, dtypes.int32_ref, dtypes.int64, - dtypes.int64_ref) - - -def _check_index(idx): - """Check if a given value is a valid index into a tensor.""" - if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)): - return - - # Optimistic check. Assumptions: - # * any object with a dtype is supported - # * any object with a dtype has a sizeable shape attribute. - dtype = getattr(idx, 'dtype', None) - if (dtype is None or dtypes.as_dtype(dtype) not in _SUPPORTED_SLICE_DTYPES or - idx.shape and len(idx.shape) == 1): - # TODO(slebedev): IndexError seems more appropriate here, but it - # will break `_slice_helper` contract. - raise TypeError(_SLICE_TYPE_ERROR + ', got {!r}'.format(idx)) - - -def _is_undefined_dimension(d): - return isinstance(d, tensor_shape.Dimension) and d.value is None - - -def _slice_helper(tensor, slice_spec, var=None): - """Copied from array_ops._slice_helper, will be merged back later.""" - if isinstance(slice_spec, bool) or \ - (isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \ - (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool): - return array_ops.boolean_mask(tensor=tensor, mask=slice_spec) - - if not isinstance(slice_spec, (list, tuple)): - slice_spec = [slice_spec] - - begin, end, strides = [], [], [] - index = 0 - - new_axis_mask, shrink_axis_mask = 0, 0 - begin_mask, end_mask = 0, 0 - ellipsis_mask = 0 - for s in slice_spec: - if isinstance(s, slice): - if s.start is not None and not _is_undefined_dimension(s.start): - _check_index(s.start) - begin.append(s.start) - else: - begin.append(0) - begin_mask |= (1 << index) - if s.stop is not None and not _is_undefined_dimension(s.stop): - _check_index(s.stop) - end.append(s.stop) - else: - end.append(0) - end_mask |= (1 << index) - if s.step is not None and not _is_undefined_dimension(s.step): - _check_index(s.step) - strides.append(s.step) - else: - strides.append(1) - elif s is Ellipsis: - begin.append(0) - end.append(0) - strides.append(1) - ellipsis_mask |= (1 << index) - elif s is array_ops.newaxis: - begin.append(0) - end.append(0) - strides.append(1) - new_axis_mask |= (1 << index) - else: - _check_index(s) - begin.append(s) - end.append(s + 1) - strides.append(1) - shrink_axis_mask |= (1 << index) - index += 1 - - # stack possibly involves no tensors, so we must use op_scope correct graph. - with ops.name_scope( - None, - 'strided_slice', [tensor] + begin + end + strides, - skip_on_eager=False) as name: - if begin: - packed_begin, packed_end, packed_strides = (array_ops.stack(begin), - array_ops.stack(end), - array_ops.stack(strides)) - if (packed_begin.dtype == dtypes.int64 or - packed_end.dtype == dtypes.int64 or - packed_strides.dtype == dtypes.int64): - if packed_begin.dtype != dtypes.int64: - packed_begin = math_ops.cast(packed_begin, dtypes.int64) - if packed_end.dtype != dtypes.int64: - packed_end = math_ops.cast(packed_end, dtypes.int64) - if packed_strides.dtype != dtypes.int64: - packed_strides = math_ops.cast(packed_strides, dtypes.int64) - else: - var_empty = constant_op.constant([], dtype=dtypes.int32) - packed_begin = packed_end = packed_strides = var_empty - return array_ops.strided_slice( - tensor, - packed_begin, - packed_end, - packed_strides, - begin_mask=begin_mask, - end_mask=end_mask, - shrink_axis_mask=shrink_axis_mask, - new_axis_mask=new_axis_mask, - ellipsis_mask=ellipsis_mask, - var=var, - name=name) def convert_to_tensor(value, dtype=None, dtype_hint=None): @@ -361,22 +240,6 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name def __bool__(self): return self.__nonzero__() - def __getitem__(self, slice_spec): - # TODO(srbs): Need to support better indexing. - def _gettensor(x): - if isinstance(x, ndarray): - x = x.data - if isinstance(x, ops.Tensor) and x.dtype not in ( - dtypes.int32, dtypes.int64): - # Currently _slice_helper will only work with int32/int64 tensors, but - # type inference by numpy can create {u,}int{8,16}, so just cast. - x = math_ops.cast(x, dtypes.int32) - return x - slice_spec = nest.map_structure(_gettensor, slice_spec) - - result_t = _slice_helper(self.data, slice_spec) - return tensor_to_ndarray(result_t) - def __iter__(self): if not isinstance(self.data, ops.EagerTensor): raise TypeError('Iteration over symbolic tensor is not allowed')