Add ShardedVariable
class.
PiperOrigin-RevId: 272745815
This commit is contained in:
parent
84f9d53683
commit
89e33e5ef3
tensorflow/python
@ -1,6 +1,5 @@
|
||||
load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
|
||||
load("//tensorflow/core/platform:default/distribute.bzl", "distribute_py_test")
|
||||
|
||||
package(
|
||||
@ -132,6 +131,7 @@ py_library(
|
||||
":distribute_lib",
|
||||
":mirrored_strategy",
|
||||
":one_device_strategy",
|
||||
":sharded_variable",
|
||||
"//tensorflow/python/distribute/experimental",
|
||||
],
|
||||
)
|
||||
@ -778,6 +778,32 @@ cuda_py_test(
|
||||
grpc_enabled = True,
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "sharded_variable",
|
||||
srcs = ["sharded_variable.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/training/saving:saveable_object_util",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "sharded_variable_test",
|
||||
size = "small",
|
||||
srcs = ["sharded_variable_test.py"],
|
||||
additional_deps = [
|
||||
":sharded_variable",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/training/tracking:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "strategy_test_lib",
|
||||
srcs = ["strategy_test_lib.py"],
|
||||
|
139
tensorflow/python/distribute/sharded_variable.py
Normal file
139
tensorflow/python/distribute/sharded_variable.py
Normal file
@ -0,0 +1,139 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ShardedVariable class."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
|
||||
|
||||
class ShardedVariable(trackable.Trackable):
|
||||
"""A container for `Variables` that should be treated as shards.
|
||||
|
||||
Variables that are too large to fit on a single device (e.g., large
|
||||
embeddings)
|
||||
may need to be sharded over multiple devices. This class maintains a list of
|
||||
smaller variables that can be independently stored on separate devices (eg,
|
||||
multiple parameter servers), and saves and restores those variables as if they
|
||||
were a single larger variable.
|
||||
|
||||
Objects of this class can be saved with a given number of shards and then
|
||||
restored from a checkpoint into a different number of shards.
|
||||
|
||||
Sharding is only supported along the first dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, variables, name='ShardedVariable'):
|
||||
"""Treats `variables` as shards of a larger Variable.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
variables = [
|
||||
tf.Variable(..., shape=(10, 100), dtype=tf.float32),
|
||||
tf.Variable(..., shape=(15, 100), dtype=tf.float32),
|
||||
tf.Variable(..., shape=(5, 100), dtype=tf.float32)
|
||||
]
|
||||
sharded_variable = ShardedVariable(variables)
|
||||
assert sharded_variable.shape.as_list() == [30, 100]
|
||||
```
|
||||
|
||||
Args:
|
||||
variables: A list of `ResourceVariable`s that comprise this sharded
|
||||
variable. Variables should not be shared between different
|
||||
`ShardedVariable` objects.
|
||||
name: String. Name of this container. Defaults to "ShardedVariable".
|
||||
"""
|
||||
super(ShardedVariable, self).__init__()
|
||||
self._variables = variables
|
||||
self._name = name
|
||||
|
||||
first_var = variables[0]
|
||||
|
||||
if any(not isinstance(v, variables_lib.Variable) for v in variables):
|
||||
raise ValueError(
|
||||
'Expected a list of `Variable`s, found: {}'.format(variables))
|
||||
|
||||
dtypes = {v.dtype for v in variables}
|
||||
if len(dtypes) > 1:
|
||||
raise ValueError(
|
||||
'All `Variable`s must have the same dtype, found: {}'.format(
|
||||
[v.dtype for v in variables]))
|
||||
self._dtype = first_var.dtype
|
||||
|
||||
# All variables must have the same shape for axes > 0.
|
||||
higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables}
|
||||
if len(higher_dim_shapes) > 1:
|
||||
raise ValueError(
|
||||
'All `Variables`s must have the same shapes except for the first '
|
||||
'axis, found {}'.format([v.shape for v in variables]))
|
||||
first_dim = sum(int(v.shape[0]) for v in variables)
|
||||
self._shape = tensor_shape.TensorShape([first_dim] + first_var.shape[1:])
|
||||
|
||||
save_slice_info = [v._get_save_slice_info() for v in variables] # pylint: disable=protected-access
|
||||
if any(slice_info is not None for slice_info in save_slice_info):
|
||||
raise ValueError('`SaveSliceInfo` should not be set for `Variable`s. '
|
||||
'`ShardedVariable` will infer `SaveSliceInfo` according '
|
||||
'to the order of the `Variable`s in the list passed to '
|
||||
'the constructor. Found {}'.format(save_slice_info))
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
"""The list of `Variable`s that make up the shards of this object."""
|
||||
return self._variables
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""The name of this object. Used for checkpointing."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""The dtype of all `Variable`s in this object."""
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""The overall shape, combining all shards along axis `0`."""
|
||||
return self._shape
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
"""Return a `Saveable` for each shard. See `Trackable`."""
|
||||
|
||||
def _saveable_factory(name=self.name):
|
||||
"""Creates `SaveableObject`s for this `ShardedVariable`."""
|
||||
saveables = []
|
||||
dims = len(self._variables[0].shape)
|
||||
var_offset = [0 for _ in range(dims)]
|
||||
for v in self._variables:
|
||||
save_slice_info = variables_lib.Variable.SaveSliceInfo(
|
||||
full_name=self.name,
|
||||
full_shape=self.shape.as_list(),
|
||||
var_offset=copy.copy(var_offset),
|
||||
var_shape=v.shape.as_list())
|
||||
saveables.append(
|
||||
saveable_object_util.ResourceVariableSaveable(
|
||||
v, save_slice_info.spec, name)) # pylint: disable=protected-access
|
||||
var_offset[0] += int(v.shape[0])
|
||||
return saveables
|
||||
|
||||
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
|
146
tensorflow/python/distribute/sharded_variable_test.py
Normal file
146
tensorflow/python/distribute/sharded_variable_test.py
Normal file
@ -0,0 +1,146 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for ShardedVariable."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.distribute import sharded_variable
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.tracking import util
|
||||
|
||||
|
||||
class ShardedVariableTest(test.TestCase):
|
||||
|
||||
def test_sharded_variable_simple(self):
|
||||
v0 = variables_lib.Variable([0])
|
||||
v1 = variables_lib.Variable([1])
|
||||
s = sharded_variable.ShardedVariable([v0, v1], name='s')
|
||||
self.assertEqual(s.variables[0], v0)
|
||||
self.assertEqual(s.variables[1], v1)
|
||||
self.assertEqual(s.shape.as_list(), [2])
|
||||
self.assertEqual(s.dtype, v0.dtype)
|
||||
self.assertEqual(s.name, 's')
|
||||
|
||||
def test_save_restore(self):
|
||||
fname = os.path.join(self.get_temp_dir(), 'checkpoint')
|
||||
variables = [
|
||||
variables_lib.Variable([0]),
|
||||
variables_lib.Variable([1]),
|
||||
variables_lib.Variable([2]),
|
||||
variables_lib.Variable([3])
|
||||
]
|
||||
s = sharded_variable.ShardedVariable(variables, name='s')
|
||||
|
||||
cp = util.Checkpoint(s=s)
|
||||
self.assertEqual(self.evaluate(cp.s.variables[0]), [0])
|
||||
cp.write(fname)
|
||||
|
||||
self.evaluate(cp.s.variables[0].assign([4]))
|
||||
self.assertEqual(self.evaluate(cp.s.variables[0]), [4])
|
||||
|
||||
cp.restore(fname)
|
||||
# Tests that the original weights are restored.
|
||||
self.assertEqual(self.evaluate(cp.s.variables[0]), [0])
|
||||
|
||||
def test_save_restore_different_partitions(self):
|
||||
fname = os.path.join(self.get_temp_dir(), 'checkpoint')
|
||||
variables = [
|
||||
variables_lib.Variable([0]),
|
||||
variables_lib.Variable([1]),
|
||||
variables_lib.Variable([2]),
|
||||
variables_lib.Variable([3])
|
||||
]
|
||||
s = sharded_variable.ShardedVariable(variables, name='s')
|
||||
|
||||
cp = util.Checkpoint(s=s)
|
||||
cp.write(fname)
|
||||
|
||||
variables2 = [variables_lib.Variable([0, 0, 0, 0])]
|
||||
s2 = sharded_variable.ShardedVariable(variables2, name='s')
|
||||
|
||||
# Restore from 4 partitions into 1.
|
||||
cp2 = util.Checkpoint(s=s2)
|
||||
cp2.restore(fname)
|
||||
self.assertAllEqual(self.evaluate(cp2.s.variables[0]), [0, 1, 2, 3])
|
||||
|
||||
self.evaluate(cp2.s.variables[0].assign([5, 10, 15, 20]))
|
||||
cp2.write(fname)
|
||||
|
||||
# Restore 1 partition into 4.
|
||||
cp.restore(fname)
|
||||
self.assertEqual(self.evaluate(cp.s.variables[0]), [5])
|
||||
self.assertEqual(self.evaluate(cp.s.variables[1]), [10])
|
||||
self.assertEqual(self.evaluate(cp.s.variables[2]), [15])
|
||||
self.assertEqual(self.evaluate(cp.s.variables[3]), [20])
|
||||
|
||||
def test_save_restore_4_to_2_partitions(self):
|
||||
fname = os.path.join(self.get_temp_dir(), 'checkpoint')
|
||||
variables = [
|
||||
variables_lib.Variable([0]),
|
||||
variables_lib.Variable([1]),
|
||||
variables_lib.Variable([2]),
|
||||
variables_lib.Variable([3])
|
||||
]
|
||||
s = sharded_variable.ShardedVariable(variables, name='s')
|
||||
cp = util.Checkpoint(s=s)
|
||||
cp.write(fname)
|
||||
|
||||
variables2 = [
|
||||
variables_lib.Variable([0, 0]),
|
||||
variables_lib.Variable([0, 0])
|
||||
]
|
||||
s2 = sharded_variable.ShardedVariable(variables2, name='s')
|
||||
cp2 = util.Checkpoint(s=s2)
|
||||
cp2.restore(fname)
|
||||
# Assert that weights from the 4 partitions were loaded here.
|
||||
self.assertLen(cp2.s.variables, 2)
|
||||
self.assertAllEqual(self.evaluate(cp2.s.variables[0]), [0, 1])
|
||||
self.assertAllEqual(self.evaluate(cp2.s.variables[1]), [2, 3])
|
||||
|
||||
def test_validation_errors(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Expected a list of '):
|
||||
sharded_variable.ShardedVariable(
|
||||
[variables_lib.Variable([0]), 'not-a-variable'])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'must have the same dtype'):
|
||||
sharded_variable.ShardedVariable([
|
||||
variables_lib.Variable([0], dtype='int64'),
|
||||
variables_lib.Variable([1], dtype='int32')
|
||||
])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'the same shapes except'):
|
||||
sharded_variable.ShardedVariable([
|
||||
variables_lib.Variable(array_ops.ones((5, 10))),
|
||||
variables_lib.Variable(array_ops.ones((5, 20)))
|
||||
])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, '`SaveSliceInfo` should not'):
|
||||
v = variables_lib.Variable([0])
|
||||
v._set_save_slice_info(
|
||||
variables_lib.Variable.SaveSliceInfo(
|
||||
full_name='s', full_shape=[2], var_offset=[0], var_shape=[1]))
|
||||
sharded_variable.ShardedVariable([v])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v2_compat.enable_v2_behavior()
|
||||
test.main()
|
@ -28,6 +28,7 @@ from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.saving import saveable_object
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
|
||||
|
||||
@ -147,6 +148,9 @@ def saveable_objects_for_op(op, name):
|
||||
slice_name = None
|
||||
# pylint: disable=protected-access
|
||||
for variable in op:
|
||||
if isinstance(variable, saveable_object.SaveableObject):
|
||||
yield variable
|
||||
continue
|
||||
if not isinstance(variable, variables.Variable):
|
||||
raise ValueError("Slices must all be Variables: %s" % variable)
|
||||
if not variable._save_slice_info:
|
||||
@ -210,7 +214,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True):
|
||||
"""Create a dictionary of names to operation lists.
|
||||
|
||||
Args:
|
||||
op_list: A list, tuple, or set of Variables or SaveableObjects.
|
||||
op_list: A (nested) list, tuple, or set of Variables or SaveableObjects.
|
||||
convert_variable_to_tensor: Whether or not to convert single Variables
|
||||
with no slice info into Tensors.
|
||||
|
||||
@ -226,6 +230,8 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True):
|
||||
if not isinstance(op_list, (list, tuple, set)):
|
||||
raise TypeError("Variables to save should be passed in a dict or a "
|
||||
"list: %s" % op_list)
|
||||
# List casting is necessary to support sets.
|
||||
op_list = nest.flatten(list(op_list))
|
||||
# When ResourceVariables are converted to Tensors, read ops are added to the
|
||||
# graph. Sorting the op_list ensures that the resulting graph is always
|
||||
# constructed in a deterministic way:
|
||||
|
Loading…
Reference in New Issue
Block a user