Optionally save MirroredVariable components.

Write MirroredVariable components to the newly introduced
`experimental_distributed_variable_components` protobuf field when the
EXPAND_DISTRIBUTED_VARIABLES SaveOption is set. This is currently not
supported by any loader.

PiperOrigin-RevId: 335670847
Change-Id: I1c38ae132e4b2cda52adafa819c1779488031f20
This commit is contained in:
Cesar Crusius 2020-10-06 10:36:01 -07:00 committed by TensorFlower Gardener
parent e665554b90
commit a4f4855c82
4 changed files with 58 additions and 29 deletions

View File

@ -1061,6 +1061,30 @@ class MirroredVariable(DistributedVariable, Mirrored):
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
def _write_object_proto(self, proto, options):
"""Update a SavedObject proto for this object.
If an object defines this method, it will be called when saving with a
pre-built `SavedObject` proto representing the object, plus an instance of
`SaveOptions`. This method is then free to modify that proto instance.
`MirroredVariables` optionally write out information about their components
to the `experimental_distributed_variable_components` field of a
`SavedVariable` (depending on the `SaveOptions` variable policy).
Args:
proto: A pre-built `SavedObject` proto for this object. It is assumed this
will be a `SavedVariable` instance.
options: A `SaveOptions` instance.
"""
if options.experimental_variable_policy._expand_distributed_variables( # pylint: disable=protected-access
):
for var in self.values:
var_proto = (
proto.variable.experimental_distributed_variable_components.add())
var_proto.name = var.name.split(":")[0]
var_proto.device = var.device
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
"""Converts a variable to a tensor."""
# TODO(b/154017756): Make _dense_var_to_tensor consistent between ON_READ

View File

@ -786,6 +786,16 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
# pylint:enable=protected-access
proto.user_object.CopyFrom(registered_type_proto)
# Give the object a chance to modify the SavedObject proto.
# This is currently used by MirroredVariables to optionally write their
# component variables to the proto.
#
# This is not yet an official Trackable method, the only current use case
# being MirroredVariables. See the method implementation there for more
# documentation.
if hasattr(obj, "_write_object_proto"):
obj._write_object_proto(proto, options) # pylint: disable=protected-access
def _export_debug_info(exported_graph):
"""Exports debug information from a graph.
@ -991,10 +1001,6 @@ def save(obj, export_dir, signatures=None, options=None):
@end_compatibility
"""
options = options or save_options.SaveOptions()
if options.experimental_variable_policy._expand_distributed_variables(): # pylint:disable=protected-access
raise NotImplementedError(
"The VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES option is "
"not implemented in saved_model.save.")
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
# compatible (no sessions) and share it with this export API rather than
# making a SavedModel proto and writing it directly.

View File

@ -56,13 +56,11 @@ class VariablePolicy(enum.Enum):
Distributed variables are still saved as one variable under this policy.
EXPAND_DISTRIBUTED_VARIABLES
Distributed variables will be explicitly expanded into their respective
distributed replicas, and their assigned devices will be saved. This is
useful when one wants to use the model for training in environments where
the original distribution strategy is not available. Checkpoints are
currently incompatible with this option, so it is not implemented in
`saved_model.save` (only the internal `saved_model.export_meta_graph` API
supports it for now).
Distributed variables will be saved with information about their components,
allowing for their restoration on load. Also, the saved graph will contain
references to those variables. This is useful when one wants to use the
model for training in environments where the original distribution strategy
is not available.
"""
NONE = None

View File

@ -556,6 +556,7 @@ class SaveTest(test.TestCase, parameterized.TestCase):
save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES),
("_DiscardDistributedVariables", save_options.VariablePolicy.NONE))
def test_expand_distributed_variables(self, expand_strategy):
# 1. Create a context with both CPU:0 and CPU:1.
context._reset_context()
cpus = context.context().list_physical_devices("CPU")
if len(cpus) == 1:
@ -566,6 +567,7 @@ class SaveTest(test.TestCase, parameterized.TestCase):
])
context.ensure_initialized()
# 2. Create and save a model under a mirrored strategy.
file_name = os.path.join(self.get_temp_dir(), "saved_model.pb")
with mirrored_strategy.MirroredStrategy(["CPU:0", "CPU:1"]).scope():
root = tracking.AutoTrackable()
@ -582,36 +584,35 @@ class SaveTest(test.TestCase, parameterized.TestCase):
filename=file_name,
options=save_options.SaveOptions(
experimental_variable_policy=expand_strategy))
graph_def = meta_graph.read_meta_graph_file(file_name).graph_def
v0 = next((n for n in graph_def.node if n.name == "v"), None)
v1 = next((n for n in graph_def.node if n.name == "v/replica_1"), None)
self.assertIsNotNone(v0)
# 3. Read the output file and test behavior.
meta_graph_def = meta_graph.read_meta_graph_file(file_name)
object_graph = meta_graph_def.object_graph_def
graph_def = meta_graph_def.graph_def
v = next((n.variable
for n in object_graph.nodes
if n.HasField("variable") and n.variable.name == "v"), None)
saved_function = next((f for f in graph_def.library.function
if "inference_f_" in f.signature.name), None)
self.assertIsNotNone(saved_function)
if (expand_strategy ==
save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES):
self.assertIsNotNone(v1)
# experimental_save_variable_devices should have been automatically set.
self.assertIn("CPU:0", v.device)
components = v.experimental_distributed_variable_components
self.assertLen(components, 2)
v0 = next((x for x in components if x.name == "v"), None)
v1 = next((x for x in components if x.name == "v/replica_1"), None)
self.assertIsNotNone(v0)
self.assertIsNotNone(v1)
self.assertIn("CPU:0", v0.device)
self.assertIn("CPU:1", v1.device)
self.assertLen(saved_function.signature.input_arg, 2)
else:
self.assertIsNone(v1)
self.assertEmpty(v0.device)
self.assertEmpty(v.device)
self.assertEmpty(v.experimental_distributed_variable_components)
self.assertLen(saved_function.signature.input_arg, 1)
def test_expand_distributed_variables_not_allowed(self):
root = tracking.AutoTrackable()
with self.assertRaisesRegex(NotImplementedError,
"not implemented in saved_model.save"):
save.save(
obj=root,
export_dir="",
options=save_options.SaveOptions(
experimental_variable_policy=save_options.VariablePolicy
.EXPAND_DISTRIBUTED_VARIABLES))
def test_save_uninitialized_variable(self):
root = tracking.AutoTrackable()
root.uninitialized_variable = resource_variable_ops.UninitializedVariable(