Handle distributed variables correctly in tf.saved_model saving. In particular, this allows saving models containing hub.KerasLayer and trained with distribution strategies.

PiperOrigin-RevId: 260557099
This commit is contained in:
Priya Gupta 2019-07-29 12:39:48 -07:00 committed by TensorFlower Gardener
parent fc2f969a74
commit 6801a4b2fb
4 changed files with 36 additions and 6 deletions

View File

@ -79,5 +79,13 @@ def main(argv):
np.testing.assert_allclose(y_lite, y_tf, rtol=0, atol=1e-5,
err_msg='Mismatch at test example %d' % i)
# Test that it loads correctly with v1 load APIs as well.
with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as session:
tf.compat.v1.saved_model.load(
session,
[tf.compat.v1.saved_model.SERVING],
FLAGS.saved_model_dir)
if __name__ == '__main__':
app.run(main)

View File

@ -97,12 +97,7 @@ class SavedModelTest(scripts.TestCase, parameterized.TestCase):
fast_test_mode = True
temp_dir = self.get_temp_dir()
feature_extrator_dir = os.path.join(temp_dir, "mnist_feature_extractor")
# TODO(b/135043074): remove this if-else.
if named_strategy is None:
full_model_dir = os.path.join(temp_dir, "full_model")
else:
full_model_dir = None
self.assertCommandSucceeded(
"export_mnist_cnn",

View File

@ -776,6 +776,9 @@ class DistributedVariable(DistributedDelegate, variables_lib.AbstractVariable):
"""Pass resource_variable_ops.is_resource_variable check."""
pass
def _clone_with_new_values(self, new_values):
raise NotImplementedError("Must be implemented in descendents.")
ops.register_dense_tensor_like_type(DistributedVariable)
@ -1069,6 +1072,10 @@ class MirroredVariable(DistributedVariable, Mirrored):
return ops.internal_convert_to_tensor(
self.get(), dtype=dtype, name=name, as_ref=as_ref)
def _clone_with_new_values(self, new_values):
return type(self)(self._distribute_strategy, self._device_map, new_values,
self._aggregation, logical_device=self._logical_device)
# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
@ -1245,6 +1252,10 @@ class SyncOnReadVariable(DistributedVariable):
return ops.internal_convert_to_tensor(
self.get(), dtype=dtype, name=name, as_ref=as_ref)
def _clone_with_new_values(self, new_values):
return type(self)(self._distribute_strategy, self._device_map, new_values,
self._aggregation, logical_device=self._logical_device)
# Register a conversion function for SyncOnReadVariable which allows as_ref to
# be true.

View File

@ -25,6 +25,7 @@ from tensorflow.core.framework import versions_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.core.protobuf import saved_object_graph_pb2
from tensorflow.python.distribute import values as ds_values
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
@ -240,6 +241,7 @@ class _SaveableView(object):
asset_initializers_by_resource={},
asset_filename_map={},
asset_index={})
for node_id, obj in enumerate(self.nodes):
if isinstance(obj, tracking.CapturableResource):
# pylint: disable=protected-access
@ -248,6 +250,20 @@ class _SaveableView(object):
# pylint: enable=protected-access
resource_map[obj.resource_handle] = new_resource
self.captured_tensor_node_ids[obj.resource_handle] = node_id
elif ds_values.is_distributed_variable(obj):
# Put both the distributed variable and component variable handles in
# `captured_tensor_node_ids`.
# Also create a new distributed variable for `object_map` with newly
# created component variables.
new_vars = []
for v in obj.values:
new_variable = resource_variable_ops.copy_to_graph_uninitialized(v)
object_map[v] = new_variable
new_vars.append(new_variable)
resource_map[v.handle] = new_variable.handle
self.captured_tensor_node_ids[v.handle] = node_id
object_map[obj] = obj._clone_with_new_values(new_vars) # pylint: disable=protected-access
self.captured_tensor_node_ids[obj] = node_id
elif resource_variable_ops.is_resource_variable(obj):
new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
object_map[obj] = new_variable