From a4f4855c82e135d04d11f7a234d5a26e9f1ab2bf Mon Sep 17 00:00:00 2001 From: Cesar Crusius Date: Tue, 6 Oct 2020 10:36:01 -0700 Subject: [PATCH] Optionally save MirroredVariable components. Write MirroredVariable components to the newly introduced `experimental_distributed_variable_components` protobuf field when the EXPAND_DISTRIBUTED_VARIABLES SaveOption is set. This is currently not supported by any loader. PiperOrigin-RevId: 335670847 Change-Id: I1c38ae132e4b2cda52adafa819c1779488031f20 --- tensorflow/python/distribute/values.py | 24 ++++++++++++ tensorflow/python/saved_model/save.py | 14 +++++-- tensorflow/python/saved_model/save_options.py | 12 +++--- tensorflow/python/saved_model/save_test.py | 37 ++++++++++--------- 4 files changed, 58 insertions(+), 29 deletions(-) diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 36aaaa6c98c..61fb933cdbf 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -1061,6 +1061,30 @@ class MirroredVariable(DistributedVariable, Mirrored): return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} + def _write_object_proto(self, proto, options): + """Update a SavedObject proto for this object. + + If an object defines this method, it will be called when saving with a + pre-built `SavedObject` proto representing the object, plus an instance of + `SaveOptions`. This method is then free to modify that proto instance. + + `MirroredVariables` optionally write out information about their components + to the `experimental_distributed_variable_components` field of a + `SavedVariable` (depending on the `SaveOptions` variable policy). + + Args: + proto: A pre-built `SavedObject` proto for this object. It is assumed this + will be a `SavedVariable` instance. + options: A `SaveOptions` instance. + """ + if options.experimental_variable_policy._expand_distributed_variables( # pylint: disable=protected-access + ): + for var in self.values: + var_proto = ( + proto.variable.experimental_distributed_variable_components.add()) + var_proto.name = var.name.split(":")[0] + var_proto.device = var.device + def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): """Converts a variable to a tensor.""" # TODO(b/154017756): Make _dense_var_to_tensor consistent between ON_READ diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index 361883adc22..29969056a4f 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -786,6 +786,16 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map): # pylint:enable=protected-access proto.user_object.CopyFrom(registered_type_proto) + # Give the object a chance to modify the SavedObject proto. + # This is currently used by MirroredVariables to optionally write their + # component variables to the proto. + # + # This is not yet an official Trackable method, the only current use case + # being MirroredVariables. See the method implementation there for more + # documentation. + if hasattr(obj, "_write_object_proto"): + obj._write_object_proto(proto, options) # pylint: disable=protected-access + def _export_debug_info(exported_graph): """Exports debug information from a graph. @@ -991,10 +1001,6 @@ def save(obj, export_dir, signatures=None, options=None): @end_compatibility """ options = options or save_options.SaveOptions() - if options.experimental_variable_policy._expand_distributed_variables(): # pylint:disable=protected-access - raise NotImplementedError( - "The VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES option is " - "not implemented in saved_model.save.") # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than # making a SavedModel proto and writing it directly. diff --git a/tensorflow/python/saved_model/save_options.py b/tensorflow/python/saved_model/save_options.py index ae4421bf022..f6330848441 100644 --- a/tensorflow/python/saved_model/save_options.py +++ b/tensorflow/python/saved_model/save_options.py @@ -56,13 +56,11 @@ class VariablePolicy(enum.Enum): Distributed variables are still saved as one variable under this policy. EXPAND_DISTRIBUTED_VARIABLES - Distributed variables will be explicitly expanded into their respective - distributed replicas, and their assigned devices will be saved. This is - useful when one wants to use the model for training in environments where - the original distribution strategy is not available. Checkpoints are - currently incompatible with this option, so it is not implemented in - `saved_model.save` (only the internal `saved_model.export_meta_graph` API - supports it for now). + Distributed variables will be saved with information about their components, + allowing for their restoration on load. Also, the saved graph will contain + references to those variables. This is useful when one wants to use the + model for training in environments where the original distribution strategy + is not available. """ NONE = None diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index f3d78881429..522720096af 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -556,6 +556,7 @@ class SaveTest(test.TestCase, parameterized.TestCase): save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES), ("_DiscardDistributedVariables", save_options.VariablePolicy.NONE)) def test_expand_distributed_variables(self, expand_strategy): + # 1. Create a context with both CPU:0 and CPU:1. context._reset_context() cpus = context.context().list_physical_devices("CPU") if len(cpus) == 1: @@ -566,6 +567,7 @@ class SaveTest(test.TestCase, parameterized.TestCase): ]) context.ensure_initialized() + # 2. Create and save a model under a mirrored strategy. file_name = os.path.join(self.get_temp_dir(), "saved_model.pb") with mirrored_strategy.MirroredStrategy(["CPU:0", "CPU:1"]).scope(): root = tracking.AutoTrackable() @@ -582,36 +584,35 @@ class SaveTest(test.TestCase, parameterized.TestCase): filename=file_name, options=save_options.SaveOptions( experimental_variable_policy=expand_strategy)) - graph_def = meta_graph.read_meta_graph_file(file_name).graph_def - v0 = next((n for n in graph_def.node if n.name == "v"), None) - v1 = next((n for n in graph_def.node if n.name == "v/replica_1"), None) - self.assertIsNotNone(v0) + + # 3. Read the output file and test behavior. + meta_graph_def = meta_graph.read_meta_graph_file(file_name) + object_graph = meta_graph_def.object_graph_def + graph_def = meta_graph_def.graph_def + v = next((n.variable + for n in object_graph.nodes + if n.HasField("variable") and n.variable.name == "v"), None) saved_function = next((f for f in graph_def.library.function if "inference_f_" in f.signature.name), None) self.assertIsNotNone(saved_function) if (expand_strategy == save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES): - self.assertIsNotNone(v1) # experimental_save_variable_devices should have been automatically set. + self.assertIn("CPU:0", v.device) + components = v.experimental_distributed_variable_components + self.assertLen(components, 2) + v0 = next((x for x in components if x.name == "v"), None) + v1 = next((x for x in components if x.name == "v/replica_1"), None) + self.assertIsNotNone(v0) + self.assertIsNotNone(v1) self.assertIn("CPU:0", v0.device) self.assertIn("CPU:1", v1.device) self.assertLen(saved_function.signature.input_arg, 2) else: - self.assertIsNone(v1) - self.assertEmpty(v0.device) + self.assertEmpty(v.device) + self.assertEmpty(v.experimental_distributed_variable_components) self.assertLen(saved_function.signature.input_arg, 1) - def test_expand_distributed_variables_not_allowed(self): - root = tracking.AutoTrackable() - with self.assertRaisesRegex(NotImplementedError, - "not implemented in saved_model.save"): - save.save( - obj=root, - export_dir="", - options=save_options.SaveOptions( - experimental_variable_policy=save_options.VariablePolicy - .EXPAND_DISTRIBUTED_VARIABLES)) - def test_save_uninitialized_variable(self): root = tracking.AutoTrackable() root.uninitialized_variable = resource_variable_ops.UninitializedVariable(