Skeleton loading for SavedModels in 2.x
Doesn't do anything useful yet, just a bunch of TODOs. PiperOrigin-RevId: 223812626
This commit is contained in:
parent
f63601af63
commit
9c8b5bf6e7
@ -591,6 +591,7 @@ py_library(
|
||||
"//tensorflow/contrib/distribute/python:tpu_strategy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
"//tensorflow/python/keras",
|
||||
"//third_party/py/numpy",
|
||||
|
@ -12,6 +12,8 @@ licenses(["notice"]) # Apache 2.0
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos")
|
||||
|
||||
py_library(
|
||||
name = "saved_model",
|
||||
@ -21,6 +23,7 @@ py_library(
|
||||
deps = [
|
||||
":builder",
|
||||
":constants",
|
||||
":load",
|
||||
":loader",
|
||||
":main_op",
|
||||
":save",
|
||||
@ -89,7 +92,7 @@ py_library(
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:saver",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
@ -168,14 +171,15 @@ py_test(
|
||||
":signature_def_utils",
|
||||
":tag_constants",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:saver_test_utils",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:test_ops",
|
||||
"//tensorflow/python:training",
|
||||
@ -266,6 +270,14 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "saved_object_graph",
|
||||
srcs = ["saved_object_graph.proto"],
|
||||
cc_api_version = 2,
|
||||
protodeps = tf_additional_all_protos(),
|
||||
visibility = ["//tensorflow:internal"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "save",
|
||||
srcs = [
|
||||
@ -273,16 +285,24 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":loader",
|
||||
":constants",
|
||||
":saved_object_graph_py",
|
||||
":signature_constants",
|
||||
":signature_def_utils",
|
||||
":tag_constants",
|
||||
":utils",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/eager:function",
|
||||
"//tensorflow/python/training/checkpointable:base",
|
||||
"//tensorflow/python/training/checkpointable:util",
|
||||
],
|
||||
)
|
||||
|
||||
@ -291,13 +311,42 @@ py_test(
|
||||
srcs = ["save_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":loader",
|
||||
":save",
|
||||
":signature_constants",
|
||||
":tag_constants",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Google-internal targets. These must be at the end for syncrepo.
|
||||
py_library(
|
||||
name = "load",
|
||||
srcs = [
|
||||
"load.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":loader",
|
||||
":saved_object_graph_py",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/training/checkpointable:tracking",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "load_test",
|
||||
srcs = ["load_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":load",
|
||||
":save",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/training/checkpointable:tracking",
|
||||
],
|
||||
)
|
||||
|
@ -29,6 +29,9 @@ tf_export(
|
||||
"saved_model.ASSETS_DIRECTORY", "saved_model.constants.ASSETS_DIRECTORY"
|
||||
]).export_constant(__name__, "ASSETS_DIRECTORY")
|
||||
|
||||
# Subdirectory name containing unmanaged files from higher-level APIs.
|
||||
EXTRA_ASSETS_DIRECTORY = "assets.extra"
|
||||
|
||||
# CollectionDef key containing SavedModel assets.
|
||||
ASSETS_KEY = "saved_model_assets"
|
||||
tf_export(
|
||||
|
61
tensorflow/python/saved_model/load.py
Normal file
61
tensorflow/python/saved_model/load.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Import a checkpointable object from a SavedModel."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import saved_object_graph_pb2
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
def _recreate_object_graph(object_graph_proto):
|
||||
"""Recreates Python objects from an ObjectGraph proto."""
|
||||
objects = []
|
||||
for _ in object_graph_proto.nodes:
|
||||
# TODO(allenl): re-create variables and other types
|
||||
objects.append(tracking.Checkpointable())
|
||||
for obj, object_proto in zip(objects, object_graph_proto.nodes):
|
||||
for reference in object_proto.children:
|
||||
setattr(obj, reference.local_name, objects[reference.node_id])
|
||||
return objects[0]
|
||||
|
||||
|
||||
def load(export_dir):
|
||||
"""Load a SavedModel from `export_dir`."""
|
||||
object_graph_filename = os.path.join(
|
||||
compat.as_bytes(export_dir),
|
||||
compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY),
|
||||
compat.as_bytes("object_graph.pb"))
|
||||
if file_io.file_exists(object_graph_filename):
|
||||
# If there is an object graph associated with the SavedModel, we'll create a
|
||||
# root object from that.
|
||||
object_graph_string = file_io.FileIO(object_graph_filename, "rb").read()
|
||||
object_graph_proto = (
|
||||
saved_object_graph_pb2.SavedObjectGraph())
|
||||
object_graph_proto.ParseFromString(object_graph_string)
|
||||
root = _recreate_object_graph(object_graph_proto)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Currently only SavedModels exported with `tf.saved_model.save` may be "
|
||||
"imported. Other SavedModels may eventually be supported via load().")
|
||||
# TODO(allenl): load functions from the SavedModel into the eager context
|
||||
return root
|
51
tensorflow/python/saved_model/load_test.py
Normal file
51
tensorflow/python/saved_model/load_test.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for checkpointable object SavedModel loading."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.saved_model import load
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
|
||||
|
||||
class LoadTest(test.TestCase):
|
||||
|
||||
def test_structure_import(self):
|
||||
root = tracking.Checkpointable()
|
||||
root.f = def_function.function(
|
||||
lambda x: 2. * x,
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||
root.dep_one = tracking.Checkpointable()
|
||||
root.dep_two = tracking.Checkpointable()
|
||||
root.dep_two.dep = tracking.Checkpointable()
|
||||
root.dep_three = root.dep_two.dep
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(root, save_dir)
|
||||
imported = load.load(save_dir)
|
||||
self.assertIs(imported.dep_three, imported.dep_two.dep)
|
||||
self.assertIsNot(imported.dep_one, imported.dep_two)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -32,6 +32,7 @@ from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import saved_object_graph_pb2
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.saved_model import signature_def_utils
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
@ -400,6 +401,22 @@ def _make_graph_def(root, signature_functions, object_saver):
|
||||
return graph_def, signatures, saver_def
|
||||
|
||||
|
||||
def _write_object_graph(obj, export_dir):
|
||||
"""Save a SavedObjectGraph proto for `obj`."""
|
||||
# SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
|
||||
# checkpoint. It will eventually go into the SavedModel.
|
||||
object_proto = util.make_object_graph_without_attributes(
|
||||
obj, proto=saved_object_graph_pb2.SavedObjectGraph())
|
||||
extra_asset_dir = os.path.join(
|
||||
compat.as_bytes(export_dir),
|
||||
compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
|
||||
file_io.recursive_create_dir(extra_asset_dir)
|
||||
object_graph_filename = os.path.join(
|
||||
extra_asset_dir, compat.as_bytes("object_graph.pb"))
|
||||
file_io.write_string_to_file(object_graph_filename,
|
||||
object_proto.SerializeToString())
|
||||
|
||||
|
||||
@tf_export("saved_model.save", v1=["saved_model.experimental.save"])
|
||||
def save(obj, export_dir, signatures=None):
|
||||
# pylint: disable=line-too-long
|
||||
@ -580,3 +597,4 @@ def save(obj, export_dir, signatures=None):
|
||||
compat.as_bytes(export_dir),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
|
||||
file_io.write_string_to_file(path, saved_model.SerializeToString())
|
||||
_write_object_graph(obj, export_dir)
|
||||
|
38
tensorflow/python/saved_model/saved_object_graph.proto
Normal file
38
tensorflow/python/saved_model/saved_object_graph.proto
Normal file
@ -0,0 +1,38 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "tensorflow/core/protobuf/checkpointable_object_graph.proto";
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
package tensorflow;
|
||||
|
||||
// A SavedObjectGraph is part of object-based SavedModels in TF 2.0. It
|
||||
// describes the directed graph of Python objects (or equivalent in other
|
||||
// languages) that make up a model, with nodes[0] at the root.
|
||||
|
||||
// SavedObjectGraph shares some structure with CheckpointableObjectGraph, but
|
||||
// ObjectGraph belongs to the SavedModel and contains pointers to functions and
|
||||
// type information, while CheckpointableObjectGraph lives in the checkpoint and
|
||||
// contains pointers only to variable values.
|
||||
|
||||
// NOTE: This protocol buffer format is experimental and subject to change.
|
||||
|
||||
message SavedObjectGraph {
|
||||
message SavedObject {
|
||||
// Objects which this object depends on: named edges in the dependency
|
||||
// graph.
|
||||
repeated CheckpointableObjectGraph.CheckpointableObject.ObjectReference
|
||||
children = 1;
|
||||
// Removed when forking from CheckpointableObjectGraph.
|
||||
reserved "attributes";
|
||||
reserved 2;
|
||||
// Slot variables owned by this object. This describes the three-way
|
||||
// (optimizer, variable, slot variable) relationship; none of the three
|
||||
// depend on the others directly.
|
||||
repeated
|
||||
CheckpointableObjectGraph.CheckpointableObject.SlotVariableReference
|
||||
slot_variables = 3;
|
||||
}
|
||||
|
||||
repeated SavedObject nodes = 1;
|
||||
}
|
@ -648,10 +648,12 @@ def _add_attributes_to_object_graph(
|
||||
return named_saveable_objects, feed_additions
|
||||
|
||||
|
||||
def _make_object_graph_proto(checkpointable_objects, node_ids, slot_variables):
|
||||
def _fill_object_graph_proto(checkpointable_objects, node_ids, slot_variables,
|
||||
object_graph_proto=None):
|
||||
"""Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
|
||||
object_graph_proto = (
|
||||
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
|
||||
if object_graph_proto is None:
|
||||
object_graph_proto = (
|
||||
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
|
||||
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
|
||||
assert node_ids[checkpointable] == checkpoint_id
|
||||
object_proto = object_graph_proto.nodes.add()
|
||||
@ -676,7 +678,7 @@ def _serialize_gathered_objects(
|
||||
checkpointable_objects=checkpointable_objects,
|
||||
node_ids=node_ids,
|
||||
object_names=object_names)
|
||||
object_graph_proto = _make_object_graph_proto(
|
||||
object_graph_proto = _fill_object_graph_proto(
|
||||
checkpointable_objects=checkpointable_objects,
|
||||
node_ids=node_ids,
|
||||
slot_variables=slot_variables)
|
||||
@ -764,12 +766,12 @@ def list_objects(root_checkpointable):
|
||||
return checkpointable_objects
|
||||
|
||||
|
||||
def make_object_graph_without_attributes(root_checkpointable):
|
||||
"""Construct a CheckpointableObjectGraph proto with no variable values."""
|
||||
def make_object_graph_without_attributes(root_checkpointable, proto=None):
|
||||
"""Fill an object graph proto, ignoring variable values."""
|
||||
checkpointable_objects, node_ids, slot_variables = _find_objects(
|
||||
root_checkpointable)
|
||||
return _make_object_graph_proto(
|
||||
checkpointable_objects, node_ids, slot_variables)
|
||||
return _fill_object_graph_proto(
|
||||
checkpointable_objects, node_ids, slot_variables, proto)
|
||||
|
||||
|
||||
def gather_initializers(root_checkpointable):
|
||||
@ -1924,3 +1926,4 @@ class Checkpoint(tracking.Checkpointable):
|
||||
# initialization when executing eagerly.
|
||||
self._maybe_create_save_counter()
|
||||
return status
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user