Add the SavedObjectGraph used by tf.saved_model.save/load to SavedModel
We reference a bunch of things in the MetaGraph/GraphDef, so it makes sense to add it there rather than to the SavedModel directly. This is in preparation for non-experimental tf.saved_model.save/load symbols. We don't yet have an exposed symbol for loading object-based SavedModels, so this CL won't break anyone (despite moving around the proto and not checking the old location). RFC: https://github.com/tensorflow/community/pull/34 PiperOrigin-RevId: 234887195
This commit is contained in:
parent
69ab50a9b6
commit
39be9adca1
@ -25,6 +25,7 @@ tensorflow/core/framework/variable.pb.cc
|
|||||||
tensorflow/core/framework/versions.pb.cc
|
tensorflow/core/framework/versions.pb.cc
|
||||||
tensorflow/core/grappler/costs/op_performance_data.pb.cc
|
tensorflow/core/grappler/costs/op_performance_data.pb.cc
|
||||||
tensorflow/core/lib/core/error_codes.pb.cc
|
tensorflow/core/lib/core/error_codes.pb.cc
|
||||||
|
tensorflow/core/protobuf/trackable_object_graph.pb.cc
|
||||||
tensorflow/core/protobuf/cluster.pb.cc
|
tensorflow/core/protobuf/cluster.pb.cc
|
||||||
tensorflow/core/protobuf/config.pb.cc
|
tensorflow/core/protobuf/config.pb.cc
|
||||||
tensorflow/core/protobuf/eager_service.pb.cc
|
tensorflow/core/protobuf/eager_service.pb.cc
|
||||||
@ -34,7 +35,9 @@ tensorflow/core/protobuf/meta_graph.pb.cc
|
|||||||
tensorflow/core/protobuf/named_tensor.pb.cc
|
tensorflow/core/protobuf/named_tensor.pb.cc
|
||||||
tensorflow/core/protobuf/queue_runner.pb.cc
|
tensorflow/core/protobuf/queue_runner.pb.cc
|
||||||
tensorflow/core/protobuf/rewriter_config.pb.cc
|
tensorflow/core/protobuf/rewriter_config.pb.cc
|
||||||
|
tensorflow/core/protobuf/saved_object_graph.pb.cc
|
||||||
tensorflow/core/protobuf/saver.pb.cc
|
tensorflow/core/protobuf/saver.pb.cc
|
||||||
|
tensorflow/core/protobuf/struct.pb.cc
|
||||||
tensorflow/core/protobuf/tensorflow_server.pb.cc
|
tensorflow/core/protobuf/tensorflow_server.pb.cc
|
||||||
tensorflow/core/protobuf/verifier_config.pb.cc
|
tensorflow/core/protobuf/verifier_config.pb.cc
|
||||||
tensorflow/core/util/event.pb.cc
|
tensorflow/core/util/event.pb.cc
|
||||||
|
@ -25,6 +25,7 @@ tensorflow/core/framework/variable.pb.h
|
|||||||
tensorflow/core/framework/versions.pb.h
|
tensorflow/core/framework/versions.pb.h
|
||||||
tensorflow/core/grappler/costs/op_performance_data.pb.h
|
tensorflow/core/grappler/costs/op_performance_data.pb.h
|
||||||
tensorflow/core/lib/core/error_codes.pb.h
|
tensorflow/core/lib/core/error_codes.pb.h
|
||||||
|
tensorflow/core/protobuf/trackable_object_graph.pb.h
|
||||||
tensorflow/core/protobuf/cluster.pb.h
|
tensorflow/core/protobuf/cluster.pb.h
|
||||||
tensorflow/core/protobuf/config.pb.h
|
tensorflow/core/protobuf/config.pb.h
|
||||||
tensorflow/core/protobuf/debug.pb.h
|
tensorflow/core/protobuf/debug.pb.h
|
||||||
@ -34,7 +35,9 @@ tensorflow/core/protobuf/meta_graph.pb.h
|
|||||||
tensorflow/core/protobuf/named_tensor.pb.h
|
tensorflow/core/protobuf/named_tensor.pb.h
|
||||||
tensorflow/core/protobuf/queue_runner.pb.h
|
tensorflow/core/protobuf/queue_runner.pb.h
|
||||||
tensorflow/core/protobuf/rewriter_config.pb.h
|
tensorflow/core/protobuf/rewriter_config.pb.h
|
||||||
|
tensorflow/core/protobuf/saved_object_graph.pb.h
|
||||||
tensorflow/core/protobuf/saver.pb.h
|
tensorflow/core/protobuf/saver.pb.h
|
||||||
|
tensorflow/core/protobuf/struct.pb.h
|
||||||
tensorflow/core/protobuf/tensor_bundle.pb.h
|
tensorflow/core/protobuf/tensor_bundle.pb.h
|
||||||
tensorflow/core/protobuf/tensorflow_server.pb.h
|
tensorflow/core/protobuf/tensorflow_server.pb.h
|
||||||
tensorflow/core/protobuf/verifier_config.pb.h
|
tensorflow/core/protobuf/verifier_config.pb.h
|
||||||
|
@ -31,6 +31,7 @@ tensorflow/core/framework/versions.proto
|
|||||||
tensorflow/core/grappler/costs/op_performance_data.proto
|
tensorflow/core/grappler/costs/op_performance_data.proto
|
||||||
tensorflow/core/kernels/boosted_trees/boosted_trees.proto
|
tensorflow/core/kernels/boosted_trees/boosted_trees.proto
|
||||||
tensorflow/core/lib/core/error_codes.proto
|
tensorflow/core/lib/core/error_codes.proto
|
||||||
|
tensorflow/core/protobuf/trackable_object_graph.proto
|
||||||
tensorflow/core/protobuf/cluster.proto
|
tensorflow/core/protobuf/cluster.proto
|
||||||
tensorflow/core/protobuf/config.proto
|
tensorflow/core/protobuf/config.proto
|
||||||
tensorflow/core/protobuf/debug.proto
|
tensorflow/core/protobuf/debug.proto
|
||||||
@ -40,7 +41,9 @@ tensorflow/core/protobuf/meta_graph.proto
|
|||||||
tensorflow/core/protobuf/named_tensor.proto
|
tensorflow/core/protobuf/named_tensor.proto
|
||||||
tensorflow/core/protobuf/queue_runner.proto
|
tensorflow/core/protobuf/queue_runner.proto
|
||||||
tensorflow/core/protobuf/rewriter_config.proto
|
tensorflow/core/protobuf/rewriter_config.proto
|
||||||
|
tensorflow/core/protobuf/saved_object_graph.proto
|
||||||
tensorflow/core/protobuf/saver.proto
|
tensorflow/core/protobuf/saver.proto
|
||||||
|
tensorflow/core/protobuf/struct.proto
|
||||||
tensorflow/core/protobuf/tensor_bundle.proto
|
tensorflow/core/protobuf/tensor_bundle.proto
|
||||||
tensorflow/core/protobuf/tensorflow_server.proto
|
tensorflow/core/protobuf/tensorflow_server.proto
|
||||||
tensorflow/core/protobuf/verifier_config.proto
|
tensorflow/core/protobuf/verifier_config.proto
|
||||||
|
@ -236,6 +236,8 @@ ADDITIONAL_CORE_PROTO_SRCS = [
|
|||||||
"protobuf/meta_graph.proto",
|
"protobuf/meta_graph.proto",
|
||||||
"protobuf/named_tensor.proto",
|
"protobuf/named_tensor.proto",
|
||||||
"protobuf/saved_model.proto",
|
"protobuf/saved_model.proto",
|
||||||
|
"protobuf/saved_object_graph.proto",
|
||||||
|
"protobuf/struct.proto",
|
||||||
"protobuf/tensorflow_server.proto",
|
"protobuf/tensorflow_server.proto",
|
||||||
"protobuf/transport_options.proto",
|
"protobuf/transport_options.proto",
|
||||||
"util/test_log.proto",
|
"util/test_log.proto",
|
||||||
|
@ -12,6 +12,7 @@ import "tensorflow/core/framework/graph.proto";
|
|||||||
import "tensorflow/core/framework/op_def.proto";
|
import "tensorflow/core/framework/op_def.proto";
|
||||||
import "tensorflow/core/framework/tensor_shape.proto";
|
import "tensorflow/core/framework/tensor_shape.proto";
|
||||||
import "tensorflow/core/framework/types.proto";
|
import "tensorflow/core/framework/types.proto";
|
||||||
|
import "tensorflow/core/protobuf/saved_object_graph.proto";
|
||||||
import "tensorflow/core/protobuf/saver.proto";
|
import "tensorflow/core/protobuf/saver.proto";
|
||||||
|
|
||||||
// NOTE: This protocol buffer is evolving, and will go through revisions in the
|
// NOTE: This protocol buffer is evolving, and will go through revisions in the
|
||||||
@ -84,6 +85,9 @@ message MetaGraphDef {
|
|||||||
|
|
||||||
// Asset file def to be used with the defined graph.
|
// Asset file def to be used with the defined graph.
|
||||||
repeated AssetFileDef asset_file_def = 6;
|
repeated AssetFileDef asset_file_def = 6;
|
||||||
|
|
||||||
|
// Extra information about the structure of functions and stateful objects.
|
||||||
|
SavedObjectGraph object_graph_def = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
// CollectionDef should cover most collections.
|
// CollectionDef should cover most collections.
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
|
|
||||||
import "tensorflow/core/protobuf/trackable_object_graph.proto";
|
import "tensorflow/core/protobuf/trackable_object_graph.proto";
|
||||||
|
import "tensorflow/core/protobuf/struct.proto";
|
||||||
import "tensorflow/core/framework/tensor_shape.proto";
|
import "tensorflow/core/framework/tensor_shape.proto";
|
||||||
import "tensorflow/core/framework/types.proto";
|
import "tensorflow/core/framework/types.proto";
|
||||||
import "tensorflow/core/framework/versions.proto";
|
import "tensorflow/core/framework/versions.proto";
|
||||||
import "tensorflow/python/saved_model/struct.proto";
|
|
||||||
|
|
||||||
option cc_enable_arenas = true;
|
option cc_enable_arenas = true;
|
||||||
|
|
||||||
@ -15,14 +15,12 @@ package tensorflow;
|
|||||||
// languages) that make up a model, with nodes[0] at the root.
|
// languages) that make up a model, with nodes[0] at the root.
|
||||||
|
|
||||||
// SavedObjectGraph shares some structure with TrackableObjectGraph, but
|
// SavedObjectGraph shares some structure with TrackableObjectGraph, but
|
||||||
// ObjectGraph belongs to the SavedModel and contains pointers to functions and
|
// SavedObjectGraph belongs to the MetaGraph and contains pointers to functions
|
||||||
// type information, while TrackableObjectGraph lives in the checkpoint and
|
// and type information, while TrackableObjectGraph lives in the checkpoint
|
||||||
// contains pointers only to variable values.
|
// and contains pointers only to variable values.
|
||||||
|
|
||||||
// NOTE: This protocol buffer format is experimental and subject to change.
|
|
||||||
|
|
||||||
message SavedObjectGraph {
|
message SavedObjectGraph {
|
||||||
// List of objects in the SavedModel.
|
// Flattened list of objects in the object graph.
|
||||||
//
|
//
|
||||||
// The position of the object in this list indicates its id.
|
// The position of the object in this list indicates its id.
|
||||||
// Nodes[0] is considered the root node.
|
// Nodes[0] is considered the root node.
|
||||||
@ -37,10 +35,11 @@ message SavedObject {
|
|||||||
// Objects which this object depends on: named edges in the dependency
|
// Objects which this object depends on: named edges in the dependency
|
||||||
// graph.
|
// graph.
|
||||||
//
|
//
|
||||||
// Note: only valid if kind == "object".
|
// Note: currently only valid if kind == "user_object".
|
||||||
repeated TrackableObjectGraph.TrackableObject.ObjectReference children = 1;
|
repeated TrackableObjectGraph.TrackableObject.ObjectReference
|
||||||
|
children = 1;
|
||||||
|
|
||||||
// Removed when forking from TrackableObjectGraph.
|
// Removed when forking SavedObject from TrackableObjectGraph.
|
||||||
reserved "attributes";
|
reserved "attributes";
|
||||||
reserved 2;
|
reserved 2;
|
||||||
|
|
||||||
@ -48,7 +47,7 @@ message SavedObject {
|
|||||||
// (optimizer, variable, slot variable) relationship; none of the three
|
// (optimizer, variable, slot variable) relationship; none of the three
|
||||||
// depend on the others directly.
|
// depend on the others directly.
|
||||||
//
|
//
|
||||||
// Note: only valid if kind == "object".
|
// Note: currently only valid if kind == "user_object".
|
||||||
repeated TrackableObjectGraph.TrackableObject.SlotVariableReference
|
repeated TrackableObjectGraph.TrackableObject.SlotVariableReference
|
||||||
slot_variables = 3;
|
slot_variables = 3;
|
||||||
|
|
||||||
@ -76,7 +75,7 @@ message SavedUserObject {
|
|||||||
VersionDef version = 2;
|
VersionDef version = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// A SavedAsset represents a file in a SavedModel.
|
// A SavedAsset points to an asset in the MetaGraph.
|
||||||
//
|
//
|
||||||
// When bound to a function this object evaluates to a tensor with the absolute
|
// When bound to a function this object evaluates to a tensor with the absolute
|
||||||
// filename. Users should not depend on a particular part of the filename to
|
// filename. Users should not depend on a particular part of the filename to
|
||||||
@ -128,13 +127,11 @@ message SavedConstant {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Represents a Variable that is initialized by loading the contents from the
|
// Represents a Variable that is initialized by loading the contents from the
|
||||||
// SavedModel checkpoint.
|
// checkpoint.
|
||||||
message SavedVariable {
|
message SavedVariable {
|
||||||
DataType dtype = 1;
|
DataType dtype = 1;
|
||||||
TensorShapeProto shape = 2;
|
TensorShapeProto shape = 2;
|
||||||
bool trainable = 3;
|
bool trainable = 3;
|
||||||
|
|
||||||
// TODO(andresp): Add save_slice_info_def?
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Represents `FunctionSpec` used in `Function`. This represents a
|
// Represents `FunctionSpec` used in `Function`. This represents a
|
@ -8,6 +8,27 @@ package tensorflow;
|
|||||||
// `StructuredValue` represents a dynamically typed value representing various
|
// `StructuredValue` represents a dynamically typed value representing various
|
||||||
// data structures that are inspired by Python data structures typically used in
|
// data structures that are inspired by Python data structures typically used in
|
||||||
// TensorFlow functions as inputs and outputs.
|
// TensorFlow functions as inputs and outputs.
|
||||||
|
//
|
||||||
|
// For example when saving a Layer there may be a `training` argument. If the
|
||||||
|
// user passes a boolean True/False, that switches between two concrete
|
||||||
|
// TensorFlow functions. In order to switch between them in the same way after
|
||||||
|
// loading the SavedModel, we need to represent "True" and "False".
|
||||||
|
//
|
||||||
|
// A more advanced example might be a function which takes a list of
|
||||||
|
// dictionaries mapping from strings to Tensors. In order to map from
|
||||||
|
// user-specified arguments `[{"a": tf.constant(1.)}, {"q": tf.constant(3.)}]`
|
||||||
|
// after load to the right saved TensorFlow function, we need to represent the
|
||||||
|
// nested structure and the strings, recording that we have a trace for anything
|
||||||
|
// matching `[{"a": tf.TensorSpec(None, tf.float32)}, {"q": tf.TensorSpec([],
|
||||||
|
// tf.float64)}]` as an example.
|
||||||
|
//
|
||||||
|
// Likewise functions may return nested structures of Tensors, for example
|
||||||
|
// returning a dictionary mapping from strings to Tensors. In order for the
|
||||||
|
// loaded function to return the same structure we need to serialize it.
|
||||||
|
//
|
||||||
|
// This is an ergonomic aid for working with loaded SavedModels, not a promise
|
||||||
|
// to serialize all possible function signatures. For example we do not expect
|
||||||
|
// to pickle generic Python objects, and ideally we'd stay language-agnostic.
|
||||||
message StructuredValue {
|
message StructuredValue {
|
||||||
// The kind of value.
|
// The kind of value.
|
||||||
oneof kind {
|
oneof kind {
|
||||||
@ -29,11 +50,11 @@ message StructuredValue {
|
|||||||
// Represents a boolean value.
|
// Represents a boolean value.
|
||||||
bool bool_value = 14;
|
bool bool_value = 14;
|
||||||
|
|
||||||
// Represents a tf.TensorShape.
|
// Represents a TensorShape.
|
||||||
tensorflow.TensorShapeProto tensor_shape_value = 31;
|
tensorflow.TensorShapeProto tensor_shape_value = 31;
|
||||||
// Represents an enum value for tf.DType.
|
// Represents an enum value for dtype.
|
||||||
tensorflow.DataType tensor_dtype_value = 32;
|
tensorflow.DataType tensor_dtype_value = 32;
|
||||||
// Represents a value for tf.TensorShape.
|
// Represents a value for tf.TensorSpec.
|
||||||
TensorSpecProto tensor_spec_value = 33;
|
TensorSpecProto tensor_spec_value = 33;
|
||||||
|
|
||||||
// Represents a list of `Value`.
|
// Represents a list of `Value`.
|
@ -291,7 +291,6 @@ py_library(
|
|||||||
":function_serialization",
|
":function_serialization",
|
||||||
":nested_structure_coder",
|
":nested_structure_coder",
|
||||||
":revived_types",
|
":revived_types",
|
||||||
":saved_object_graph_py",
|
|
||||||
":signature_constants",
|
":signature_constants",
|
||||||
":signature_def_utils",
|
":signature_def_utils",
|
||||||
":signature_serialization",
|
":signature_serialization",
|
||||||
@ -346,8 +345,8 @@ py_library(
|
|||||||
":loader",
|
":loader",
|
||||||
":nested_structure_coder",
|
":nested_structure_coder",
|
||||||
":revived_types",
|
":revived_types",
|
||||||
":saved_object_graph_py",
|
|
||||||
":utils",
|
":utils",
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:constant_op",
|
"//tensorflow/python:constant_op",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:init_ops",
|
"//tensorflow/python:init_ops",
|
||||||
@ -429,7 +428,7 @@ py_library(
|
|||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":saved_object_graph_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -438,7 +437,7 @@ tf_py_test(
|
|||||||
srcs = ["revived_types_test.py"],
|
srcs = ["revived_types_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":revived_types",
|
":revived_types",
|
||||||
":saved_object_graph_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -451,7 +450,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":nested_structure_coder",
|
":nested_structure_coder",
|
||||||
":saved_object_graph_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
"//tensorflow/python/eager:function",
|
"//tensorflow/python/eager:function",
|
||||||
],
|
],
|
||||||
@ -469,27 +468,11 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_proto_library(
|
|
||||||
name = "struct",
|
|
||||||
srcs = ["struct.proto"],
|
|
||||||
cc_api_version = 2,
|
|
||||||
protodeps = tf_additional_all_protos(),
|
|
||||||
visibility = ["//tensorflow:internal"],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_proto_library(
|
|
||||||
name = "saved_object_graph",
|
|
||||||
srcs = ["saved_object_graph.proto"],
|
|
||||||
cc_api_version = 2,
|
|
||||||
protodeps = tf_additional_all_protos() + [":struct"],
|
|
||||||
visibility = ["//tensorflow:internal"],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "nested_structure_coder",
|
name = "nested_structure_coder",
|
||||||
srcs = ["nested_structure_coder.py"],
|
srcs = ["nested_structure_coder.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":struct_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
@ -500,7 +483,7 @@ tf_py_test(
|
|||||||
srcs = ["nested_structure_coder_test.py"],
|
srcs = ["nested_structure_coder_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":nested_structure_coder",
|
":nested_structure_coder",
|
||||||
":struct_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
],
|
],
|
||||||
|
@ -18,9 +18,9 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.core.protobuf import saved_object_graph_pb2
|
||||||
from tensorflow.python.framework import func_graph as func_graph_module
|
from tensorflow.python.framework import func_graph as func_graph_module
|
||||||
from tensorflow.python.saved_model import nested_structure_coder
|
from tensorflow.python.saved_model import nested_structure_coder
|
||||||
from tensorflow.python.saved_model import saved_object_graph_pb2
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_function_spec(function_spec, coder):
|
def _serialize_function_spec(function_spec, coder):
|
||||||
|
@ -24,23 +24,19 @@ import os
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.lib.io import file_io
|
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.saved_model import constants
|
|
||||||
from tensorflow.python.saved_model import function_deserialization
|
from tensorflow.python.saved_model import function_deserialization
|
||||||
from tensorflow.python.saved_model import load_v1_in_v2
|
from tensorflow.python.saved_model import load_v1_in_v2
|
||||||
from tensorflow.python.saved_model import loader_impl
|
from tensorflow.python.saved_model import loader_impl
|
||||||
from tensorflow.python.saved_model import nested_structure_coder
|
from tensorflow.python.saved_model import nested_structure_coder
|
||||||
from tensorflow.python.saved_model import revived_types
|
from tensorflow.python.saved_model import revived_types
|
||||||
from tensorflow.python.saved_model import saved_object_graph_pb2
|
|
||||||
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||||
from tensorflow.python.training.tracking import base
|
from tensorflow.python.training.tracking import base
|
||||||
from tensorflow.python.training.tracking import graph_view
|
from tensorflow.python.training.tracking import graph_view
|
||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
from tensorflow.python.training.tracking import util
|
from tensorflow.python.training.tracking import util
|
||||||
from tensorflow.python.util import compat
|
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
@ -265,12 +261,6 @@ def _call_attribute(instance, *args, **kwargs):
|
|||||||
return instance.__call__(*args, **kwargs)
|
return instance.__call__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _load_saved_object_graph_proto(filename):
|
|
||||||
with file_io.FileIO(filename, "rb") as f:
|
|
||||||
contents = f.read()
|
|
||||||
return saved_object_graph_pb2.SavedObjectGraph.FromString(contents)
|
|
||||||
|
|
||||||
|
|
||||||
def load(export_dir, tags=None):
|
def load(export_dir, tags=None):
|
||||||
"""Load a SavedModel from `export_dir`.
|
"""Load a SavedModel from `export_dir`.
|
||||||
|
|
||||||
@ -315,12 +305,8 @@ def load(export_dir, tags=None):
|
|||||||
# Supports e.g. tags=SERVING and tags=[SERVING]
|
# Supports e.g. tags=SERVING and tags=[SERVING]
|
||||||
tags = nest.flatten(tags)
|
tags = nest.flatten(tags)
|
||||||
saved_model_proto = loader_impl.parse_saved_model(export_dir)
|
saved_model_proto = loader_impl.parse_saved_model(export_dir)
|
||||||
object_graph_filename = os.path.join(
|
if (len(saved_model_proto.meta_graphs) == 1
|
||||||
compat.as_bytes(export_dir),
|
and saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
|
||||||
compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY),
|
|
||||||
compat.as_bytes("object_graph.pb"))
|
|
||||||
if (file_io.file_exists(object_graph_filename)
|
|
||||||
and len(saved_model_proto.meta_graphs) == 1):
|
|
||||||
meta_graph_def = saved_model_proto.meta_graphs[0]
|
meta_graph_def = saved_model_proto.meta_graphs[0]
|
||||||
if (tags is not None
|
if (tags is not None
|
||||||
and set(tags) != set(meta_graph_def.meta_info_def.tags)):
|
and set(tags) != set(meta_graph_def.meta_info_def.tags)):
|
||||||
@ -329,7 +315,7 @@ def load(export_dir, tags=None):
|
|||||||
"incompatible argument tags={} to tf.saved_model.load. You may omit "
|
"incompatible argument tags={} to tf.saved_model.load. You may omit "
|
||||||
"it, pass 'None', or pass matching tags.")
|
"it, pass 'None', or pass matching tags.")
|
||||||
.format(export_dir, meta_graph_def.meta_info_def.tags, tags))
|
.format(export_dir, meta_graph_def.meta_info_def.tags, tags))
|
||||||
object_graph_proto = _load_saved_object_graph_proto(object_graph_filename)
|
object_graph_proto = meta_graph_def.object_graph_def
|
||||||
with ops.init_scope():
|
with ops.init_scope():
|
||||||
loader = _Loader(object_graph_proto,
|
loader = _Loader(object_graph_proto,
|
||||||
saved_model_proto,
|
saved_model_proto,
|
||||||
|
@ -34,10 +34,10 @@ import collections
|
|||||||
import functools
|
import functools
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.core.protobuf import struct_pb2
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.saved_model import struct_pb2
|
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,12 +20,12 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
|
from tensorflow.core.protobuf import struct_pb2
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.saved_model import nested_structure_coder
|
from tensorflow.python.saved_model import nested_structure_coder
|
||||||
from tensorflow.python.saved_model import struct_pb2
|
|
||||||
|
|
||||||
|
|
||||||
class NestedStructureTest(test.TestCase):
|
class NestedStructureTest(test.TestCase):
|
||||||
|
@ -19,7 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.core.framework import versions_pb2
|
from tensorflow.core.framework import versions_pb2
|
||||||
from tensorflow.python.saved_model import saved_object_graph_pb2
|
from tensorflow.core.protobuf import saved_object_graph_pb2
|
||||||
|
|
||||||
|
|
||||||
class VersionedTypeRegistration(object):
|
class VersionedTypeRegistration(object):
|
||||||
|
@ -19,9 +19,9 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.core.framework import versions_pb2
|
from tensorflow.core.framework import versions_pb2
|
||||||
|
from tensorflow.core.protobuf import saved_object_graph_pb2
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.saved_model import revived_types
|
from tensorflow.python.saved_model import revived_types
|
||||||
from tensorflow.python.saved_model import saved_object_graph_pb2
|
|
||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ import os
|
|||||||
from tensorflow.core.framework import versions_pb2
|
from tensorflow.core.framework import versions_pb2
|
||||||
from tensorflow.core.protobuf import meta_graph_pb2
|
from tensorflow.core.protobuf import meta_graph_pb2
|
||||||
from tensorflow.core.protobuf import saved_model_pb2
|
from tensorflow.core.protobuf import saved_model_pb2
|
||||||
|
from tensorflow.core.protobuf import saved_object_graph_pb2
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import function as defun
|
from tensorflow.python.eager import function as defun
|
||||||
@ -40,7 +41,6 @@ from tensorflow.python.saved_model import constants
|
|||||||
from tensorflow.python.saved_model import function_serialization
|
from tensorflow.python.saved_model import function_serialization
|
||||||
from tensorflow.python.saved_model import nested_structure_coder
|
from tensorflow.python.saved_model import nested_structure_coder
|
||||||
from tensorflow.python.saved_model import revived_types
|
from tensorflow.python.saved_model import revived_types
|
||||||
from tensorflow.python.saved_model import saved_object_graph_pb2
|
|
||||||
from tensorflow.python.saved_model import signature_constants
|
from tensorflow.python.saved_model import signature_constants
|
||||||
from tensorflow.python.saved_model import signature_def_utils
|
from tensorflow.python.saved_model import signature_def_utils
|
||||||
from tensorflow.python.saved_model import signature_serialization
|
from tensorflow.python.saved_model import signature_serialization
|
||||||
@ -542,7 +542,7 @@ def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions):
|
|||||||
return asset_info, exported_graph
|
return asset_info, exported_graph
|
||||||
|
|
||||||
|
|
||||||
def _write_object_graph(saveable_view, export_dir, asset_file_def_index):
|
def _serialize_object_graph(saveable_view, asset_file_def_index):
|
||||||
"""Save a SavedObjectGraph proto for `root`."""
|
"""Save a SavedObjectGraph proto for `root`."""
|
||||||
# SavedObjectGraph is similar to the TrackableObjectGraph proto in the
|
# SavedObjectGraph is similar to the TrackableObjectGraph proto in the
|
||||||
# checkpoint. It will eventually go into the SavedModel.
|
# checkpoint. It will eventually go into the SavedModel.
|
||||||
@ -559,14 +559,7 @@ def _write_object_graph(saveable_view, export_dir, asset_file_def_index):
|
|||||||
|
|
||||||
for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
|
for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
|
||||||
_write_object_proto(obj, obj_proto, asset_file_def_index)
|
_write_object_proto(obj, obj_proto, asset_file_def_index)
|
||||||
|
return proto
|
||||||
extra_asset_dir = os.path.join(
|
|
||||||
compat.as_bytes(export_dir),
|
|
||||||
compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
|
|
||||||
file_io.recursive_create_dir(extra_asset_dir)
|
|
||||||
object_graph_filename = os.path.join(
|
|
||||||
extra_asset_dir, compat.as_bytes("object_graph.pb"))
|
|
||||||
file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
|
|
||||||
|
|
||||||
|
|
||||||
def _write_object_proto(obj, proto, asset_file_def_index):
|
def _write_object_proto(obj, proto, asset_file_def_index):
|
||||||
@ -814,8 +807,10 @@ def save(obj, export_dir, signatures=None):
|
|||||||
path = os.path.join(
|
path = os.path.join(
|
||||||
compat.as_bytes(export_dir),
|
compat.as_bytes(export_dir),
|
||||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
|
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
|
||||||
|
object_graph_proto = _serialize_object_graph(
|
||||||
|
saveable_view, asset_info.asset_index)
|
||||||
|
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
|
||||||
file_io.write_string_to_file(path, saved_model.SerializeToString())
|
file_io.write_string_to_file(path, saved_model.SerializeToString())
|
||||||
_write_object_graph(saveable_view, export_dir, asset_info.asset_index)
|
|
||||||
# Clean reference cycles so repeated export()s don't make work for the garbage
|
# Clean reference cycles so repeated export()s don't make work for the garbage
|
||||||
# collector. Before this point we need to keep references to captured
|
# collector. Before this point we need to keep references to captured
|
||||||
# constants in the saved graph.
|
# constants in the saved graph.
|
||||||
|
@ -44,6 +44,13 @@ tf_proto {
|
|||||||
type: TYPE_MESSAGE
|
type: TYPE_MESSAGE
|
||||||
type_name: ".tensorflow.AssetFileDef"
|
type_name: ".tensorflow.AssetFileDef"
|
||||||
}
|
}
|
||||||
|
field {
|
||||||
|
name: "object_graph_def"
|
||||||
|
number: 7
|
||||||
|
label: LABEL_OPTIONAL
|
||||||
|
type: TYPE_MESSAGE
|
||||||
|
type_name: ".tensorflow.SavedObjectGraph"
|
||||||
|
}
|
||||||
nested_type {
|
nested_type {
|
||||||
name: "MetaInfoDef"
|
name: "MetaInfoDef"
|
||||||
field {
|
field {
|
||||||
|
Loading…
Reference in New Issue
Block a user