Add basic support for variables in object-based saved model.
PiperOrigin-RevId: 225539883
This commit is contained in:
parent
3aeb925272
commit
13187e1566
tensorflow/python/saved_model
@ -22,12 +22,15 @@ import os
|
||||
|
||||
from tensorflow.python.framework import function as function_lib
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import function_deserialization
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.saved_model import saved_object_graph_pb2
|
||||
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.training.checkpointable import util
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
@ -47,6 +50,7 @@ class _Loader(object):
|
||||
defined_function.add_to_graph(None)
|
||||
self._defined_functions[defined_function.name] = defined_function
|
||||
self._load_all()
|
||||
self._restore_checkpoint()
|
||||
|
||||
def _load_all(self):
|
||||
self._nodes = [self._recreate(proto) for proto in self._proto.nodes]
|
||||
@ -55,14 +59,21 @@ class _Loader(object):
|
||||
for reference in object_proto.children:
|
||||
setattr(obj, reference.local_name, self._nodes[reference.node_id])
|
||||
|
||||
def _restore_checkpoint(self):
|
||||
variables_path = saved_model_utils.get_variables_path(self._export_dir)
|
||||
saver = util.CheckpointableSaver(self.get(0))
|
||||
saver.restore(variables_path).assert_consumed()
|
||||
|
||||
def get(self, node_id):
|
||||
return self._nodes[node_id]
|
||||
|
||||
def _recreate(self, proto):
|
||||
"""Creates a Python object from a SavedObject protocol buffer."""
|
||||
factory = {
|
||||
"user_object": lambda: self._recreate_user_object(proto.user_object),
|
||||
"asset": lambda: self._recreate_asset(proto.asset),
|
||||
"function": lambda: self._recreate_function(proto.function)
|
||||
"function": lambda: self._recreate_function(proto.function),
|
||||
"variable": lambda: self._recreate_variable(proto.variable),
|
||||
}
|
||||
kind = proto.WhichOneof("kind")
|
||||
if kind not in factory:
|
||||
@ -83,6 +94,11 @@ class _Loader(object):
|
||||
return function_deserialization.recreate_polymorphic_function(
|
||||
proto, self._defined_functions)
|
||||
|
||||
def _recreate_variable(self, proto):
|
||||
# TODO(andresp): Can we use the checkpointed value as initializer?
|
||||
dummy_value = init_ops.Zeros(dtype=proto.dtype)(shape=proto.shape)
|
||||
return variables.Variable(dummy_value)
|
||||
|
||||
|
||||
def _load_saved_object_graph_proto(filename):
|
||||
with file_io.FileIO(filename, "rb") as f:
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.saved_model import load
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
@ -50,6 +51,19 @@ class LoadTest(test.TestCase):
|
||||
self.assertIsNot(imported.dep_one, imported.dep_two)
|
||||
self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy())
|
||||
|
||||
def test_variables(self):
|
||||
root = tracking.Checkpointable()
|
||||
root.f = def_function.function(
|
||||
lambda x: 2. * x,
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||
root.v1 = variables.Variable(1.)
|
||||
root.v2 = variables.Variable(2.)
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(root, save_dir)
|
||||
imported = load.load(save_dir)
|
||||
self.assertEquals(imported.v1.numpy(), 1.0)
|
||||
self.assertEquals(imported.v2.numpy(), 2.0)
|
||||
|
||||
def _make_asset(self, contents):
|
||||
filename = tempfile.mktemp(prefix=self.get_temp_dir())
|
||||
with open(filename, "w") as f:
|
||||
|
@ -541,6 +541,10 @@ def _write_object_proto(obj, proto, asset_file_def_index):
|
||||
if isinstance(obj, tracking.TrackableAsset):
|
||||
proto.asset.SetInParent()
|
||||
proto.asset.asset_file_def_index = asset_file_def_index[obj]
|
||||
elif resource_variable_ops.is_resource_variable(obj):
|
||||
proto.variable.SetInParent()
|
||||
proto.variable.dtype = obj.dtype.as_datatype_enum
|
||||
proto.variable.shape.CopyFrom(obj.shape.as_proto())
|
||||
else:
|
||||
proto.user_object.SetInParent()
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "tensorflow/core/protobuf/checkpointable_object_graph.proto";
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
@ -49,6 +51,7 @@ message SavedObject {
|
||||
SavedUserObject user_object = 4;
|
||||
SavedAsset asset = 5;
|
||||
SavedPolymorphicFunction function = 6;
|
||||
SavedVariable variable = 7;
|
||||
}
|
||||
}
|
||||
|
||||
@ -82,3 +85,12 @@ message SavedMonomorphicFunction {
|
||||
// A reference to a TensorFlow function in the MetaGraph's FunctionDefLibrary
|
||||
string concrete_function = 1;
|
||||
}
|
||||
|
||||
// Represents a Variable that is initialized by loading the contents from the
|
||||
// SavedModel checkpoint.
|
||||
message SavedVariable {
|
||||
DataType dtype = 1;
|
||||
TensorShapeProto shape = 2;
|
||||
|
||||
// TODO(andresp): Add "trainable" and save_slice_info_def.
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user