Support slicing in ShardedVariable. The slicing semantic is identical to Tensor/Variable.
PiperOrigin-RevId: 355249212 Change-Id: Ic9a14b5ae5cc0a446142eaa529f052c09c445396
This commit is contained in:
parent
c21185ae1c
commit
b4bf78ffec
@ -1131,6 +1131,8 @@ py_library(
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:composite_tensor",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:embedding_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:partitioned_variables",
|
||||
@ -1139,10 +1141,12 @@ py_library(
|
||||
"//tensorflow/python:type_spec",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/saved_model:revived_types",
|
||||
"//tensorflow/python/saved_model:save_context",
|
||||
"//tensorflow/python/training/saving:saveable_object_util",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"//tensorflow/python/util:tf_export",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1160,10 +1164,10 @@ tf_py_test(
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
|
@ -18,8 +18,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import type_spec
|
||||
@ -282,8 +286,8 @@ class ShardedVariableMixin(trackable.Trackable):
|
||||
raise ValueError(
|
||||
'Expected a list of `Variable`s, found: {}'.format(variables))
|
||||
|
||||
dtypes = {v.dtype for v in variables}
|
||||
if len(dtypes) > 1:
|
||||
var_dtypes = {v.dtype for v in variables}
|
||||
if len(var_dtypes) > 1:
|
||||
raise ValueError(
|
||||
'All `Variable`s must have the same dtype, found: {}'.format(
|
||||
[v.dtype for v in variables]))
|
||||
@ -322,6 +326,171 @@ class ShardedVariableMixin(trackable.Trackable):
|
||||
"""Return an iterable for accessing the underlying sharded variables."""
|
||||
return iter(self._variables)
|
||||
|
||||
def __getitem__(self, slice_spec):
|
||||
"""Extracts the specified region as a Tensor from the sharded variable.
|
||||
|
||||
The API contract is identical to `Tensor.__getitem__`. Assignment to the
|
||||
sliced range is not yet supported.
|
||||
|
||||
Args:
|
||||
slice_spec: The arguments to __getitem__, specifying the global slicing of
|
||||
the sharded variable.
|
||||
|
||||
Returns:
|
||||
The appropriate slice of tensor based on `slice_spec`.
|
||||
|
||||
Raises:
|
||||
IndexError: If a slice index is out of bound.
|
||||
TypeError: If `spec_spec` contains Tensor.
|
||||
"""
|
||||
|
||||
# TODO(b/177482728): Support tensor input.
|
||||
# TODO(b/177482728): Support slice assign, similar to variable slice assign.
|
||||
|
||||
if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
|
||||
slice_spec.dtype == dtypes.bool) or
|
||||
(isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)):
|
||||
tensor = _var_to_tensor(self)
|
||||
return array_ops.boolean_mask(tensor=tensor, mask=slice_spec)
|
||||
|
||||
if not isinstance(slice_spec, (list, tuple)):
|
||||
slice_spec = (slice_spec,)
|
||||
|
||||
s = slice_spec[0]
|
||||
if isinstance(s, slice):
|
||||
first_dim_slice_specs = self._decompose_slice_spec(s)
|
||||
values = []
|
||||
for i, var in enumerate(self._variables):
|
||||
if first_dim_slice_specs[i] is not None:
|
||||
all_dim_slice_spec = (first_dim_slice_specs[i],) + slice_spec[1:]
|
||||
values.append(var[all_dim_slice_spec])
|
||||
if s.step is not None and s.step < 0:
|
||||
values.reverse()
|
||||
if not values:
|
||||
return constant_op.constant([],
|
||||
dtype=self._dtype,
|
||||
shape=((0,) + self._shape[1:]))
|
||||
return array_ops.concat(values, axis=0)
|
||||
elif s is Ellipsis:
|
||||
return array_ops.concat([var[slice_spec] for var in self._variables],
|
||||
axis=0)
|
||||
elif s is array_ops.newaxis:
|
||||
return array_ops.concat([var[slice_spec[1:]] for var in self._variables],
|
||||
axis=0)[array_ops.newaxis]
|
||||
else:
|
||||
if isinstance(s, ops.Tensor):
|
||||
raise TypeError(
|
||||
'ShardedVariable: using Tensor for indexing is not allowed.')
|
||||
if s < 0:
|
||||
s += self._shape[0]
|
||||
if s < 0 or s >= self._shape[0]:
|
||||
raise IndexError('slice index %d of dimension 0 out of bounds.' % s)
|
||||
for i in range(len(self._variables)):
|
||||
if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and
|
||||
s < self._var_offsets[i + 1][0]):
|
||||
return self._variables[i][(s - self._var_offsets[i][0],) +
|
||||
slice_spec[1:]]
|
||||
|
||||
def _decompose_slice_spec(self, slice_spec):
|
||||
"""Decompose a global slice_spec into a list of per-variable slice_spec.
|
||||
|
||||
`ShardedVariable` only supports first dimension partitioning, thus
|
||||
`slice_spec` must be for first dimension.
|
||||
|
||||
Args:
|
||||
slice_spec: A python `slice` object that specifies the global slicing.
|
||||
|
||||
Returns:
|
||||
A list of python `slice` objects or None specifying the local slicing for
|
||||
each component variable. None means no slicing.
|
||||
|
||||
For example, given component variables:
|
||||
v0 = [0, 1, 2]
|
||||
v1 = [3, 4, 5]
|
||||
v2 = [6, 7, 8, 9]
|
||||
|
||||
If `slice_spec` is slice(start=None, stop=None, step=None), we will have:
|
||||
v0[returned[0]] = [0, 1, 2]
|
||||
v1[returned[1]] = [3, 4, 5]
|
||||
v2[returned[2]] = [6, 7, 8, 9]
|
||||
If `slice_spec` is slice(start=2, stop=8, step=3), we will have:
|
||||
v0[returned[0]] = [2]
|
||||
v1[returned[1]] = [5]
|
||||
returned[2] == None
|
||||
If `slice_spec` is slice(start=9, stop=3, step=-2), we will have:
|
||||
returned[0] == None
|
||||
v1[returned[1]] = [5]
|
||||
v2[returned[2]] = [9, 7]
|
||||
"""
|
||||
if isinstance(slice_spec.start, ops.Tensor) or isinstance(
|
||||
slice_spec.stop, ops.Tensor) or isinstance(slice_spec.step, ops.Tensor):
|
||||
raise TypeError(
|
||||
'ShardedVariable: using Tensor in slice_spec is not allowed. Please '
|
||||
'file a feature request with the TensorFlow team.')
|
||||
|
||||
result = []
|
||||
# Normalize start, end and stop.
|
||||
slice_step = slice_spec.step if slice_spec.step is not None else 1
|
||||
if slice_step == 0:
|
||||
raise ValueError('slice step cannot be zero')
|
||||
slice_start = slice_spec.start
|
||||
if slice_start is None:
|
||||
slice_start = 0 if slice_step > 0 else self._shape[0] - 1
|
||||
elif slice_start < 0:
|
||||
slice_start += self._shape[0]
|
||||
slice_end = slice_spec.stop
|
||||
if slice_end is None:
|
||||
# After the normalization, we no longer interpret negative index, thus
|
||||
# "-1" conceptually refers to the element before the first one, which
|
||||
# doesn't exist. This is to ease the decomposition code.
|
||||
slice_end = self._shape[0] if slice_step > 0 else -1
|
||||
elif slice_end < 0:
|
||||
slice_end += self._shape[0]
|
||||
|
||||
# To find the local slice_spec of each component variable, we start from
|
||||
# the start of the global slice, and iterate through each variable.
|
||||
# When iterating on a variable, we move the cursor (`cur`) to the first
|
||||
# index that falls into the variable's range, which becomes the start of
|
||||
# the variable's local slice_spec. The end of the local_spec is determined
|
||||
# by using whatever is smaller between global slice end and variable range
|
||||
# end.
|
||||
cur = slice_start
|
||||
if slice_step > 0:
|
||||
for i in range(len(self._var_offsets)):
|
||||
var_start = self._var_offsets[i][0]
|
||||
var_end = (
|
||||
self._var_offsets[i + 1][0]
|
||||
if i < len(self._var_offsets) - 1 else self._shape[0])
|
||||
if cur < var_start:
|
||||
cur += slice_step * int(math.ceil((var_start - cur) / slice_step))
|
||||
if cur >= var_end or cur >= slice_end:
|
||||
result.append(None)
|
||||
else:
|
||||
start = cur - var_start
|
||||
end = min(slice_end, var_end) - var_start
|
||||
result.append(slice(start, end, slice_step))
|
||||
else: # slice_step < 0
|
||||
for i in range(len(self._var_offsets) - 1, -1, -1):
|
||||
var_start = self._var_offsets[i][0]
|
||||
var_end = (
|
||||
self._var_offsets[i + 1][0]
|
||||
if i < len(self._var_offsets) - 1 else self._shape[0])
|
||||
if cur >= var_end:
|
||||
cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step))
|
||||
if cur < var_start or cur <= slice_end:
|
||||
result.append(None)
|
||||
else:
|
||||
start = cur - var_start
|
||||
if slice_end >= var_start:
|
||||
end = slice_end - var_start
|
||||
else:
|
||||
end = None # no explicit end: slice until hitting the boundary.
|
||||
result.append(slice(start, end, slice_step))
|
||||
|
||||
result.reverse()
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
return ShardedVariableSpec(*(
|
||||
|
@ -449,6 +449,99 @@ class ShardedVariableTest(test.TestCase):
|
||||
self.assertAllClose(sparse_lookup(), [[4., 5.], [9., 10.], [3., 4.]])
|
||||
self.assertAllClose(safe_sparse_lookup(), [[1., 2.], [0., 0.], [3., 4.]])
|
||||
|
||||
def test_slicing(self):
|
||||
v = [
|
||||
variables_lib.Variable([[1, 2], [3, 4], [5, 6]]),
|
||||
variables_lib.Variable([[7, 8], [9, 10], [11, 12]]),
|
||||
variables_lib.Variable([[13, 14], [15, 16]])
|
||||
]
|
||||
sv = sharded_variable.ShardedVariable(v)
|
||||
empty = v[0][0:0]
|
||||
|
||||
# Test cases: positive step
|
||||
self.assertAllEqual(sv[:], array_ops.concat(v, axis=0))
|
||||
self.assertAllEqual(sv[:2], [[1, 2], [3, 4]])
|
||||
self.assertAllEqual(sv[-8:2], [[1, 2], [3, 4]])
|
||||
self.assertAllEqual(sv[-10:2], [[1, 2], [3, 4]])
|
||||
self.assertAllEqual(sv[5:], [[11, 12], [13, 14], [15, 16]])
|
||||
self.assertAllEqual(sv[5:-1], [[11, 12], [13, 14]])
|
||||
self.assertAllEqual(sv[::3], [[1, 2], [7, 8], [13, 14]])
|
||||
self.assertAllEqual(sv[::5], [[1, 2], [11, 12]])
|
||||
self.assertAllEqual(sv[1::6], [[3, 4], [15, 16]])
|
||||
self.assertAllEqual(sv[1:5:6], [[3, 4]])
|
||||
self.assertAllEqual(sv[1::7], [[3, 4]])
|
||||
self.assertAllEqual(sv[2:7], [[5, 6], [7, 8], [9, 10], [11, 12], [13, 14]])
|
||||
self.assertAllEqual(sv[2:7:2], [[5, 6], [9, 10], [13, 14]])
|
||||
self.assertAllEqual(sv[2:7:3], [[5, 6], [11, 12]])
|
||||
|
||||
# Test cases: negative step
|
||||
self.assertAllEqual(
|
||||
sv[::-1], array_ops.reverse(array_ops.concat(v, axis=0), axis=[0]))
|
||||
self.assertAllEqual(sv[2::-1], [[5, 6], [3, 4], [1, 2]])
|
||||
self.assertAllEqual(sv[2:-8:-1], [[5, 6], [3, 4]])
|
||||
self.assertAllEqual(sv[2:-10:-1], [[5, 6], [3, 4], [1, 2]])
|
||||
self.assertAllEqual(sv[4::-1], [[9, 10], [7, 8], [5, 6], [3, 4], [1, 2]])
|
||||
self.assertAllEqual(sv[-1:-3:-1], [[15, 16], [13, 14]])
|
||||
self.assertAllEqual(sv[::-5], [[15, 16], [5, 6]])
|
||||
self.assertAllEqual(sv[6::-6], [[13, 14], [1, 2]])
|
||||
self.assertAllEqual(sv[6:5:-6], [[13, 14]])
|
||||
self.assertAllEqual(sv[6::-7], [[13, 14]])
|
||||
self.assertAllEqual(sv[7:1:-1],
|
||||
[[15, 16], [13, 14], [11, 12], [9, 10], [7, 8], [5, 6]])
|
||||
self.assertAllEqual(sv[7:1:-2], [[15, 16], [11, 12], [7, 8]])
|
||||
self.assertAllEqual(sv[7:1:-4], [[15, 16], [7, 8]])
|
||||
|
||||
# Test cases: empty slice
|
||||
self.assertAllEqual(sv[0:0], empty)
|
||||
self.assertAllEqual(sv[5:3], empty)
|
||||
self.assertAllEqual(sv[3:5:-1], empty)
|
||||
self.assertAllEqual(sv[-1:0], empty)
|
||||
self.assertAllEqual(sv[2:-1:-1], empty)
|
||||
|
||||
# Test cases: slicing other dimensions
|
||||
self.assertAllEqual(sv[:, 0], [1, 3, 5, 7, 9, 11, 13, 15])
|
||||
self.assertAllEqual(sv[:, 0:1], [[1], [3], [5], [7], [9], [11], [13], [15]])
|
||||
|
||||
# Test cases: normal indexing
|
||||
self.assertAllEqual(sv[2], [5, 6])
|
||||
self.assertAllEqual(sv[6], [13, 14])
|
||||
self.assertAllEqual(sv[2, 1], 6)
|
||||
self.assertAllEqual(sv[-2], [13, 14])
|
||||
with self.assertRaisesRegex(IndexError, 'out of bounds'):
|
||||
_ = sv[100]
|
||||
with self.assertRaisesRegex(IndexError, 'out of bounds'):
|
||||
_ = sv[-100]
|
||||
|
||||
# Test cases: Ellipsis
|
||||
self.assertAllEqual(sv[...], array_ops.concat(v, axis=0))
|
||||
self.assertAllEqual(sv[..., 0], [1, 3, 5, 7, 9, 11, 13, 15])
|
||||
self.assertAllEqual(sv[0:1, ...], [[1, 2]])
|
||||
|
||||
# Test cases: newaxis
|
||||
self.assertAllEqual(
|
||||
sv[array_ops.newaxis, ...],
|
||||
array_ops.expand_dims_v2(array_ops.concat(v, axis=0), axis=0))
|
||||
|
||||
# Test cases: boolean masks
|
||||
self.assertAllEqual(sv[ops.convert_to_tensor(sv) > 10],
|
||||
[11, 12, 13, 14, 15, 16])
|
||||
|
||||
# Test cases: tensor input
|
||||
with self.assertRaisesRegex(TypeError, 'not allowed'):
|
||||
_ = sv[constant_op.constant(1)::]
|
||||
with self.assertRaisesRegex(TypeError, 'not allowed'):
|
||||
_ = sv[:constant_op.constant(1):]
|
||||
with self.assertRaisesRegex(TypeError, 'not allowed'):
|
||||
_ = sv[constant_op.constant(1)]
|
||||
|
||||
# Test cases: inside tf.function
|
||||
@def_function.function
|
||||
def func():
|
||||
a = sv[:, 0]
|
||||
return a
|
||||
|
||||
self.assertAllEqual(func(), [1, 3, 5, 7, 9, 11, 13, 15])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v2_compat.enable_v2_behavior()
|
||||
|
Loading…
Reference in New Issue
Block a user