Optionally save MirroredVariable components.
Write MirroredVariable components to the newly introduced `experimental_distributed_variable_components` protobuf field when the EXPAND_DISTRIBUTED_VARIABLES SaveOption is set. This is currently not supported by any loader. PiperOrigin-RevId: 335670847 Change-Id: I1c38ae132e4b2cda52adafa819c1779488031f20
This commit is contained in:
parent
e665554b90
commit
a4f4855c82
@ -1061,6 +1061,30 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
|
||||
return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
|
||||
|
||||
def _write_object_proto(self, proto, options):
|
||||
"""Update a SavedObject proto for this object.
|
||||
|
||||
If an object defines this method, it will be called when saving with a
|
||||
pre-built `SavedObject` proto representing the object, plus an instance of
|
||||
`SaveOptions`. This method is then free to modify that proto instance.
|
||||
|
||||
`MirroredVariables` optionally write out information about their components
|
||||
to the `experimental_distributed_variable_components` field of a
|
||||
`SavedVariable` (depending on the `SaveOptions` variable policy).
|
||||
|
||||
Args:
|
||||
proto: A pre-built `SavedObject` proto for this object. It is assumed this
|
||||
will be a `SavedVariable` instance.
|
||||
options: A `SaveOptions` instance.
|
||||
"""
|
||||
if options.experimental_variable_policy._expand_distributed_variables( # pylint: disable=protected-access
|
||||
):
|
||||
for var in self.values:
|
||||
var_proto = (
|
||||
proto.variable.experimental_distributed_variable_components.add())
|
||||
var_proto.name = var.name.split(":")[0]
|
||||
var_proto.device = var.device
|
||||
|
||||
def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
|
||||
"""Converts a variable to a tensor."""
|
||||
# TODO(b/154017756): Make _dense_var_to_tensor consistent between ON_READ
|
||||
|
@ -786,6 +786,16 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
|
||||
# pylint:enable=protected-access
|
||||
proto.user_object.CopyFrom(registered_type_proto)
|
||||
|
||||
# Give the object a chance to modify the SavedObject proto.
|
||||
# This is currently used by MirroredVariables to optionally write their
|
||||
# component variables to the proto.
|
||||
#
|
||||
# This is not yet an official Trackable method, the only current use case
|
||||
# being MirroredVariables. See the method implementation there for more
|
||||
# documentation.
|
||||
if hasattr(obj, "_write_object_proto"):
|
||||
obj._write_object_proto(proto, options) # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _export_debug_info(exported_graph):
|
||||
"""Exports debug information from a graph.
|
||||
@ -991,10 +1001,6 @@ def save(obj, export_dir, signatures=None, options=None):
|
||||
@end_compatibility
|
||||
"""
|
||||
options = options or save_options.SaveOptions()
|
||||
if options.experimental_variable_policy._expand_distributed_variables(): # pylint:disable=protected-access
|
||||
raise NotImplementedError(
|
||||
"The VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES option is "
|
||||
"not implemented in saved_model.save.")
|
||||
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
|
||||
# compatible (no sessions) and share it with this export API rather than
|
||||
# making a SavedModel proto and writing it directly.
|
||||
|
@ -56,13 +56,11 @@ class VariablePolicy(enum.Enum):
|
||||
Distributed variables are still saved as one variable under this policy.
|
||||
|
||||
EXPAND_DISTRIBUTED_VARIABLES
|
||||
Distributed variables will be explicitly expanded into their respective
|
||||
distributed replicas, and their assigned devices will be saved. This is
|
||||
useful when one wants to use the model for training in environments where
|
||||
the original distribution strategy is not available. Checkpoints are
|
||||
currently incompatible with this option, so it is not implemented in
|
||||
`saved_model.save` (only the internal `saved_model.export_meta_graph` API
|
||||
supports it for now).
|
||||
Distributed variables will be saved with information about their components,
|
||||
allowing for their restoration on load. Also, the saved graph will contain
|
||||
references to those variables. This is useful when one wants to use the
|
||||
model for training in environments where the original distribution strategy
|
||||
is not available.
|
||||
"""
|
||||
|
||||
NONE = None
|
||||
|
@ -556,6 +556,7 @@ class SaveTest(test.TestCase, parameterized.TestCase):
|
||||
save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES),
|
||||
("_DiscardDistributedVariables", save_options.VariablePolicy.NONE))
|
||||
def test_expand_distributed_variables(self, expand_strategy):
|
||||
# 1. Create a context with both CPU:0 and CPU:1.
|
||||
context._reset_context()
|
||||
cpus = context.context().list_physical_devices("CPU")
|
||||
if len(cpus) == 1:
|
||||
@ -566,6 +567,7 @@ class SaveTest(test.TestCase, parameterized.TestCase):
|
||||
])
|
||||
context.ensure_initialized()
|
||||
|
||||
# 2. Create and save a model under a mirrored strategy.
|
||||
file_name = os.path.join(self.get_temp_dir(), "saved_model.pb")
|
||||
with mirrored_strategy.MirroredStrategy(["CPU:0", "CPU:1"]).scope():
|
||||
root = tracking.AutoTrackable()
|
||||
@ -582,36 +584,35 @@ class SaveTest(test.TestCase, parameterized.TestCase):
|
||||
filename=file_name,
|
||||
options=save_options.SaveOptions(
|
||||
experimental_variable_policy=expand_strategy))
|
||||
graph_def = meta_graph.read_meta_graph_file(file_name).graph_def
|
||||
v0 = next((n for n in graph_def.node if n.name == "v"), None)
|
||||
v1 = next((n for n in graph_def.node if n.name == "v/replica_1"), None)
|
||||
self.assertIsNotNone(v0)
|
||||
|
||||
# 3. Read the output file and test behavior.
|
||||
meta_graph_def = meta_graph.read_meta_graph_file(file_name)
|
||||
object_graph = meta_graph_def.object_graph_def
|
||||
graph_def = meta_graph_def.graph_def
|
||||
v = next((n.variable
|
||||
for n in object_graph.nodes
|
||||
if n.HasField("variable") and n.variable.name == "v"), None)
|
||||
saved_function = next((f for f in graph_def.library.function
|
||||
if "inference_f_" in f.signature.name), None)
|
||||
self.assertIsNotNone(saved_function)
|
||||
if (expand_strategy ==
|
||||
save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES):
|
||||
self.assertIsNotNone(v1)
|
||||
# experimental_save_variable_devices should have been automatically set.
|
||||
self.assertIn("CPU:0", v.device)
|
||||
components = v.experimental_distributed_variable_components
|
||||
self.assertLen(components, 2)
|
||||
v0 = next((x for x in components if x.name == "v"), None)
|
||||
v1 = next((x for x in components if x.name == "v/replica_1"), None)
|
||||
self.assertIsNotNone(v0)
|
||||
self.assertIsNotNone(v1)
|
||||
self.assertIn("CPU:0", v0.device)
|
||||
self.assertIn("CPU:1", v1.device)
|
||||
self.assertLen(saved_function.signature.input_arg, 2)
|
||||
else:
|
||||
self.assertIsNone(v1)
|
||||
self.assertEmpty(v0.device)
|
||||
self.assertEmpty(v.device)
|
||||
self.assertEmpty(v.experimental_distributed_variable_components)
|
||||
self.assertLen(saved_function.signature.input_arg, 1)
|
||||
|
||||
def test_expand_distributed_variables_not_allowed(self):
|
||||
root = tracking.AutoTrackable()
|
||||
with self.assertRaisesRegex(NotImplementedError,
|
||||
"not implemented in saved_model.save"):
|
||||
save.save(
|
||||
obj=root,
|
||||
export_dir="",
|
||||
options=save_options.SaveOptions(
|
||||
experimental_variable_policy=save_options.VariablePolicy
|
||||
.EXPAND_DISTRIBUTED_VARIABLES))
|
||||
|
||||
def test_save_uninitialized_variable(self):
|
||||
root = tracking.AutoTrackable()
|
||||
root.uninitialized_variable = resource_variable_ops.UninitializedVariable(
|
||||
|
Loading…
Reference in New Issue
Block a user