Add the SavedObjectGraph used by tf.saved_model.save/load to SavedModel

We reference a bunch of things in the MetaGraph/GraphDef, so it makes sense to add it there rather than to the SavedModel directly.

This is in preparation for non-experimental tf.saved_model.save/load symbols. We don't yet have an exposed symbol for loading object-based SavedModels, so this CL won't break anyone (despite moving around the proto and not checking the old location).

RFC: https://github.com/tensorflow/community/pull/34
PiperOrigin-RevId: 234887195
This commit is contained in:
Allen Lavoie 2019-02-20 15:58:42 -08:00 committed by TensorFlower Gardener
parent 69ab50a9b6
commit 39be9adca1
16 changed files with 78 additions and 74 deletions

View File

@ -25,6 +25,7 @@ tensorflow/core/framework/variable.pb.cc
tensorflow/core/framework/versions.pb.cc tensorflow/core/framework/versions.pb.cc
tensorflow/core/grappler/costs/op_performance_data.pb.cc tensorflow/core/grappler/costs/op_performance_data.pb.cc
tensorflow/core/lib/core/error_codes.pb.cc tensorflow/core/lib/core/error_codes.pb.cc
tensorflow/core/protobuf/trackable_object_graph.pb.cc
tensorflow/core/protobuf/cluster.pb.cc tensorflow/core/protobuf/cluster.pb.cc
tensorflow/core/protobuf/config.pb.cc tensorflow/core/protobuf/config.pb.cc
tensorflow/core/protobuf/eager_service.pb.cc tensorflow/core/protobuf/eager_service.pb.cc
@ -34,7 +35,9 @@ tensorflow/core/protobuf/meta_graph.pb.cc
tensorflow/core/protobuf/named_tensor.pb.cc tensorflow/core/protobuf/named_tensor.pb.cc
tensorflow/core/protobuf/queue_runner.pb.cc tensorflow/core/protobuf/queue_runner.pb.cc
tensorflow/core/protobuf/rewriter_config.pb.cc tensorflow/core/protobuf/rewriter_config.pb.cc
tensorflow/core/protobuf/saved_object_graph.pb.cc
tensorflow/core/protobuf/saver.pb.cc tensorflow/core/protobuf/saver.pb.cc
tensorflow/core/protobuf/struct.pb.cc
tensorflow/core/protobuf/tensorflow_server.pb.cc tensorflow/core/protobuf/tensorflow_server.pb.cc
tensorflow/core/protobuf/verifier_config.pb.cc tensorflow/core/protobuf/verifier_config.pb.cc
tensorflow/core/util/event.pb.cc tensorflow/core/util/event.pb.cc

View File

@ -25,6 +25,7 @@ tensorflow/core/framework/variable.pb.h
tensorflow/core/framework/versions.pb.h tensorflow/core/framework/versions.pb.h
tensorflow/core/grappler/costs/op_performance_data.pb.h tensorflow/core/grappler/costs/op_performance_data.pb.h
tensorflow/core/lib/core/error_codes.pb.h tensorflow/core/lib/core/error_codes.pb.h
tensorflow/core/protobuf/trackable_object_graph.pb.h
tensorflow/core/protobuf/cluster.pb.h tensorflow/core/protobuf/cluster.pb.h
tensorflow/core/protobuf/config.pb.h tensorflow/core/protobuf/config.pb.h
tensorflow/core/protobuf/debug.pb.h tensorflow/core/protobuf/debug.pb.h
@ -34,7 +35,9 @@ tensorflow/core/protobuf/meta_graph.pb.h
tensorflow/core/protobuf/named_tensor.pb.h tensorflow/core/protobuf/named_tensor.pb.h
tensorflow/core/protobuf/queue_runner.pb.h tensorflow/core/protobuf/queue_runner.pb.h
tensorflow/core/protobuf/rewriter_config.pb.h tensorflow/core/protobuf/rewriter_config.pb.h
tensorflow/core/protobuf/saved_object_graph.pb.h
tensorflow/core/protobuf/saver.pb.h tensorflow/core/protobuf/saver.pb.h
tensorflow/core/protobuf/struct.pb.h
tensorflow/core/protobuf/tensor_bundle.pb.h tensorflow/core/protobuf/tensor_bundle.pb.h
tensorflow/core/protobuf/tensorflow_server.pb.h tensorflow/core/protobuf/tensorflow_server.pb.h
tensorflow/core/protobuf/verifier_config.pb.h tensorflow/core/protobuf/verifier_config.pb.h

View File

@ -31,6 +31,7 @@ tensorflow/core/framework/versions.proto
tensorflow/core/grappler/costs/op_performance_data.proto tensorflow/core/grappler/costs/op_performance_data.proto
tensorflow/core/kernels/boosted_trees/boosted_trees.proto tensorflow/core/kernels/boosted_trees/boosted_trees.proto
tensorflow/core/lib/core/error_codes.proto tensorflow/core/lib/core/error_codes.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/cluster.proto tensorflow/core/protobuf/cluster.proto
tensorflow/core/protobuf/config.proto tensorflow/core/protobuf/config.proto
tensorflow/core/protobuf/debug.proto tensorflow/core/protobuf/debug.proto
@ -40,7 +41,9 @@ tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/protobuf/named_tensor.proto tensorflow/core/protobuf/named_tensor.proto
tensorflow/core/protobuf/queue_runner.proto tensorflow/core/protobuf/queue_runner.proto
tensorflow/core/protobuf/rewriter_config.proto tensorflow/core/protobuf/rewriter_config.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/tensor_bundle.proto tensorflow/core/protobuf/tensor_bundle.proto
tensorflow/core/protobuf/tensorflow_server.proto tensorflow/core/protobuf/tensorflow_server.proto
tensorflow/core/protobuf/verifier_config.proto tensorflow/core/protobuf/verifier_config.proto

View File

@ -236,6 +236,8 @@ ADDITIONAL_CORE_PROTO_SRCS = [
"protobuf/meta_graph.proto", "protobuf/meta_graph.proto",
"protobuf/named_tensor.proto", "protobuf/named_tensor.proto",
"protobuf/saved_model.proto", "protobuf/saved_model.proto",
"protobuf/saved_object_graph.proto",
"protobuf/struct.proto",
"protobuf/tensorflow_server.proto", "protobuf/tensorflow_server.proto",
"protobuf/transport_options.proto", "protobuf/transport_options.proto",
"util/test_log.proto", "util/test_log.proto",

View File

@ -12,6 +12,7 @@ import "tensorflow/core/framework/graph.proto";
import "tensorflow/core/framework/op_def.proto"; import "tensorflow/core/framework/op_def.proto";
import "tensorflow/core/framework/tensor_shape.proto"; import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto"; import "tensorflow/core/framework/types.proto";
import "tensorflow/core/protobuf/saved_object_graph.proto";
import "tensorflow/core/protobuf/saver.proto"; import "tensorflow/core/protobuf/saver.proto";
// NOTE: This protocol buffer is evolving, and will go through revisions in the // NOTE: This protocol buffer is evolving, and will go through revisions in the
@ -84,6 +85,9 @@ message MetaGraphDef {
// Asset file def to be used with the defined graph. // Asset file def to be used with the defined graph.
repeated AssetFileDef asset_file_def = 6; repeated AssetFileDef asset_file_def = 6;
// Extra information about the structure of functions and stateful objects.
SavedObjectGraph object_graph_def = 7;
} }
// CollectionDef should cover most collections. // CollectionDef should cover most collections.

View File

@ -1,10 +1,10 @@
syntax = "proto3"; syntax = "proto3";
import "tensorflow/core/protobuf/trackable_object_graph.proto"; import "tensorflow/core/protobuf/trackable_object_graph.proto";
import "tensorflow/core/protobuf/struct.proto";
import "tensorflow/core/framework/tensor_shape.proto"; import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto"; import "tensorflow/core/framework/types.proto";
import "tensorflow/core/framework/versions.proto"; import "tensorflow/core/framework/versions.proto";
import "tensorflow/python/saved_model/struct.proto";
option cc_enable_arenas = true; option cc_enable_arenas = true;
@ -15,14 +15,12 @@ package tensorflow;
// languages) that make up a model, with nodes[0] at the root. // languages) that make up a model, with nodes[0] at the root.
// SavedObjectGraph shares some structure with TrackableObjectGraph, but // SavedObjectGraph shares some structure with TrackableObjectGraph, but
// ObjectGraph belongs to the SavedModel and contains pointers to functions and // SavedObjectGraph belongs to the MetaGraph and contains pointers to functions
// type information, while TrackableObjectGraph lives in the checkpoint and // and type information, while TrackableObjectGraph lives in the checkpoint
// contains pointers only to variable values. // and contains pointers only to variable values.
// NOTE: This protocol buffer format is experimental and subject to change.
message SavedObjectGraph { message SavedObjectGraph {
// List of objects in the SavedModel. // Flattened list of objects in the object graph.
// //
// The position of the object in this list indicates its id. // The position of the object in this list indicates its id.
// Nodes[0] is considered the root node. // Nodes[0] is considered the root node.
@ -37,10 +35,11 @@ message SavedObject {
// Objects which this object depends on: named edges in the dependency // Objects which this object depends on: named edges in the dependency
// graph. // graph.
// //
// Note: only valid if kind == "object". // Note: currently only valid if kind == "user_object".
repeated TrackableObjectGraph.TrackableObject.ObjectReference children = 1; repeated TrackableObjectGraph.TrackableObject.ObjectReference
children = 1;
// Removed when forking from TrackableObjectGraph. // Removed when forking SavedObject from TrackableObjectGraph.
reserved "attributes"; reserved "attributes";
reserved 2; reserved 2;
@ -48,7 +47,7 @@ message SavedObject {
// (optimizer, variable, slot variable) relationship; none of the three // (optimizer, variable, slot variable) relationship; none of the three
// depend on the others directly. // depend on the others directly.
// //
// Note: only valid if kind == "object". // Note: currently only valid if kind == "user_object".
repeated TrackableObjectGraph.TrackableObject.SlotVariableReference repeated TrackableObjectGraph.TrackableObject.SlotVariableReference
slot_variables = 3; slot_variables = 3;
@ -76,7 +75,7 @@ message SavedUserObject {
VersionDef version = 2; VersionDef version = 2;
} }
// A SavedAsset represents a file in a SavedModel. // A SavedAsset points to an asset in the MetaGraph.
// //
// When bound to a function this object evaluates to a tensor with the absolute // When bound to a function this object evaluates to a tensor with the absolute
// filename. Users should not depend on a particular part of the filename to // filename. Users should not depend on a particular part of the filename to
@ -128,13 +127,11 @@ message SavedConstant {
} }
// Represents a Variable that is initialized by loading the contents from the // Represents a Variable that is initialized by loading the contents from the
// SavedModel checkpoint. // checkpoint.
message SavedVariable { message SavedVariable {
DataType dtype = 1; DataType dtype = 1;
TensorShapeProto shape = 2; TensorShapeProto shape = 2;
bool trainable = 3; bool trainable = 3;
// TODO(andresp): Add save_slice_info_def?
} }
// Represents `FunctionSpec` used in `Function`. This represents a // Represents `FunctionSpec` used in `Function`. This represents a

View File

@ -8,6 +8,27 @@ package tensorflow;
// `StructuredValue` represents a dynamically typed value representing various // `StructuredValue` represents a dynamically typed value representing various
// data structures that are inspired by Python data structures typically used in // data structures that are inspired by Python data structures typically used in
// TensorFlow functions as inputs and outputs. // TensorFlow functions as inputs and outputs.
//
// For example when saving a Layer there may be a `training` argument. If the
// user passes a boolean True/False, that switches between two concrete
// TensorFlow functions. In order to switch between them in the same way after
// loading the SavedModel, we need to represent "True" and "False".
//
// A more advanced example might be a function which takes a list of
// dictionaries mapping from strings to Tensors. In order to map from
// user-specified arguments `[{"a": tf.constant(1.)}, {"q": tf.constant(3.)}]`
// after load to the right saved TensorFlow function, we need to represent the
// nested structure and the strings, recording that we have a trace for anything
// matching `[{"a": tf.TensorSpec(None, tf.float32)}, {"q": tf.TensorSpec([],
// tf.float64)}]` as an example.
//
// Likewise functions may return nested structures of Tensors, for example
// returning a dictionary mapping from strings to Tensors. In order for the
// loaded function to return the same structure we need to serialize it.
//
// This is an ergonomic aid for working with loaded SavedModels, not a promise
// to serialize all possible function signatures. For example we do not expect
// to pickle generic Python objects, and ideally we'd stay language-agnostic.
message StructuredValue { message StructuredValue {
// The kind of value. // The kind of value.
oneof kind { oneof kind {
@ -29,11 +50,11 @@ message StructuredValue {
// Represents a boolean value. // Represents a boolean value.
bool bool_value = 14; bool bool_value = 14;
// Represents a tf.TensorShape. // Represents a TensorShape.
tensorflow.TensorShapeProto tensor_shape_value = 31; tensorflow.TensorShapeProto tensor_shape_value = 31;
// Represents an enum value for tf.DType. // Represents an enum value for dtype.
tensorflow.DataType tensor_dtype_value = 32; tensorflow.DataType tensor_dtype_value = 32;
// Represents a value for tf.TensorShape. // Represents a value for tf.TensorSpec.
TensorSpecProto tensor_spec_value = 33; TensorSpecProto tensor_spec_value = 33;
// Represents a list of `Value`. // Represents a list of `Value`.

View File

@ -291,7 +291,6 @@ py_library(
":function_serialization", ":function_serialization",
":nested_structure_coder", ":nested_structure_coder",
":revived_types", ":revived_types",
":saved_object_graph_py",
":signature_constants", ":signature_constants",
":signature_def_utils", ":signature_def_utils",
":signature_serialization", ":signature_serialization",
@ -346,8 +345,8 @@ py_library(
":loader", ":loader",
":nested_structure_coder", ":nested_structure_coder",
":revived_types", ":revived_types",
":saved_object_graph_py",
":utils", ":utils",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:constant_op", "//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops", "//tensorflow/python:init_ops",
@ -429,7 +428,7 @@ py_library(
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":saved_object_graph_py", "//tensorflow/core:protos_all_py",
], ],
) )
@ -438,7 +437,7 @@ tf_py_test(
srcs = ["revived_types_test.py"], srcs = ["revived_types_test.py"],
additional_deps = [ additional_deps = [
":revived_types", ":revived_types",
":saved_object_graph_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
], ],
) )
@ -451,7 +450,7 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":nested_structure_coder", ":nested_structure_coder",
":saved_object_graph_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:def_function", "//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:function", "//tensorflow/python/eager:function",
], ],
@ -469,27 +468,11 @@ py_library(
], ],
) )
tf_proto_library(
name = "struct",
srcs = ["struct.proto"],
cc_api_version = 2,
protodeps = tf_additional_all_protos(),
visibility = ["//tensorflow:internal"],
)
tf_proto_library(
name = "saved_object_graph",
srcs = ["saved_object_graph.proto"],
cc_api_version = 2,
protodeps = tf_additional_all_protos() + [":struct"],
visibility = ["//tensorflow:internal"],
)
py_library( py_library(
name = "nested_structure_coder", name = "nested_structure_coder",
srcs = ["nested_structure_coder.py"], srcs = ["nested_structure_coder.py"],
deps = [ deps = [
":struct_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python:framework", "//tensorflow/python:framework",
"@six_archive//:six", "@six_archive//:six",
], ],
@ -500,7 +483,7 @@ tf_py_test(
srcs = ["nested_structure_coder_test.py"], srcs = ["nested_structure_coder_test.py"],
additional_deps = [ additional_deps = [
":nested_structure_coder", ":nested_structure_coder",
":struct_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python:framework", "//tensorflow/python:framework",
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",
], ],

View File

@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.core.protobuf import saved_object_graph_pb2
from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import saved_object_graph_pb2
def _serialize_function_spec(function_spec, coder): def _serialize_function_spec(function_spec, coder):

View File

@ -24,23 +24,19 @@ import os
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import function_deserialization from tensorflow.python.saved_model import function_deserialization
from tensorflow.python.saved_model import load_v1_in_v2 from tensorflow.python.saved_model import load_v1_in_v2
from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import revived_types from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import graph_view from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util from tensorflow.python.training.tracking import util
from tensorflow.python.util import compat
from tensorflow.python.util import nest from tensorflow.python.util import nest
@ -265,12 +261,6 @@ def _call_attribute(instance, *args, **kwargs):
return instance.__call__(*args, **kwargs) return instance.__call__(*args, **kwargs)
def _load_saved_object_graph_proto(filename):
with file_io.FileIO(filename, "rb") as f:
contents = f.read()
return saved_object_graph_pb2.SavedObjectGraph.FromString(contents)
def load(export_dir, tags=None): def load(export_dir, tags=None):
"""Load a SavedModel from `export_dir`. """Load a SavedModel from `export_dir`.
@ -315,12 +305,8 @@ def load(export_dir, tags=None):
# Supports e.g. tags=SERVING and tags=[SERVING] # Supports e.g. tags=SERVING and tags=[SERVING]
tags = nest.flatten(tags) tags = nest.flatten(tags)
saved_model_proto = loader_impl.parse_saved_model(export_dir) saved_model_proto = loader_impl.parse_saved_model(export_dir)
object_graph_filename = os.path.join( if (len(saved_model_proto.meta_graphs) == 1
compat.as_bytes(export_dir), and saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY),
compat.as_bytes("object_graph.pb"))
if (file_io.file_exists(object_graph_filename)
and len(saved_model_proto.meta_graphs) == 1):
meta_graph_def = saved_model_proto.meta_graphs[0] meta_graph_def = saved_model_proto.meta_graphs[0]
if (tags is not None if (tags is not None
and set(tags) != set(meta_graph_def.meta_info_def.tags)): and set(tags) != set(meta_graph_def.meta_info_def.tags)):
@ -329,7 +315,7 @@ def load(export_dir, tags=None):
"incompatible argument tags={} to tf.saved_model.load. You may omit " "incompatible argument tags={} to tf.saved_model.load. You may omit "
"it, pass 'None', or pass matching tags.") "it, pass 'None', or pass matching tags.")
.format(export_dir, meta_graph_def.meta_info_def.tags, tags)) .format(export_dir, meta_graph_def.meta_info_def.tags, tags))
object_graph_proto = _load_saved_object_graph_proto(object_graph_filename) object_graph_proto = meta_graph_def.object_graph_def
with ops.init_scope(): with ops.init_scope():
loader = _Loader(object_graph_proto, loader = _Loader(object_graph_proto,
saved_model_proto, saved_model_proto,

View File

@ -34,10 +34,10 @@ import collections
import functools import functools
import six import six
from tensorflow.core.protobuf import struct_pb2
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.saved_model import struct_pb2
from tensorflow.python.util import compat from tensorflow.python.util import compat

View File

@ -20,12 +20,12 @@ from __future__ import print_function
import collections import collections
from tensorflow.core.protobuf import struct_pb2
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import struct_pb2
class NestedStructureTest(test.TestCase): class NestedStructureTest(test.TestCase):

View File

@ -19,7 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.core.framework import versions_pb2 from tensorflow.core.framework import versions_pb2
from tensorflow.python.saved_model import saved_object_graph_pb2 from tensorflow.core.protobuf import saved_object_graph_pb2
class VersionedTypeRegistration(object): class VersionedTypeRegistration(object):

View File

@ -19,9 +19,9 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.core.framework import versions_pb2 from tensorflow.core.framework import versions_pb2
from tensorflow.core.protobuf import saved_object_graph_pb2
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.saved_model import revived_types from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import tracking

View File

@ -24,6 +24,7 @@ import os
from tensorflow.core.framework import versions_pb2 from tensorflow.core.framework import versions_pb2
from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.core.protobuf import saved_object_graph_pb2
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun from tensorflow.python.eager import function as defun
@ -40,7 +41,6 @@ from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import function_serialization from tensorflow.python.saved_model import function_serialization
from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import revived_types from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import signature_serialization from tensorflow.python.saved_model import signature_serialization
@ -542,7 +542,7 @@ def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions):
return asset_info, exported_graph return asset_info, exported_graph
def _write_object_graph(saveable_view, export_dir, asset_file_def_index): def _serialize_object_graph(saveable_view, asset_file_def_index):
"""Save a SavedObjectGraph proto for `root`.""" """Save a SavedObjectGraph proto for `root`."""
# SavedObjectGraph is similar to the TrackableObjectGraph proto in the # SavedObjectGraph is similar to the TrackableObjectGraph proto in the
# checkpoint. It will eventually go into the SavedModel. # checkpoint. It will eventually go into the SavedModel.
@ -559,14 +559,7 @@ def _write_object_graph(saveable_view, export_dir, asset_file_def_index):
for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
_write_object_proto(obj, obj_proto, asset_file_def_index) _write_object_proto(obj, obj_proto, asset_file_def_index)
return proto
extra_asset_dir = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
file_io.recursive_create_dir(extra_asset_dir)
object_graph_filename = os.path.join(
extra_asset_dir, compat.as_bytes("object_graph.pb"))
file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
def _write_object_proto(obj, proto, asset_file_def_index): def _write_object_proto(obj, proto, asset_file_def_index):
@ -814,8 +807,10 @@ def save(obj, export_dir, signatures=None):
path = os.path.join( path = os.path.join(
compat.as_bytes(export_dir), compat.as_bytes(export_dir),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
object_graph_proto = _serialize_object_graph(
saveable_view, asset_info.asset_index)
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
file_io.write_string_to_file(path, saved_model.SerializeToString()) file_io.write_string_to_file(path, saved_model.SerializeToString())
_write_object_graph(saveable_view, export_dir, asset_info.asset_index)
# Clean reference cycles so repeated export()s don't make work for the garbage # Clean reference cycles so repeated export()s don't make work for the garbage
# collector. Before this point we need to keep references to captured # collector. Before this point we need to keep references to captured
# constants in the saved graph. # constants in the saved graph.

View File

@ -44,6 +44,13 @@ tf_proto {
type: TYPE_MESSAGE type: TYPE_MESSAGE
type_name: ".tensorflow.AssetFileDef" type_name: ".tensorflow.AssetFileDef"
} }
field {
name: "object_graph_def"
number: 7
label: LABEL_OPTIONAL
type: TYPE_MESSAGE
type_name: ".tensorflow.SavedObjectGraph"
}
nested_type { nested_type {
name: "MetaInfoDef" name: "MetaInfoDef"
field { field {