Add the SavedObjectGraph used by tf.saved_model.save/load to SavedModel

We reference a bunch of things in the MetaGraph/GraphDef, so it makes sense to add it there rather than to the SavedModel directly.

This is in preparation for non-experimental tf.saved_model.save/load symbols. We don't yet have an exposed symbol for loading object-based SavedModels, so this CL won't break anyone (despite moving around the proto and not checking the old location).

RFC: https://github.com/tensorflow/community/pull/34
PiperOrigin-RevId: 234887195
This commit is contained in:
Allen Lavoie 2019-02-20 15:58:42 -08:00 committed by TensorFlower Gardener
parent 69ab50a9b6
commit 39be9adca1
16 changed files with 78 additions and 74 deletions

View File

@ -25,6 +25,7 @@ tensorflow/core/framework/variable.pb.cc
tensorflow/core/framework/versions.pb.cc
tensorflow/core/grappler/costs/op_performance_data.pb.cc
tensorflow/core/lib/core/error_codes.pb.cc
tensorflow/core/protobuf/trackable_object_graph.pb.cc
tensorflow/core/protobuf/cluster.pb.cc
tensorflow/core/protobuf/config.pb.cc
tensorflow/core/protobuf/eager_service.pb.cc
@ -34,7 +35,9 @@ tensorflow/core/protobuf/meta_graph.pb.cc
tensorflow/core/protobuf/named_tensor.pb.cc
tensorflow/core/protobuf/queue_runner.pb.cc
tensorflow/core/protobuf/rewriter_config.pb.cc
tensorflow/core/protobuf/saved_object_graph.pb.cc
tensorflow/core/protobuf/saver.pb.cc
tensorflow/core/protobuf/struct.pb.cc
tensorflow/core/protobuf/tensorflow_server.pb.cc
tensorflow/core/protobuf/verifier_config.pb.cc
tensorflow/core/util/event.pb.cc

View File

@ -25,6 +25,7 @@ tensorflow/core/framework/variable.pb.h
tensorflow/core/framework/versions.pb.h
tensorflow/core/grappler/costs/op_performance_data.pb.h
tensorflow/core/lib/core/error_codes.pb.h
tensorflow/core/protobuf/trackable_object_graph.pb.h
tensorflow/core/protobuf/cluster.pb.h
tensorflow/core/protobuf/config.pb.h
tensorflow/core/protobuf/debug.pb.h
@ -34,7 +35,9 @@ tensorflow/core/protobuf/meta_graph.pb.h
tensorflow/core/protobuf/named_tensor.pb.h
tensorflow/core/protobuf/queue_runner.pb.h
tensorflow/core/protobuf/rewriter_config.pb.h
tensorflow/core/protobuf/saved_object_graph.pb.h
tensorflow/core/protobuf/saver.pb.h
tensorflow/core/protobuf/struct.pb.h
tensorflow/core/protobuf/tensor_bundle.pb.h
tensorflow/core/protobuf/tensorflow_server.pb.h
tensorflow/core/protobuf/verifier_config.pb.h

View File

@ -31,6 +31,7 @@ tensorflow/core/framework/versions.proto
tensorflow/core/grappler/costs/op_performance_data.proto
tensorflow/core/kernels/boosted_trees/boosted_trees.proto
tensorflow/core/lib/core/error_codes.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/cluster.proto
tensorflow/core/protobuf/config.proto
tensorflow/core/protobuf/debug.proto
@ -40,7 +41,9 @@ tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/protobuf/named_tensor.proto
tensorflow/core/protobuf/queue_runner.proto
tensorflow/core/protobuf/rewriter_config.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/tensor_bundle.proto
tensorflow/core/protobuf/tensorflow_server.proto
tensorflow/core/protobuf/verifier_config.proto

View File

@ -236,6 +236,8 @@ ADDITIONAL_CORE_PROTO_SRCS = [
"protobuf/meta_graph.proto",
"protobuf/named_tensor.proto",
"protobuf/saved_model.proto",
"protobuf/saved_object_graph.proto",
"protobuf/struct.proto",
"protobuf/tensorflow_server.proto",
"protobuf/transport_options.proto",
"util/test_log.proto",

View File

@ -12,6 +12,7 @@ import "tensorflow/core/framework/graph.proto";
import "tensorflow/core/framework/op_def.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
import "tensorflow/core/protobuf/saved_object_graph.proto";
import "tensorflow/core/protobuf/saver.proto";
// NOTE: This protocol buffer is evolving, and will go through revisions in the
@ -84,6 +85,9 @@ message MetaGraphDef {
// Asset file def to be used with the defined graph.
repeated AssetFileDef asset_file_def = 6;
// Extra information about the structure of functions and stateful objects.
SavedObjectGraph object_graph_def = 7;
}
// CollectionDef should cover most collections.

View File

@ -1,10 +1,10 @@
syntax = "proto3";
import "tensorflow/core/protobuf/trackable_object_graph.proto";
import "tensorflow/core/protobuf/struct.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
import "tensorflow/core/framework/versions.proto";
import "tensorflow/python/saved_model/struct.proto";
option cc_enable_arenas = true;
@ -15,14 +15,12 @@ package tensorflow;
// languages) that make up a model, with nodes[0] at the root.
// SavedObjectGraph shares some structure with TrackableObjectGraph, but
// ObjectGraph belongs to the SavedModel and contains pointers to functions and
// type information, while TrackableObjectGraph lives in the checkpoint and
// contains pointers only to variable values.
// NOTE: This protocol buffer format is experimental and subject to change.
// SavedObjectGraph belongs to the MetaGraph and contains pointers to functions
// and type information, while TrackableObjectGraph lives in the checkpoint
// and contains pointers only to variable values.
message SavedObjectGraph {
// List of objects in the SavedModel.
// Flattened list of objects in the object graph.
//
// The position of the object in this list indicates its id.
// Nodes[0] is considered the root node.
@ -37,10 +35,11 @@ message SavedObject {
// Objects which this object depends on: named edges in the dependency
// graph.
//
// Note: only valid if kind == "object".
repeated TrackableObjectGraph.TrackableObject.ObjectReference children = 1;
// Note: currently only valid if kind == "user_object".
repeated TrackableObjectGraph.TrackableObject.ObjectReference
children = 1;
// Removed when forking from TrackableObjectGraph.
// Removed when forking SavedObject from TrackableObjectGraph.
reserved "attributes";
reserved 2;
@ -48,7 +47,7 @@ message SavedObject {
// (optimizer, variable, slot variable) relationship; none of the three
// depend on the others directly.
//
// Note: only valid if kind == "object".
// Note: currently only valid if kind == "user_object".
repeated TrackableObjectGraph.TrackableObject.SlotVariableReference
slot_variables = 3;
@ -76,7 +75,7 @@ message SavedUserObject {
VersionDef version = 2;
}
// A SavedAsset represents a file in a SavedModel.
// A SavedAsset points to an asset in the MetaGraph.
//
// When bound to a function this object evaluates to a tensor with the absolute
// filename. Users should not depend on a particular part of the filename to
@ -128,13 +127,11 @@ message SavedConstant {
}
// Represents a Variable that is initialized by loading the contents from the
// SavedModel checkpoint.
// checkpoint.
message SavedVariable {
DataType dtype = 1;
TensorShapeProto shape = 2;
bool trainable = 3;
// TODO(andresp): Add save_slice_info_def?
}
// Represents `FunctionSpec` used in `Function`. This represents a

View File

@ -8,6 +8,27 @@ package tensorflow;
// `StructuredValue` represents a dynamically typed value representing various
// data structures that are inspired by Python data structures typically used in
// TensorFlow functions as inputs and outputs.
//
// For example when saving a Layer there may be a `training` argument. If the
// user passes a boolean True/False, that switches between two concrete
// TensorFlow functions. In order to switch between them in the same way after
// loading the SavedModel, we need to represent "True" and "False".
//
// A more advanced example might be a function which takes a list of
// dictionaries mapping from strings to Tensors. In order to map from
// user-specified arguments `[{"a": tf.constant(1.)}, {"q": tf.constant(3.)}]`
// after load to the right saved TensorFlow function, we need to represent the
// nested structure and the strings, recording that we have a trace for anything
// matching `[{"a": tf.TensorSpec(None, tf.float32)}, {"q": tf.TensorSpec([],
// tf.float64)}]` as an example.
//
// Likewise functions may return nested structures of Tensors, for example
// returning a dictionary mapping from strings to Tensors. In order for the
// loaded function to return the same structure we need to serialize it.
//
// This is an ergonomic aid for working with loaded SavedModels, not a promise
// to serialize all possible function signatures. For example we do not expect
// to pickle generic Python objects, and ideally we'd stay language-agnostic.
message StructuredValue {
// The kind of value.
oneof kind {
@ -29,11 +50,11 @@ message StructuredValue {
// Represents a boolean value.
bool bool_value = 14;
// Represents a tf.TensorShape.
// Represents a TensorShape.
tensorflow.TensorShapeProto tensor_shape_value = 31;
// Represents an enum value for tf.DType.
// Represents an enum value for dtype.
tensorflow.DataType tensor_dtype_value = 32;
// Represents a value for tf.TensorShape.
// Represents a value for tf.TensorSpec.
TensorSpecProto tensor_spec_value = 33;
// Represents a list of `Value`.

View File

@ -291,7 +291,6 @@ py_library(
":function_serialization",
":nested_structure_coder",
":revived_types",
":saved_object_graph_py",
":signature_constants",
":signature_def_utils",
":signature_serialization",
@ -346,8 +345,8 @@ py_library(
":loader",
":nested_structure_coder",
":revived_types",
":saved_object_graph_py",
":utils",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
@ -429,7 +428,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
":saved_object_graph_py",
"//tensorflow/core:protos_all_py",
],
)
@ -438,7 +437,7 @@ tf_py_test(
srcs = ["revived_types_test.py"],
additional_deps = [
":revived_types",
":saved_object_graph_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
],
)
@ -451,7 +450,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":nested_structure_coder",
":saved_object_graph_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:function",
],
@ -469,27 +468,11 @@ py_library(
],
)
tf_proto_library(
name = "struct",
srcs = ["struct.proto"],
cc_api_version = 2,
protodeps = tf_additional_all_protos(),
visibility = ["//tensorflow:internal"],
)
tf_proto_library(
name = "saved_object_graph",
srcs = ["saved_object_graph.proto"],
cc_api_version = 2,
protodeps = tf_additional_all_protos() + [":struct"],
visibility = ["//tensorflow:internal"],
)
py_library(
name = "nested_structure_coder",
srcs = ["nested_structure_coder.py"],
deps = [
":struct_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework",
"@six_archive//:six",
],
@ -500,7 +483,7 @@ tf_py_test(
srcs = ["nested_structure_coder_test.py"],
additional_deps = [
":nested_structure_coder",
":struct_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework",
"//tensorflow/python/eager:test",
],

View File

@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import saved_object_graph_pb2
from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import saved_object_graph_pb2
def _serialize_function_spec(function_spec, coder):

View File

@ -24,23 +24,19 @@ import os
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import function_deserialization
from tensorflow.python.saved_model import load_v1_in_v2
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util
from tensorflow.python.util import compat
from tensorflow.python.util import nest
@ -265,12 +261,6 @@ def _call_attribute(instance, *args, **kwargs):
return instance.__call__(*args, **kwargs)
def _load_saved_object_graph_proto(filename):
with file_io.FileIO(filename, "rb") as f:
contents = f.read()
return saved_object_graph_pb2.SavedObjectGraph.FromString(contents)
def load(export_dir, tags=None):
"""Load a SavedModel from `export_dir`.
@ -315,12 +305,8 @@ def load(export_dir, tags=None):
# Supports e.g. tags=SERVING and tags=[SERVING]
tags = nest.flatten(tags)
saved_model_proto = loader_impl.parse_saved_model(export_dir)
object_graph_filename = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY),
compat.as_bytes("object_graph.pb"))
if (file_io.file_exists(object_graph_filename)
and len(saved_model_proto.meta_graphs) == 1):
if (len(saved_model_proto.meta_graphs) == 1
and saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
meta_graph_def = saved_model_proto.meta_graphs[0]
if (tags is not None
and set(tags) != set(meta_graph_def.meta_info_def.tags)):
@ -329,7 +315,7 @@ def load(export_dir, tags=None):
"incompatible argument tags={} to tf.saved_model.load. You may omit "
"it, pass 'None', or pass matching tags.")
.format(export_dir, meta_graph_def.meta_info_def.tags, tags))
object_graph_proto = _load_saved_object_graph_proto(object_graph_filename)
object_graph_proto = meta_graph_def.object_graph_def
with ops.init_scope():
loader = _Loader(object_graph_proto,
saved_model_proto,

View File

@ -34,10 +34,10 @@ import collections
import functools
import six
from tensorflow.core.protobuf import struct_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.saved_model import struct_pb2
from tensorflow.python.util import compat

View File

@ -20,12 +20,12 @@ from __future__ import print_function
import collections
from tensorflow.core.protobuf import struct_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.platform import test
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import struct_pb2
class NestedStructureTest(test.TestCase):

View File

@ -19,7 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import versions_pb2
from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.core.protobuf import saved_object_graph_pb2
class VersionedTypeRegistration(object):

View File

@ -19,9 +19,9 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import versions_pb2
from tensorflow.core.protobuf import saved_object_graph_pb2
from tensorflow.python.platform import test
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.python.training.tracking import tracking

View File

@ -24,6 +24,7 @@ import os
from tensorflow.core.framework import versions_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.core.protobuf import saved_object_graph_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
@ -40,7 +41,6 @@ from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import function_serialization
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import signature_serialization
@ -542,7 +542,7 @@ def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions):
return asset_info, exported_graph
def _write_object_graph(saveable_view, export_dir, asset_file_def_index):
def _serialize_object_graph(saveable_view, asset_file_def_index):
"""Save a SavedObjectGraph proto for `root`."""
# SavedObjectGraph is similar to the TrackableObjectGraph proto in the
# checkpoint. It will eventually go into the SavedModel.
@ -559,14 +559,7 @@ def _write_object_graph(saveable_view, export_dir, asset_file_def_index):
for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
_write_object_proto(obj, obj_proto, asset_file_def_index)
extra_asset_dir = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
file_io.recursive_create_dir(extra_asset_dir)
object_graph_filename = os.path.join(
extra_asset_dir, compat.as_bytes("object_graph.pb"))
file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
return proto
def _write_object_proto(obj, proto, asset_file_def_index):
@ -814,8 +807,10 @@ def save(obj, export_dir, signatures=None):
path = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
object_graph_proto = _serialize_object_graph(
saveable_view, asset_info.asset_index)
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
file_io.write_string_to_file(path, saved_model.SerializeToString())
_write_object_graph(saveable_view, export_dir, asset_info.asset_index)
# Clean reference cycles so repeated export()s don't make work for the garbage
# collector. Before this point we need to keep references to captured
# constants in the saved graph.

View File

@ -44,6 +44,13 @@ tf_proto {
type: TYPE_MESSAGE
type_name: ".tensorflow.AssetFileDef"
}
field {
name: "object_graph_def"
number: 7
label: LABEL_OPTIONAL
type: TYPE_MESSAGE
type_name: ".tensorflow.SavedObjectGraph"
}
nested_type {
name: "MetaInfoDef"
field {