Add basic support for variables in object-based saved model.

PiperOrigin-RevId: 225539883
This commit is contained in:
A. Unique TensorFlower 2018-12-14 07:40:48 -08:00 committed by TensorFlower Gardener
parent 3aeb925272
commit 13187e1566
4 changed files with 47 additions and 1 deletions

View File

@ -22,12 +22,15 @@ import os
from tensorflow.python.framework import function as function_lib
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import function_deserialization
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
from tensorflow.python.util import compat
@ -47,6 +50,7 @@ class _Loader(object):
defined_function.add_to_graph(None)
self._defined_functions[defined_function.name] = defined_function
self._load_all()
self._restore_checkpoint()
def _load_all(self):
self._nodes = [self._recreate(proto) for proto in self._proto.nodes]
@ -55,14 +59,21 @@ class _Loader(object):
for reference in object_proto.children:
setattr(obj, reference.local_name, self._nodes[reference.node_id])
def _restore_checkpoint(self):
variables_path = saved_model_utils.get_variables_path(self._export_dir)
saver = util.CheckpointableSaver(self.get(0))
saver.restore(variables_path).assert_consumed()
def get(self, node_id):
return self._nodes[node_id]
def _recreate(self, proto):
"""Creates a Python object from a SavedObject protocol buffer."""
factory = {
"user_object": lambda: self._recreate_user_object(proto.user_object),
"asset": lambda: self._recreate_asset(proto.asset),
"function": lambda: self._recreate_function(proto.function)
"function": lambda: self._recreate_function(proto.function),
"variable": lambda: self._recreate_variable(proto.variable),
}
kind = proto.WhichOneof("kind")
if kind not in factory:
@ -83,6 +94,11 @@ class _Loader(object):
return function_deserialization.recreate_polymorphic_function(
proto, self._defined_functions)
def _recreate_variable(self, proto):
# TODO(andresp): Can we use the checkpointed value as initializer?
dummy_value = init_ops.Zeros(dtype=proto.dtype)(shape=proto.shape)
return variables.Variable(dummy_value)
def _load_saved_object_graph_proto(filename):
with file_io.FileIO(filename, "rb") as f:

View File

@ -27,6 +27,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import save
from tensorflow.python.training.checkpointable import tracking
@ -50,6 +51,19 @@ class LoadTest(test.TestCase):
self.assertIsNot(imported.dep_one, imported.dep_two)
self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy())
def test_variables(self):
root = tracking.Checkpointable()
root.f = def_function.function(
lambda x: 2. * x,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
root.v1 = variables.Variable(1.)
root.v2 = variables.Variable(2.)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(root, save_dir)
imported = load.load(save_dir)
self.assertEquals(imported.v1.numpy(), 1.0)
self.assertEquals(imported.v2.numpy(), 2.0)
def _make_asset(self, contents):
filename = tempfile.mktemp(prefix=self.get_temp_dir())
with open(filename, "w") as f:

View File

@ -541,6 +541,10 @@ def _write_object_proto(obj, proto, asset_file_def_index):
if isinstance(obj, tracking.TrackableAsset):
proto.asset.SetInParent()
proto.asset.asset_file_def_index = asset_file_def_index[obj]
elif resource_variable_ops.is_resource_variable(obj):
proto.variable.SetInParent()
proto.variable.dtype = obj.dtype.as_datatype_enum
proto.variable.shape.CopyFrom(obj.shape.as_proto())
else:
proto.user_object.SetInParent()

View File

@ -1,6 +1,8 @@
syntax = "proto3";
import "tensorflow/core/protobuf/checkpointable_object_graph.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
option cc_enable_arenas = true;
@ -49,6 +51,7 @@ message SavedObject {
SavedUserObject user_object = 4;
SavedAsset asset = 5;
SavedPolymorphicFunction function = 6;
SavedVariable variable = 7;
}
}
@ -82,3 +85,12 @@ message SavedMonomorphicFunction {
// A reference to a TensorFlow function in the MetaGraph's FunctionDefLibrary
string concrete_function = 1;
}
// Represents a Variable that is initialized by loading the contents from the
// SavedModel checkpoint.
message SavedVariable {
DataType dtype = 1;
TensorShapeProto shape = 2;
// TODO(andresp): Add "trainable" and save_slice_info_def.
}