Support slicing in ShardedVariable. The slicing semantic is identical to Tensor/Variable.

PiperOrigin-RevId: 355249212
Change-Id: Ic9a14b5ae5cc0a446142eaa529f052c09c445396
This commit is contained in:
Chenkai Kuang 2021-02-02 14:10:17 -08:00 committed by TensorFlower Gardener
parent c21185ae1c
commit b4bf78ffec
3 changed files with 269 additions and 3 deletions

View File

@ -1131,6 +1131,8 @@ py_library(
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:composite_tensor",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:partitioned_variables",
@ -1139,10 +1141,12 @@ py_library(
"//tensorflow/python:type_spec",
"//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/saved_model:revived_types",
"//tensorflow/python/saved_model:save_context",
"//tensorflow/python/training/saving:saveable_object_util",
"//tensorflow/python/training/tracking:base",
"//tensorflow/python/util:tf_export",
"//third_party/py/numpy",
],
)
@ -1160,10 +1164,10 @@ tf_py_test(
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework_ops",
"//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/eager:def_function",

View File

@ -18,8 +18,12 @@ from __future__ import division
from __future__ import print_function
import copy
import math
import numpy as np
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import type_spec
@ -282,8 +286,8 @@ class ShardedVariableMixin(trackable.Trackable):
raise ValueError(
'Expected a list of `Variable`s, found: {}'.format(variables))
dtypes = {v.dtype for v in variables}
if len(dtypes) > 1:
var_dtypes = {v.dtype for v in variables}
if len(var_dtypes) > 1:
raise ValueError(
'All `Variable`s must have the same dtype, found: {}'.format(
[v.dtype for v in variables]))
@ -322,6 +326,171 @@ class ShardedVariableMixin(trackable.Trackable):
"""Return an iterable for accessing the underlying sharded variables."""
return iter(self._variables)
def __getitem__(self, slice_spec):
"""Extracts the specified region as a Tensor from the sharded variable.
The API contract is identical to `Tensor.__getitem__`. Assignment to the
sliced range is not yet supported.
Args:
slice_spec: The arguments to __getitem__, specifying the global slicing of
the sharded variable.
Returns:
The appropriate slice of tensor based on `slice_spec`.
Raises:
IndexError: If a slice index is out of bound.
TypeError: If `spec_spec` contains Tensor.
"""
# TODO(b/177482728): Support tensor input.
# TODO(b/177482728): Support slice assign, similar to variable slice assign.
if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
slice_spec.dtype == dtypes.bool) or
(isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)):
tensor = _var_to_tensor(self)
return array_ops.boolean_mask(tensor=tensor, mask=slice_spec)
if not isinstance(slice_spec, (list, tuple)):
slice_spec = (slice_spec,)
s = slice_spec[0]
if isinstance(s, slice):
first_dim_slice_specs = self._decompose_slice_spec(s)
values = []
for i, var in enumerate(self._variables):
if first_dim_slice_specs[i] is not None:
all_dim_slice_spec = (first_dim_slice_specs[i],) + slice_spec[1:]
values.append(var[all_dim_slice_spec])
if s.step is not None and s.step < 0:
values.reverse()
if not values:
return constant_op.constant([],
dtype=self._dtype,
shape=((0,) + self._shape[1:]))
return array_ops.concat(values, axis=0)
elif s is Ellipsis:
return array_ops.concat([var[slice_spec] for var in self._variables],
axis=0)
elif s is array_ops.newaxis:
return array_ops.concat([var[slice_spec[1:]] for var in self._variables],
axis=0)[array_ops.newaxis]
else:
if isinstance(s, ops.Tensor):
raise TypeError(
'ShardedVariable: using Tensor for indexing is not allowed.')
if s < 0:
s += self._shape[0]
if s < 0 or s >= self._shape[0]:
raise IndexError('slice index %d of dimension 0 out of bounds.' % s)
for i in range(len(self._variables)):
if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and
s < self._var_offsets[i + 1][0]):
return self._variables[i][(s - self._var_offsets[i][0],) +
slice_spec[1:]]
def _decompose_slice_spec(self, slice_spec):
"""Decompose a global slice_spec into a list of per-variable slice_spec.
`ShardedVariable` only supports first dimension partitioning, thus
`slice_spec` must be for first dimension.
Args:
slice_spec: A python `slice` object that specifies the global slicing.
Returns:
A list of python `slice` objects or None specifying the local slicing for
each component variable. None means no slicing.
For example, given component variables:
v0 = [0, 1, 2]
v1 = [3, 4, 5]
v2 = [6, 7, 8, 9]
If `slice_spec` is slice(start=None, stop=None, step=None), we will have:
v0[returned[0]] = [0, 1, 2]
v1[returned[1]] = [3, 4, 5]
v2[returned[2]] = [6, 7, 8, 9]
If `slice_spec` is slice(start=2, stop=8, step=3), we will have:
v0[returned[0]] = [2]
v1[returned[1]] = [5]
returned[2] == None
If `slice_spec` is slice(start=9, stop=3, step=-2), we will have:
returned[0] == None
v1[returned[1]] = [5]
v2[returned[2]] = [9, 7]
"""
if isinstance(slice_spec.start, ops.Tensor) or isinstance(
slice_spec.stop, ops.Tensor) or isinstance(slice_spec.step, ops.Tensor):
raise TypeError(
'ShardedVariable: using Tensor in slice_spec is not allowed. Please '
'file a feature request with the TensorFlow team.')
result = []
# Normalize start, end and stop.
slice_step = slice_spec.step if slice_spec.step is not None else 1
if slice_step == 0:
raise ValueError('slice step cannot be zero')
slice_start = slice_spec.start
if slice_start is None:
slice_start = 0 if slice_step > 0 else self._shape[0] - 1
elif slice_start < 0:
slice_start += self._shape[0]
slice_end = slice_spec.stop
if slice_end is None:
# After the normalization, we no longer interpret negative index, thus
# "-1" conceptually refers to the element before the first one, which
# doesn't exist. This is to ease the decomposition code.
slice_end = self._shape[0] if slice_step > 0 else -1
elif slice_end < 0:
slice_end += self._shape[0]
# To find the local slice_spec of each component variable, we start from
# the start of the global slice, and iterate through each variable.
# When iterating on a variable, we move the cursor (`cur`) to the first
# index that falls into the variable's range, which becomes the start of
# the variable's local slice_spec. The end of the local_spec is determined
# by using whatever is smaller between global slice end and variable range
# end.
cur = slice_start
if slice_step > 0:
for i in range(len(self._var_offsets)):
var_start = self._var_offsets[i][0]
var_end = (
self._var_offsets[i + 1][0]
if i < len(self._var_offsets) - 1 else self._shape[0])
if cur < var_start:
cur += slice_step * int(math.ceil((var_start - cur) / slice_step))
if cur >= var_end or cur >= slice_end:
result.append(None)
else:
start = cur - var_start
end = min(slice_end, var_end) - var_start
result.append(slice(start, end, slice_step))
else: # slice_step < 0
for i in range(len(self._var_offsets) - 1, -1, -1):
var_start = self._var_offsets[i][0]
var_end = (
self._var_offsets[i + 1][0]
if i < len(self._var_offsets) - 1 else self._shape[0])
if cur >= var_end:
cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step))
if cur < var_start or cur <= slice_end:
result.append(None)
else:
start = cur - var_start
if slice_end >= var_start:
end = slice_end - var_start
else:
end = None # no explicit end: slice until hitting the boundary.
result.append(slice(start, end, slice_step))
result.reverse()
return result
@property
def _type_spec(self):
return ShardedVariableSpec(*(

View File

@ -449,6 +449,99 @@ class ShardedVariableTest(test.TestCase):
self.assertAllClose(sparse_lookup(), [[4., 5.], [9., 10.], [3., 4.]])
self.assertAllClose(safe_sparse_lookup(), [[1., 2.], [0., 0.], [3., 4.]])
def test_slicing(self):
v = [
variables_lib.Variable([[1, 2], [3, 4], [5, 6]]),
variables_lib.Variable([[7, 8], [9, 10], [11, 12]]),
variables_lib.Variable([[13, 14], [15, 16]])
]
sv = sharded_variable.ShardedVariable(v)
empty = v[0][0:0]
# Test cases: positive step
self.assertAllEqual(sv[:], array_ops.concat(v, axis=0))
self.assertAllEqual(sv[:2], [[1, 2], [3, 4]])
self.assertAllEqual(sv[-8:2], [[1, 2], [3, 4]])
self.assertAllEqual(sv[-10:2], [[1, 2], [3, 4]])
self.assertAllEqual(sv[5:], [[11, 12], [13, 14], [15, 16]])
self.assertAllEqual(sv[5:-1], [[11, 12], [13, 14]])
self.assertAllEqual(sv[::3], [[1, 2], [7, 8], [13, 14]])
self.assertAllEqual(sv[::5], [[1, 2], [11, 12]])
self.assertAllEqual(sv[1::6], [[3, 4], [15, 16]])
self.assertAllEqual(sv[1:5:6], [[3, 4]])
self.assertAllEqual(sv[1::7], [[3, 4]])
self.assertAllEqual(sv[2:7], [[5, 6], [7, 8], [9, 10], [11, 12], [13, 14]])
self.assertAllEqual(sv[2:7:2], [[5, 6], [9, 10], [13, 14]])
self.assertAllEqual(sv[2:7:3], [[5, 6], [11, 12]])
# Test cases: negative step
self.assertAllEqual(
sv[::-1], array_ops.reverse(array_ops.concat(v, axis=0), axis=[0]))
self.assertAllEqual(sv[2::-1], [[5, 6], [3, 4], [1, 2]])
self.assertAllEqual(sv[2:-8:-1], [[5, 6], [3, 4]])
self.assertAllEqual(sv[2:-10:-1], [[5, 6], [3, 4], [1, 2]])
self.assertAllEqual(sv[4::-1], [[9, 10], [7, 8], [5, 6], [3, 4], [1, 2]])
self.assertAllEqual(sv[-1:-3:-1], [[15, 16], [13, 14]])
self.assertAllEqual(sv[::-5], [[15, 16], [5, 6]])
self.assertAllEqual(sv[6::-6], [[13, 14], [1, 2]])
self.assertAllEqual(sv[6:5:-6], [[13, 14]])
self.assertAllEqual(sv[6::-7], [[13, 14]])
self.assertAllEqual(sv[7:1:-1],
[[15, 16], [13, 14], [11, 12], [9, 10], [7, 8], [5, 6]])
self.assertAllEqual(sv[7:1:-2], [[15, 16], [11, 12], [7, 8]])
self.assertAllEqual(sv[7:1:-4], [[15, 16], [7, 8]])
# Test cases: empty slice
self.assertAllEqual(sv[0:0], empty)
self.assertAllEqual(sv[5:3], empty)
self.assertAllEqual(sv[3:5:-1], empty)
self.assertAllEqual(sv[-1:0], empty)
self.assertAllEqual(sv[2:-1:-1], empty)
# Test cases: slicing other dimensions
self.assertAllEqual(sv[:, 0], [1, 3, 5, 7, 9, 11, 13, 15])
self.assertAllEqual(sv[:, 0:1], [[1], [3], [5], [7], [9], [11], [13], [15]])
# Test cases: normal indexing
self.assertAllEqual(sv[2], [5, 6])
self.assertAllEqual(sv[6], [13, 14])
self.assertAllEqual(sv[2, 1], 6)
self.assertAllEqual(sv[-2], [13, 14])
with self.assertRaisesRegex(IndexError, 'out of bounds'):
_ = sv[100]
with self.assertRaisesRegex(IndexError, 'out of bounds'):
_ = sv[-100]
# Test cases: Ellipsis
self.assertAllEqual(sv[...], array_ops.concat(v, axis=0))
self.assertAllEqual(sv[..., 0], [1, 3, 5, 7, 9, 11, 13, 15])
self.assertAllEqual(sv[0:1, ...], [[1, 2]])
# Test cases: newaxis
self.assertAllEqual(
sv[array_ops.newaxis, ...],
array_ops.expand_dims_v2(array_ops.concat(v, axis=0), axis=0))
# Test cases: boolean masks
self.assertAllEqual(sv[ops.convert_to_tensor(sv) > 10],
[11, 12, 13, 14, 15, 16])
# Test cases: tensor input
with self.assertRaisesRegex(TypeError, 'not allowed'):
_ = sv[constant_op.constant(1)::]
with self.assertRaisesRegex(TypeError, 'not allowed'):
_ = sv[:constant_op.constant(1):]
with self.assertRaisesRegex(TypeError, 'not allowed'):
_ = sv[constant_op.constant(1)]
# Test cases: inside tf.function
@def_function.function
def func():
a = sv[:, 0]
return a
self.assertAllEqual(func(), [1, 3, 5, 7, 9, 11, 13, 15])
if __name__ == '__main__':
v2_compat.enable_v2_behavior()