From 13187e1566e74a1f9434f5bb16c0cddc076ac497 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Dec 2018 07:40:48 -0800 Subject: [PATCH] Add basic support for variables in object-based saved model. PiperOrigin-RevId: 225539883 --- tensorflow/python/saved_model/load.py | 18 +++++++++++++++++- tensorflow/python/saved_model/load_test.py | 14 ++++++++++++++ tensorflow/python/saved_model/save.py | 4 ++++ .../saved_model/saved_object_graph.proto | 12 ++++++++++++ 4 files changed, 47 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py index 28c0af2b657..9d9f60c69dd 100644 --- a/tensorflow/python/saved_model/load.py +++ b/tensorflow/python/saved_model/load.py @@ -22,12 +22,15 @@ import os from tensorflow.python.framework import function as function_lib from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variables from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import function_deserialization from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import saved_object_graph_pb2 from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util from tensorflow.python.util import compat @@ -47,6 +50,7 @@ class _Loader(object): defined_function.add_to_graph(None) self._defined_functions[defined_function.name] = defined_function self._load_all() + self._restore_checkpoint() def _load_all(self): self._nodes = [self._recreate(proto) for proto in self._proto.nodes] @@ -55,14 +59,21 @@ class _Loader(object): for reference in object_proto.children: setattr(obj, reference.local_name, self._nodes[reference.node_id]) + def _restore_checkpoint(self): + variables_path = saved_model_utils.get_variables_path(self._export_dir) + saver = util.CheckpointableSaver(self.get(0)) + saver.restore(variables_path).assert_consumed() + def get(self, node_id): return self._nodes[node_id] def _recreate(self, proto): + """Creates a Python object from a SavedObject protocol buffer.""" factory = { "user_object": lambda: self._recreate_user_object(proto.user_object), "asset": lambda: self._recreate_asset(proto.asset), - "function": lambda: self._recreate_function(proto.function) + "function": lambda: self._recreate_function(proto.function), + "variable": lambda: self._recreate_variable(proto.variable), } kind = proto.WhichOneof("kind") if kind not in factory: @@ -83,6 +94,11 @@ class _Loader(object): return function_deserialization.recreate_polymorphic_function( proto, self._defined_functions) + def _recreate_variable(self, proto): + # TODO(andresp): Can we use the checkpointed value as initializer? + dummy_value = init_ops.Zeros(dtype=proto.dtype)(shape=proto.shape) + return variables.Variable(dummy_value) + def _load_saved_object_graph_proto(filename): with file_io.FileIO(filename, "rb") as f: diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index 303b8f66efc..ba88668f8c7 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_spec from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import variables from tensorflow.python.saved_model import load from tensorflow.python.saved_model import save from tensorflow.python.training.checkpointable import tracking @@ -50,6 +51,19 @@ class LoadTest(test.TestCase): self.assertIsNot(imported.dep_one, imported.dep_two) self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy()) + def test_variables(self): + root = tracking.Checkpointable() + root.f = def_function.function( + lambda x: 2. * x, + input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) + root.v1 = variables.Variable(1.) + root.v2 = variables.Variable(2.) + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + save.save(root, save_dir) + imported = load.load(save_dir) + self.assertEquals(imported.v1.numpy(), 1.0) + self.assertEquals(imported.v2.numpy(), 2.0) + def _make_asset(self, contents): filename = tempfile.mktemp(prefix=self.get_temp_dir()) with open(filename, "w") as f: diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index 57c63f8cdac..6c2d5e6f2bb 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -541,6 +541,10 @@ def _write_object_proto(obj, proto, asset_file_def_index): if isinstance(obj, tracking.TrackableAsset): proto.asset.SetInParent() proto.asset.asset_file_def_index = asset_file_def_index[obj] + elif resource_variable_ops.is_resource_variable(obj): + proto.variable.SetInParent() + proto.variable.dtype = obj.dtype.as_datatype_enum + proto.variable.shape.CopyFrom(obj.shape.as_proto()) else: proto.user_object.SetInParent() diff --git a/tensorflow/python/saved_model/saved_object_graph.proto b/tensorflow/python/saved_model/saved_object_graph.proto index ed5c63935ff..b95990ad348 100644 --- a/tensorflow/python/saved_model/saved_object_graph.proto +++ b/tensorflow/python/saved_model/saved_object_graph.proto @@ -1,6 +1,8 @@ syntax = "proto3"; import "tensorflow/core/protobuf/checkpointable_object_graph.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; option cc_enable_arenas = true; @@ -49,6 +51,7 @@ message SavedObject { SavedUserObject user_object = 4; SavedAsset asset = 5; SavedPolymorphicFunction function = 6; + SavedVariable variable = 7; } } @@ -82,3 +85,12 @@ message SavedMonomorphicFunction { // A reference to a TensorFlow function in the MetaGraph's FunctionDefLibrary string concrete_function = 1; } + +// Represents a Variable that is initialized by loading the contents from the +// SavedModel checkpoint. +message SavedVariable { + DataType dtype = 1; + TensorShapeProto shape = 2; + + // TODO(andresp): Add "trainable" and save_slice_info_def. +}