Skeleton loading for SavedModels in 2.x

Doesn't do anything useful yet, just a bunch of TODOs.

PiperOrigin-RevId: 223812626
This commit is contained in:
Allen Lavoie 2018-12-03 09:53:46 -08:00 committed by TensorFlower Gardener
parent f63601af63
commit 9c8b5bf6e7
8 changed files with 238 additions and 14 deletions

View File

@ -591,6 +591,7 @@ py_library(
"//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
"//tensorflow/python/eager:test",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/keras",
"//third_party/py/numpy",

View File

@ -12,6 +12,8 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos")
py_library(
name = "saved_model",
@ -21,6 +23,7 @@ py_library(
deps = [
":builder",
":constants",
":load",
":loader",
":main_op",
":save",
@ -89,7 +92,7 @@ py_library(
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:lib",
"//tensorflow/python:platform",
"//tensorflow/python:training",
"//tensorflow/python:saver",
"//tensorflow/python:util",
"//tensorflow/python:variables",
],
@ -168,14 +171,15 @@ py_test(
":signature_def_utils",
":tag_constants",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:saver_test_utils",
"//tensorflow/python:session",
"//tensorflow/python:state_ops",
"//tensorflow/python:test_ops",
"//tensorflow/python:training",
@ -266,6 +270,14 @@ py_test(
],
)
tf_proto_library(
name = "saved_object_graph",
srcs = ["saved_object_graph.proto"],
cc_api_version = 2,
protodeps = tf_additional_all_protos(),
visibility = ["//tensorflow:internal"],
)
py_library(
name = "save",
srcs = [
@ -273,16 +285,24 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
":loader",
":constants",
":saved_object_graph_py",
":signature_constants",
":signature_def_utils",
":tag_constants",
":utils",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"//tensorflow/python/eager:function",
"//tensorflow/python/training/checkpointable:base",
"//tensorflow/python/training/checkpointable:util",
],
)
@ -291,13 +311,42 @@ py_test(
srcs = ["save_test.py"],
srcs_version = "PY2AND3",
deps = [
":loader",
":save",
":signature_constants",
":tag_constants",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"@absl_py//absl/testing:parameterized",
],
)
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
py_library(
name = "load",
srcs = [
"load.py",
],
srcs_version = "PY2AND3",
deps = [
":loader",
":saved_object_graph_py",
"//tensorflow/python:lib",
"//tensorflow/python:util",
"//tensorflow/python/training/checkpointable:tracking",
],
)
py_test(
name = "load_test",
srcs = ["load_test.py"],
srcs_version = "PY2AND3",
deps = [
":load",
":save",
"//tensorflow/python:dtypes",
"//tensorflow/python:tensor_spec",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"//tensorflow/python/training/checkpointable:tracking",
],
)

View File

@ -29,6 +29,9 @@ tf_export(
"saved_model.ASSETS_DIRECTORY", "saved_model.constants.ASSETS_DIRECTORY"
]).export_constant(__name__, "ASSETS_DIRECTORY")
# Subdirectory name containing unmanaged files from higher-level APIs.
EXTRA_ASSETS_DIRECTORY = "assets.extra"
# CollectionDef key containing SavedModel assets.
ASSETS_KEY = "saved_model_assets"
tf_export(

View File

@ -0,0 +1,61 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Import a checkpointable object from a SavedModel."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.util import compat
def _recreate_object_graph(object_graph_proto):
"""Recreates Python objects from an ObjectGraph proto."""
objects = []
for _ in object_graph_proto.nodes:
# TODO(allenl): re-create variables and other types
objects.append(tracking.Checkpointable())
for obj, object_proto in zip(objects, object_graph_proto.nodes):
for reference in object_proto.children:
setattr(obj, reference.local_name, objects[reference.node_id])
return objects[0]
def load(export_dir):
"""Load a SavedModel from `export_dir`."""
object_graph_filename = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY),
compat.as_bytes("object_graph.pb"))
if file_io.file_exists(object_graph_filename):
# If there is an object graph associated with the SavedModel, we'll create a
# root object from that.
object_graph_string = file_io.FileIO(object_graph_filename, "rb").read()
object_graph_proto = (
saved_object_graph_pb2.SavedObjectGraph())
object_graph_proto.ParseFromString(object_graph_string)
root = _recreate_object_graph(object_graph_proto)
else:
raise NotImplementedError(
"Currently only SavedModels exported with `tf.saved_model.save` may be "
"imported. Other SavedModels may eventually be supported via load().")
# TODO(allenl): load functions from the SavedModel into the eager context
return root

View File

@ -0,0 +1,51 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for checkpointable object SavedModel loading."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import save
from tensorflow.python.training.checkpointable import tracking
class LoadTest(test.TestCase):
def test_structure_import(self):
root = tracking.Checkpointable()
root.f = def_function.function(
lambda x: 2. * x,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
root.dep_one = tracking.Checkpointable()
root.dep_two = tracking.Checkpointable()
root.dep_two.dep = tracking.Checkpointable()
root.dep_three = root.dep_two.dep
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(root, save_dir)
imported = load.load(save_dir)
self.assertIs(imported.dep_three, imported.dep_two.dep)
self.assertIsNot(imported.dep_one, imported.dep_two)
if __name__ == "__main__":
test.main()

View File

@ -32,6 +32,7 @@ from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
@ -400,6 +401,22 @@ def _make_graph_def(root, signature_functions, object_saver):
return graph_def, signatures, saver_def
def _write_object_graph(obj, export_dir):
"""Save a SavedObjectGraph proto for `obj`."""
# SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
# checkpoint. It will eventually go into the SavedModel.
object_proto = util.make_object_graph_without_attributes(
obj, proto=saved_object_graph_pb2.SavedObjectGraph())
extra_asset_dir = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
file_io.recursive_create_dir(extra_asset_dir)
object_graph_filename = os.path.join(
extra_asset_dir, compat.as_bytes("object_graph.pb"))
file_io.write_string_to_file(object_graph_filename,
object_proto.SerializeToString())
@tf_export("saved_model.save", v1=["saved_model.experimental.save"])
def save(obj, export_dir, signatures=None):
# pylint: disable=line-too-long
@ -580,3 +597,4 @@ def save(obj, export_dir, signatures=None):
compat.as_bytes(export_dir),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
file_io.write_string_to_file(path, saved_model.SerializeToString())
_write_object_graph(obj, export_dir)

View File

@ -0,0 +1,38 @@
syntax = "proto3";
import "tensorflow/core/protobuf/checkpointable_object_graph.proto";
option cc_enable_arenas = true;
package tensorflow;
// A SavedObjectGraph is part of object-based SavedModels in TF 2.0. It
// describes the directed graph of Python objects (or equivalent in other
// languages) that make up a model, with nodes[0] at the root.
// SavedObjectGraph shares some structure with CheckpointableObjectGraph, but
// ObjectGraph belongs to the SavedModel and contains pointers to functions and
// type information, while CheckpointableObjectGraph lives in the checkpoint and
// contains pointers only to variable values.
// NOTE: This protocol buffer format is experimental and subject to change.
message SavedObjectGraph {
message SavedObject {
// Objects which this object depends on: named edges in the dependency
// graph.
repeated CheckpointableObjectGraph.CheckpointableObject.ObjectReference
children = 1;
// Removed when forking from CheckpointableObjectGraph.
reserved "attributes";
reserved 2;
// Slot variables owned by this object. This describes the three-way
// (optimizer, variable, slot variable) relationship; none of the three
// depend on the others directly.
repeated
CheckpointableObjectGraph.CheckpointableObject.SlotVariableReference
slot_variables = 3;
}
repeated SavedObject nodes = 1;
}

View File

@ -648,10 +648,12 @@ def _add_attributes_to_object_graph(
return named_saveable_objects, feed_additions
def _make_object_graph_proto(checkpointable_objects, node_ids, slot_variables):
def _fill_object_graph_proto(checkpointable_objects, node_ids, slot_variables,
object_graph_proto=None):
"""Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
if object_graph_proto is None:
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
assert node_ids[checkpointable] == checkpoint_id
object_proto = object_graph_proto.nodes.add()
@ -676,7 +678,7 @@ def _serialize_gathered_objects(
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
object_names=object_names)
object_graph_proto = _make_object_graph_proto(
object_graph_proto = _fill_object_graph_proto(
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
slot_variables=slot_variables)
@ -764,12 +766,12 @@ def list_objects(root_checkpointable):
return checkpointable_objects
def make_object_graph_without_attributes(root_checkpointable):
"""Construct a CheckpointableObjectGraph proto with no variable values."""
def make_object_graph_without_attributes(root_checkpointable, proto=None):
"""Fill an object graph proto, ignoring variable values."""
checkpointable_objects, node_ids, slot_variables = _find_objects(
root_checkpointable)
return _make_object_graph_proto(
checkpointable_objects, node_ids, slot_variables)
return _fill_object_graph_proto(
checkpointable_objects, node_ids, slot_variables, proto)
def gather_initializers(root_checkpointable):
@ -1924,3 +1926,4 @@ class Checkpoint(tracking.Checkpointable):
# initialization when executing eagerly.
self._maybe_create_save_counter()
return status