From 89e33e5ef3c3a14197fef6b27c45f9ff4a4d500c Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Thu, 3 Oct 2019 14:45:29 -0700 Subject: [PATCH] Add `ShardedVariable` class. PiperOrigin-RevId: 272745815 --- tensorflow/python/distribute/BUILD | 28 +++- .../python/distribute/sharded_variable.py | 139 +++++++++++++++++ .../distribute/sharded_variable_test.py | 146 ++++++++++++++++++ .../training/saving/saveable_object_util.py | 8 +- 4 files changed, 319 insertions(+), 2 deletions(-) create mode 100644 tensorflow/python/distribute/sharded_variable.py create mode 100644 tensorflow/python/distribute/sharded_variable_test.py diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 185495b5c5d..60c6ae84956 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1,6 +1,5 @@ load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") -load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") load("//tensorflow/core/platform:default/distribute.bzl", "distribute_py_test") package( @@ -132,6 +131,7 @@ py_library( ":distribute_lib", ":mirrored_strategy", ":one_device_strategy", + ":sharded_variable", "//tensorflow/python/distribute/experimental", ], ) @@ -778,6 +778,32 @@ cuda_py_test( grpc_enabled = True, ) +py_library( + name = "sharded_variable", + srcs = ["sharded_variable.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:tensor_shape", + "//tensorflow/python:variables", + "//tensorflow/python/training/saving:saveable_object_util", + "//tensorflow/python/training/tracking:base", + ], +) + +tf_py_test( + name = "sharded_variable_test", + size = "small", + srcs = ["sharded_variable_test.py"], + additional_deps = [ + ":sharded_variable", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:variables", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/training/tracking:util", + ], +) + py_library( name = "strategy_test_lib", srcs = ["strategy_test_lib.py"], diff --git a/tensorflow/python/distribute/sharded_variable.py b/tensorflow/python/distribute/sharded_variable.py new file mode 100644 index 00000000000..9886e42a8b3 --- /dev/null +++ b/tensorflow/python/distribute/sharded_variable.py @@ -0,0 +1,139 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ShardedVariable class.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.training.saving import saveable_object_util +from tensorflow.python.training.tracking import base as trackable + + +class ShardedVariable(trackable.Trackable): + """A container for `Variables` that should be treated as shards. + + Variables that are too large to fit on a single device (e.g., large + embeddings) + may need to be sharded over multiple devices. This class maintains a list of + smaller variables that can be independently stored on separate devices (eg, + multiple parameter servers), and saves and restores those variables as if they + were a single larger variable. + + Objects of this class can be saved with a given number of shards and then + restored from a checkpoint into a different number of shards. + + Sharding is only supported along the first dimension. + """ + + def __init__(self, variables, name='ShardedVariable'): + """Treats `variables` as shards of a larger Variable. + + + Example: + + ``` + variables = [ + tf.Variable(..., shape=(10, 100), dtype=tf.float32), + tf.Variable(..., shape=(15, 100), dtype=tf.float32), + tf.Variable(..., shape=(5, 100), dtype=tf.float32) + ] + sharded_variable = ShardedVariable(variables) + assert sharded_variable.shape.as_list() == [30, 100] + ``` + + Args: + variables: A list of `ResourceVariable`s that comprise this sharded + variable. Variables should not be shared between different + `ShardedVariable` objects. + name: String. Name of this container. Defaults to "ShardedVariable". + """ + super(ShardedVariable, self).__init__() + self._variables = variables + self._name = name + + first_var = variables[0] + + if any(not isinstance(v, variables_lib.Variable) for v in variables): + raise ValueError( + 'Expected a list of `Variable`s, found: {}'.format(variables)) + + dtypes = {v.dtype for v in variables} + if len(dtypes) > 1: + raise ValueError( + 'All `Variable`s must have the same dtype, found: {}'.format( + [v.dtype for v in variables])) + self._dtype = first_var.dtype + + # All variables must have the same shape for axes > 0. + higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables} + if len(higher_dim_shapes) > 1: + raise ValueError( + 'All `Variables`s must have the same shapes except for the first ' + 'axis, found {}'.format([v.shape for v in variables])) + first_dim = sum(int(v.shape[0]) for v in variables) + self._shape = tensor_shape.TensorShape([first_dim] + first_var.shape[1:]) + + save_slice_info = [v._get_save_slice_info() for v in variables] # pylint: disable=protected-access + if any(slice_info is not None for slice_info in save_slice_info): + raise ValueError('`SaveSliceInfo` should not be set for `Variable`s. ' + '`ShardedVariable` will infer `SaveSliceInfo` according ' + 'to the order of the `Variable`s in the list passed to ' + 'the constructor. Found {}'.format(save_slice_info)) + + @property + def variables(self): + """The list of `Variable`s that make up the shards of this object.""" + return self._variables + + @property + def name(self): + """The name of this object. Used for checkpointing.""" + return self._name + + @property + def dtype(self): + """The dtype of all `Variable`s in this object.""" + return self._dtype + + @property + def shape(self): + """The overall shape, combining all shards along axis `0`.""" + return self._shape + + def _gather_saveables_for_checkpoint(self): + """Return a `Saveable` for each shard. See `Trackable`.""" + + def _saveable_factory(name=self.name): + """Creates `SaveableObject`s for this `ShardedVariable`.""" + saveables = [] + dims = len(self._variables[0].shape) + var_offset = [0 for _ in range(dims)] + for v in self._variables: + save_slice_info = variables_lib.Variable.SaveSliceInfo( + full_name=self.name, + full_shape=self.shape.as_list(), + var_offset=copy.copy(var_offset), + var_shape=v.shape.as_list()) + saveables.append( + saveable_object_util.ResourceVariableSaveable( + v, save_slice_info.spec, name)) # pylint: disable=protected-access + var_offset[0] += int(v.shape[0]) + return saveables + + return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} diff --git a/tensorflow/python/distribute/sharded_variable_test.py b/tensorflow/python/distribute/sharded_variable_test.py new file mode 100644 index 00000000000..7110a9ff1fe --- /dev/null +++ b/tensorflow/python/distribute/sharded_variable_test.py @@ -0,0 +1,146 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ShardedVariable.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.compat import v2_compat +from tensorflow.python.distribute import sharded_variable +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.platform import test +from tensorflow.python.training.tracking import util + + +class ShardedVariableTest(test.TestCase): + + def test_sharded_variable_simple(self): + v0 = variables_lib.Variable([0]) + v1 = variables_lib.Variable([1]) + s = sharded_variable.ShardedVariable([v0, v1], name='s') + self.assertEqual(s.variables[0], v0) + self.assertEqual(s.variables[1], v1) + self.assertEqual(s.shape.as_list(), [2]) + self.assertEqual(s.dtype, v0.dtype) + self.assertEqual(s.name, 's') + + def test_save_restore(self): + fname = os.path.join(self.get_temp_dir(), 'checkpoint') + variables = [ + variables_lib.Variable([0]), + variables_lib.Variable([1]), + variables_lib.Variable([2]), + variables_lib.Variable([3]) + ] + s = sharded_variable.ShardedVariable(variables, name='s') + + cp = util.Checkpoint(s=s) + self.assertEqual(self.evaluate(cp.s.variables[0]), [0]) + cp.write(fname) + + self.evaluate(cp.s.variables[0].assign([4])) + self.assertEqual(self.evaluate(cp.s.variables[0]), [4]) + + cp.restore(fname) + # Tests that the original weights are restored. + self.assertEqual(self.evaluate(cp.s.variables[0]), [0]) + + def test_save_restore_different_partitions(self): + fname = os.path.join(self.get_temp_dir(), 'checkpoint') + variables = [ + variables_lib.Variable([0]), + variables_lib.Variable([1]), + variables_lib.Variable([2]), + variables_lib.Variable([3]) + ] + s = sharded_variable.ShardedVariable(variables, name='s') + + cp = util.Checkpoint(s=s) + cp.write(fname) + + variables2 = [variables_lib.Variable([0, 0, 0, 0])] + s2 = sharded_variable.ShardedVariable(variables2, name='s') + + # Restore from 4 partitions into 1. + cp2 = util.Checkpoint(s=s2) + cp2.restore(fname) + self.assertAllEqual(self.evaluate(cp2.s.variables[0]), [0, 1, 2, 3]) + + self.evaluate(cp2.s.variables[0].assign([5, 10, 15, 20])) + cp2.write(fname) + + # Restore 1 partition into 4. + cp.restore(fname) + self.assertEqual(self.evaluate(cp.s.variables[0]), [5]) + self.assertEqual(self.evaluate(cp.s.variables[1]), [10]) + self.assertEqual(self.evaluate(cp.s.variables[2]), [15]) + self.assertEqual(self.evaluate(cp.s.variables[3]), [20]) + + def test_save_restore_4_to_2_partitions(self): + fname = os.path.join(self.get_temp_dir(), 'checkpoint') + variables = [ + variables_lib.Variable([0]), + variables_lib.Variable([1]), + variables_lib.Variable([2]), + variables_lib.Variable([3]) + ] + s = sharded_variable.ShardedVariable(variables, name='s') + cp = util.Checkpoint(s=s) + cp.write(fname) + + variables2 = [ + variables_lib.Variable([0, 0]), + variables_lib.Variable([0, 0]) + ] + s2 = sharded_variable.ShardedVariable(variables2, name='s') + cp2 = util.Checkpoint(s=s2) + cp2.restore(fname) + # Assert that weights from the 4 partitions were loaded here. + self.assertLen(cp2.s.variables, 2) + self.assertAllEqual(self.evaluate(cp2.s.variables[0]), [0, 1]) + self.assertAllEqual(self.evaluate(cp2.s.variables[1]), [2, 3]) + + def test_validation_errors(self): + with self.assertRaisesRegexp(ValueError, 'Expected a list of '): + sharded_variable.ShardedVariable( + [variables_lib.Variable([0]), 'not-a-variable']) + + with self.assertRaisesRegexp(ValueError, 'must have the same dtype'): + sharded_variable.ShardedVariable([ + variables_lib.Variable([0], dtype='int64'), + variables_lib.Variable([1], dtype='int32') + ]) + + with self.assertRaisesRegexp(ValueError, 'the same shapes except'): + sharded_variable.ShardedVariable([ + variables_lib.Variable(array_ops.ones((5, 10))), + variables_lib.Variable(array_ops.ones((5, 20))) + ]) + + with self.assertRaisesRegexp(ValueError, '`SaveSliceInfo` should not'): + v = variables_lib.Variable([0]) + v._set_save_slice_info( + variables_lib.Variable.SaveSliceInfo( + full_name='s', full_shape=[2], var_offset=[0], var_shape=[1])) + sharded_variable.ShardedVariable([v]) + + +if __name__ == '__main__': + v2_compat.enable_v2_behavior() + test.main() diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py index 099fcf0548d..f4c5ee7f3c2 100644 --- a/tensorflow/python/training/saving/saveable_object_util.py +++ b/tensorflow/python/training/saving/saveable_object_util.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.training.saving import saveable_object from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.util import nest from tensorflow.python.util import object_identity @@ -147,6 +148,9 @@ def saveable_objects_for_op(op, name): slice_name = None # pylint: disable=protected-access for variable in op: + if isinstance(variable, saveable_object.SaveableObject): + yield variable + continue if not isinstance(variable, variables.Variable): raise ValueError("Slices must all be Variables: %s" % variable) if not variable._save_slice_info: @@ -210,7 +214,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True): """Create a dictionary of names to operation lists. Args: - op_list: A list, tuple, or set of Variables or SaveableObjects. + op_list: A (nested) list, tuple, or set of Variables or SaveableObjects. convert_variable_to_tensor: Whether or not to convert single Variables with no slice info into Tensors. @@ -226,6 +230,8 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True): if not isinstance(op_list, (list, tuple, set)): raise TypeError("Variables to save should be passed in a dict or a " "list: %s" % op_list) + # List casting is necessary to support sets. + op_list = nest.flatten(list(op_list)) # When ResourceVariables are converted to Tensors, read ops are added to the # graph. Sorting the op_list ensures that the resulting graph is always # constructed in a deterministic way: