Handle distributed variables correctly in tf.saved_model saving. In particular, this allows saving models containing hub.KerasLayer and trained with distribution strategies.
PiperOrigin-RevId: 260557099
This commit is contained in:
parent
fc2f969a74
commit
6801a4b2fb
@ -79,5 +79,13 @@ def main(argv):
|
||||
np.testing.assert_allclose(y_lite, y_tf, rtol=0, atol=1e-5,
|
||||
err_msg='Mismatch at test example %d' % i)
|
||||
|
||||
# Test that it loads correctly with v1 load APIs as well.
|
||||
with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as session:
|
||||
tf.compat.v1.saved_model.load(
|
||||
session,
|
||||
[tf.compat.v1.saved_model.SERVING],
|
||||
FLAGS.saved_model_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
|
@ -97,12 +97,7 @@ class SavedModelTest(scripts.TestCase, parameterized.TestCase):
|
||||
fast_test_mode = True
|
||||
temp_dir = self.get_temp_dir()
|
||||
feature_extrator_dir = os.path.join(temp_dir, "mnist_feature_extractor")
|
||||
|
||||
# TODO(b/135043074): remove this if-else.
|
||||
if named_strategy is None:
|
||||
full_model_dir = os.path.join(temp_dir, "full_model")
|
||||
else:
|
||||
full_model_dir = None
|
||||
|
||||
self.assertCommandSucceeded(
|
||||
"export_mnist_cnn",
|
||||
|
@ -776,6 +776,9 @@ class DistributedVariable(DistributedDelegate, variables_lib.AbstractVariable):
|
||||
"""Pass resource_variable_ops.is_resource_variable check."""
|
||||
pass
|
||||
|
||||
def _clone_with_new_values(self, new_values):
|
||||
raise NotImplementedError("Must be implemented in descendents.")
|
||||
|
||||
|
||||
ops.register_dense_tensor_like_type(DistributedVariable)
|
||||
|
||||
@ -1069,6 +1072,10 @@ class MirroredVariable(DistributedVariable, Mirrored):
|
||||
return ops.internal_convert_to_tensor(
|
||||
self.get(), dtype=dtype, name=name, as_ref=as_ref)
|
||||
|
||||
def _clone_with_new_values(self, new_values):
|
||||
return type(self)(self._distribute_strategy, self._device_map, new_values,
|
||||
self._aggregation, logical_device=self._logical_device)
|
||||
|
||||
|
||||
# Register a conversion function which reads the value of the variable,
|
||||
# allowing instances of the class to be used as tensors.
|
||||
@ -1245,6 +1252,10 @@ class SyncOnReadVariable(DistributedVariable):
|
||||
return ops.internal_convert_to_tensor(
|
||||
self.get(), dtype=dtype, name=name, as_ref=as_ref)
|
||||
|
||||
def _clone_with_new_values(self, new_values):
|
||||
return type(self)(self._distribute_strategy, self._device_map, new_values,
|
||||
self._aggregation, logical_device=self._logical_device)
|
||||
|
||||
|
||||
# Register a conversion function for SyncOnReadVariable which allows as_ref to
|
||||
# be true.
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.core.framework import versions_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.core.protobuf import saved_model_pb2
|
||||
from tensorflow.core.protobuf import saved_object_graph_pb2
|
||||
from tensorflow.python.distribute import values as ds_values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as defun
|
||||
@ -240,6 +241,7 @@ class _SaveableView(object):
|
||||
asset_initializers_by_resource={},
|
||||
asset_filename_map={},
|
||||
asset_index={})
|
||||
|
||||
for node_id, obj in enumerate(self.nodes):
|
||||
if isinstance(obj, tracking.CapturableResource):
|
||||
# pylint: disable=protected-access
|
||||
@ -248,6 +250,20 @@ class _SaveableView(object):
|
||||
# pylint: enable=protected-access
|
||||
resource_map[obj.resource_handle] = new_resource
|
||||
self.captured_tensor_node_ids[obj.resource_handle] = node_id
|
||||
elif ds_values.is_distributed_variable(obj):
|
||||
# Put both the distributed variable and component variable handles in
|
||||
# `captured_tensor_node_ids`.
|
||||
# Also create a new distributed variable for `object_map` with newly
|
||||
# created component variables.
|
||||
new_vars = []
|
||||
for v in obj.values:
|
||||
new_variable = resource_variable_ops.copy_to_graph_uninitialized(v)
|
||||
object_map[v] = new_variable
|
||||
new_vars.append(new_variable)
|
||||
resource_map[v.handle] = new_variable.handle
|
||||
self.captured_tensor_node_ids[v.handle] = node_id
|
||||
object_map[obj] = obj._clone_with_new_values(new_vars) # pylint: disable=protected-access
|
||||
self.captured_tensor_node_ids[obj] = node_id
|
||||
elif resource_variable_ops.is_resource_variable(obj):
|
||||
new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
|
||||
object_map[obj] = new_variable
|
||||
|
Loading…
Reference in New Issue
Block a user