Update the TensorInfo protobuf message with an encoding for composite tensors; and update SavedModel to use this new encoding.

PiperOrigin-RevId: 262639435
This commit is contained in:
Edward Loper 2019-08-09 14:54:19 -07:00 committed by TensorFlower Gardener
parent 23e33f871b
commit b78d23cf92
14 changed files with 255 additions and 14 deletions

View File

@ -42,6 +42,10 @@ void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info,
tensor_names->insert(coo_sparse.values_tensor_name());
tensor_names->insert(coo_sparse.indices_tensor_name());
tensor_names->insert(coo_sparse.dense_shape_tensor_name());
} else if (tensor_info.has_composite_tensor()) {
for (const auto& component : tensor_info.composite_tensor().components()) {
tensor_names->insert(component.name());
}
} else {
tensor_names->insert(tensor_info.name());
}

View File

@ -425,5 +425,63 @@ TEST_F(FreezeTest, GraphDefWithAndWithoutDependentResourceVariables) {
TestFreezeGraphWithAndWithoutDependentVariables(true);
}
TEST_F(FreezeTest, InputsAndOutputsCompositeTensorSignatureDef) {
// Test that inputs and outputs get correctly populated for a
// SignatureDef containing composite tensor inputs and outputs.
SavedModelBundle saved_model_bundle;
SignatureDef signature_def;
TensorInfo& in = (*signature_def.mutable_inputs())["input_arg"];
in.mutable_composite_tensor()->add_components()->set_name("input1:0");
in.mutable_composite_tensor()->add_components()->set_name("input2:0");
TensorInfo& out = (*signature_def.mutable_outputs())["output_arg"];
out.mutable_composite_tensor()->add_components()->set_name("output2:0");
out.mutable_composite_tensor()->add_components()->set_name("output1:0");
AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
&saved_model_bundle);
GraphDef frozen_graph_def;
std::unordered_set<string> inputs;
std::unordered_set<string> outputs;
TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
&outputs));
std::unordered_set<string> expected_inputs = {"input1:0", "input2:0"};
std::unordered_set<string> expected_outputs = {"output1:0", "output2:0"};
EXPECT_EQ(expected_inputs, inputs);
EXPECT_EQ(expected_outputs, outputs);
}
TEST_F(FreezeTest, InputsAndOutputsSparseCooSignatureDef) {
// Test that inputs and outputs get correctly populated for a
// SignatureDef containing composite tensor inputs and outputs.
SavedModelBundle saved_model_bundle;
SignatureDef signature_def;
TensorInfo& in = (*signature_def.mutable_inputs())["input_arg"];
in.mutable_coo_sparse()->set_values_tensor_name("input1:0");
in.mutable_coo_sparse()->set_indices_tensor_name("input2:0");
in.mutable_coo_sparse()->set_dense_shape_tensor_name("input3:0");
TensorInfo& out = (*signature_def.mutable_outputs())["output_arg"];
out.mutable_coo_sparse()->set_values_tensor_name("output1:0");
out.mutable_coo_sparse()->set_indices_tensor_name("output2:0");
out.mutable_coo_sparse()->set_dense_shape_tensor_name("output3:0");
AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
&saved_model_bundle);
GraphDef frozen_graph_def;
std::unordered_set<string> inputs;
std::unordered_set<string> outputs;
TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
&outputs));
std::unordered_set<string> expected_inputs = {"input1:0", "input2:0",
"input3:0"};
std::unordered_set<string> expected_outputs = {"output1:0", "output2:0",
"output3:0"};
EXPECT_EQ(expected_inputs, inputs);
EXPECT_EQ(expected_outputs, outputs);
}
} // namespace
} // namespace tensorflow

View File

@ -14,6 +14,7 @@ import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
import "tensorflow/core/protobuf/saved_object_graph.proto";
import "tensorflow/core/protobuf/saver.proto";
import "tensorflow/core/protobuf/struct.proto";
// NOTE: This protocol buffer is evolving, and will go through revisions in the
// coming months.
@ -225,6 +226,15 @@ message TensorInfo {
string dense_shape_tensor_name = 3;
}
// Generic encoding for composite tensors.
message CompositeTensor {
// The serialized TypeSpec for the composite tensor.
TypeSpecProto type_spec = 1;
// A TensorInfo for each flattened component tensor.
repeated TensorInfo components = 2;
}
oneof encoding {
// For dense `Tensor`s, the name of the tensor in the graph.
string name = 1;
@ -233,6 +243,8 @@ message TensorInfo {
// uses only the COO encoding. This is supported and documented in the
// SparseTensor Python class.
CooSparse coo_sparse = 4;
// Generic encoding for CompositeTensors.
CompositeTensor composite_tensor = 5;
}
DataType dtype = 2;
// The static shape should be recorded here, to the extent that it can

View File

@ -720,6 +720,7 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:template",
"//tensorflow/python:variable_scope",
"//tensorflow/python/saved_model:nested_structure_coder",
"//tensorflow/python/training/tracking:base",
],
)

View File

@ -22,9 +22,11 @@ from __future__ import print_function
import weakref
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import struct_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.eager import lift_to_graph
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
@ -34,6 +36,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
@ -104,6 +107,14 @@ def _get_element_from_tensor_info(tensor_info, graph):
graph.get_tensor_by_name(tensor_info.coo_sparse.values_tensor_name),
graph.get_tensor_by_name(
tensor_info.coo_sparse.dense_shape_tensor_name))
elif encoding == "composite_tensor":
struct_coder = nested_structure_coder.StructureCoder()
spec_proto = struct_pb2.StructuredValue(
type_spec_value=tensor_info.composite_tensor.type_spec)
spec = struct_coder.decode_proto(spec_proto)
components = [graph.get_tensor_by_name(component.name) for component in
tensor_info.composite_tensor.components]
return spec._from_components(components) # pylint: disable=protected-access
else:
raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
@ -243,8 +254,8 @@ class WrappedFunction(function.ConcreteFunction):
"""
# TODO(b/129646028): Add support for CompositeTensors.
name = name or "pruned"
feeds = nest.map_structure(self.graph.as_graph_element, feeds)
flat_feeds = nest.flatten(feeds)
flat_feeds = nest.flatten(feeds, expand_composites=True)
flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds]
for f in flat_feeds:
if not isinstance(f, ops.Tensor):
raise ValueError("Feeds must be tensors.")
@ -278,12 +289,13 @@ class WrappedFunction(function.ConcreteFunction):
elif isinstance(fetch, meta_graph_pb2.TensorInfo):
tensor_infos.append(fetch)
decoded = _get_element_from_tensor_info(fetch, self._func_graph)
if tensor_util.is_tensor(decoded):
if (tensor_util.is_tensor(decoded) or
isinstance(decoded, composite_tensor.CompositeTensor)):
tensor_fetches.append(decoded)
else:
operation_fetches.append(decoded)
return decoded
elif isinstance(fetch, ops.Tensor):
elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)):
tensor_fetches.append(fetch)
return fetch
else:

View File

@ -36,6 +36,8 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
from tensorflow.python.training import saver as saver_lib
@ -84,6 +86,31 @@ class WrapFunctionTest(test.TestCase):
f_pruned = f_wrapped.prune(x_in[0], [x_out[0]])
self.assertAllEqual(f_pruned(ops.convert_to_tensor(2.0)), [4.0])
def testPruneRagged(self):
x_in = []
x_out = []
def f(x, y):
x_in.append(x)
xx = x * x
x_out.append(xx)
return xx, y * y
x_spec = ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32)
y_spec = tensor_spec.TensorSpec((), dtypes.float32)
f_wrapped = wrap_function.wrap_function(f, [x_spec, y_spec])
f_pruned = f_wrapped.prune(x_in[0], x_out[0])
rt = ragged_factory_ops.constant([[1.0, 2.0], [3.0]])
expected = ragged_factory_ops.constant_value([[1.0, 4.0], [9.0]])
# Note: when we call f_pruned, we must pass the RaggedTensor in using
# its components, since that's the current convention for how concrete
# functions handle structured inputs.
self.assertAllEqual(f_pruned(rt.values, rt.row_splits), expected)
def _assert_single_captured_variable_argument(self, graph_def):
# The single FunctionDef should have one argument, a captured variable
function_def, = graph_def.library.function

View File

@ -191,6 +191,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":constants",
":nested_structure_coder",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:lib",
@ -481,6 +482,12 @@ py_library(
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:optional_ops",
"//tensorflow/python/distribute:values",
"//tensorflow/python/ops/ragged",
"@six_archive//:six",
],
)

View File

@ -155,14 +155,14 @@ class _SavedModelBuilder(object):
def _validate_tensor_info(self, tensor_info):
"""Validates the `TensorInfo` proto.
Checks if the `encoding` (`name` or `coo_sparse`) and `dtype` fields exist
and are non-empty.
Checks if the `encoding` (`name` or `coo_sparse` or `type_spec`) and
`dtype` fields exist and are non-empty.
Args:
tensor_info: `TensorInfo` protocol buffer to validate.
Raises:
AssertionError: If the `name` or `dtype` fields of the supplied
AssertionError: If the `encoding` or `dtype` fields of the supplied
`TensorInfo` proto are not populated.
"""
if tensor_info is None:
@ -175,7 +175,10 @@ class _SavedModelBuilder(object):
"All TensorInfo protos used in the SignatureDefs must have one of "
"the 'encoding' fields (e.g., name or coo_sparse) set: %s"
% tensor_info)
if tensor_info.dtype is types_pb2.DT_INVALID:
if tensor_info.WhichOneof("encoding") == "composite_tensor":
for component in tensor_info.composite_tensor.components:
self._validate_tensor_info(component)
elif tensor_info.dtype == types_pb2.DT_INVALID:
raise AssertionError(
"All TensorInfo protos used in the SignatureDefs must have the dtype "
"field set: %s" % tensor_info)

View File

@ -1709,7 +1709,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
imported = cycle(root, cycles)
self.assertAllClose(2., imported.f(constant_op.constant(1.)))
def test_ragged_no_signature(self, cycles):
def test_ragged(self, cycles):
@def_function.function(input_signature=[
ragged_tensor.RaggedTensorSpec(shape=[None, None], dtype=dtypes.int32)
@ -1720,10 +1720,13 @@ class LoadTest(test.TestCase, parameterized.TestCase):
obj = tracking.AutoTrackable()
obj.f = f
imported = cycle(obj, cycles, signatures={})
imported1 = cycle(obj, cycles, signatures={})
rt = ragged_factory_ops.constant([[1, 2], [3]])
self.assertAllEqual(imported.f(rt), [[2, 3], [4]])
self.assertAllEqual(imported1.f(rt), [[2, 3], [4]])
imported2 = cycle(obj, cycles)
rt = ragged_factory_ops.constant([[1, 2], [3]])
self.assertAllEqual(imported2.f(rt), [[2, 3], [4]])
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
@parameterized.named_parameters(

View File

@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import constants
@ -43,6 +44,7 @@ from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import main_op
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils
from tensorflow.python.training import saver_test_utils
from tensorflow.python.training import training
from tensorflow.python.util import compat
@ -642,6 +644,19 @@ class SavedModelTest(SavedModelTestBase):
builder = saved_model_builder._SavedModelBuilder(export_dir)
self._validate_outputs_tensor_info_accept(builder, tensor_with_coo)
@test_util.run_deprecated_v1
def testSignatureDefValidationSucceedsWithRagged(self):
ragged_tensor = ragged_factory_ops.constant([[1, 2], [3]])
tensor_with_ragged = utils.build_tensor_info(ragged_tensor)
export_dir = self._get_export_dir("test_signature_def_validation_ragged_1")
builder = saved_model_builder._SavedModelBuilder(export_dir)
self._validate_inputs_tensor_info_accept(builder, tensor_with_ragged)
export_dir = self._get_export_dir("test_signature_def_validation_ragged_2")
builder = saved_model_builder._SavedModelBuilder(export_dir)
self._validate_outputs_tensor_info_accept(builder, tensor_with_ragged)
@test_util.run_deprecated_v1
def testAssets(self):
export_dir = self._get_export_dir("test_assets")

View File

@ -22,15 +22,19 @@ import os
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import struct_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@ -65,6 +69,10 @@ def build_tensor_info(tensor):
def build_tensor_info_internal(tensor):
"""Utility function to build TensorInfo proto from a Tensor."""
if (isinstance(tensor, composite_tensor.CompositeTensor) and
not isinstance(tensor, sparse_tensor.SparseTensor)):
return _build_composite_tensor_info_internal(tensor)
tensor_info = meta_graph_pb2.TensorInfo(
dtype=dtypes.as_dtype(tensor.dtype).as_datatype_enum,
tensor_shape=tensor.get_shape().as_proto())
@ -77,6 +85,19 @@ def build_tensor_info_internal(tensor):
return tensor_info
def _build_composite_tensor_info_internal(tensor):
"""Utility function to build TensorInfo proto from a CompositeTensor."""
spec = tensor._type_spec # pylint: disable=protected-access
tensor_info = meta_graph_pb2.TensorInfo()
struct_coder = nested_structure_coder.StructureCoder()
spec_proto = struct_coder.encode_structure(spec)
tensor_info.composite_tensor.type_spec.CopyFrom(spec_proto.type_spec_value)
for component in nest.flatten(tensor, expand_composites=True):
tensor_info.composite_tensor.components.add().CopyFrom(
build_tensor_info_internal(component))
return tensor_info
def build_tensor_info_from_op(op):
"""Utility function to build TensorInfo proto from an Op.
@ -120,17 +141,19 @@ def build_tensor_info_from_op(op):
"library as tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info or "
"tf.compat.v1.saved_model.get_tensor_from_tensor_info.")
def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
"""Returns the Tensor or SparseTensor described by a TensorInfo proto.
"""Returns the Tensor or CompositeTensor described by a TensorInfo proto.
Args:
tensor_info: A TensorInfo proto describing a Tensor or SparseTensor.
tensor_info: A TensorInfo proto describing a Tensor or SparseTensor or
CompositeTensor.
graph: The tf.Graph in which tensors are looked up. If None, the
current default graph is used.
import_scope: If not None, names in `tensor_info` are prefixed with this
string before lookup.
Returns:
The Tensor or SparseTensor in `graph` described by `tensor_info`.
The Tensor or SparseTensor or CompositeTensor in `graph` described by
`tensor_info`.
Raises:
KeyError: If `tensor_info` does not correspond to a tensor in `graph`.
@ -148,6 +171,14 @@ def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
_get_tensor(tensor_info.coo_sparse.indices_tensor_name),
_get_tensor(tensor_info.coo_sparse.values_tensor_name),
_get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name))
elif encoding == "composite_tensor":
struct_coder = nested_structure_coder.StructureCoder()
spec_proto = struct_pb2.StructuredValue(
type_spec_value=tensor_info.composite_tensor.type_spec)
spec = struct_coder.decode_proto(spec_proto)
components = [_get_tensor(component.name) for component in
tensor_info.composite_tensor.components]
return spec.from_components(components)
else:
raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import struct_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
@ -28,7 +29,9 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import utils
@ -82,6 +85,26 @@ class UtilsTest(test.TestCase):
self.assertEqual(42, x_tensor_info.tensor_shape.dim[0].size)
self.assertEqual(69, x_tensor_info.tensor_shape.dim[1].size)
@test_util.run_v1_only("b/120545219")
def testBuildTensorInfoRagged(self):
x = ragged_factory_ops.constant([[1, 2], [3]])
x_tensor_info = utils.build_tensor_info(x)
# Check components
self.assertEqual(x.values.name,
x_tensor_info.composite_tensor.components[0].name)
self.assertEqual(types_pb2.DT_INT32,
x_tensor_info.composite_tensor.components[0].dtype)
self.assertEqual(x.row_splits.name,
x_tensor_info.composite_tensor.components[1].name)
self.assertEqual(types_pb2.DT_INT64,
x_tensor_info.composite_tensor.components[1].dtype)
# Check type_spec.
struct_coder = nested_structure_coder.StructureCoder()
spec_proto = struct_pb2.StructuredValue(
type_spec_value=x_tensor_info.composite_tensor.type_spec)
spec = struct_coder.decode_proto(spec_proto)
self.assertEqual(spec, x._type_spec)
def testBuildTensorInfoEager(self):
x = constant_op.constant(1, name="x")
with context.eager_mode(), self.assertRaisesRegexp(

View File

@ -0,0 +1,20 @@
path: "tensorflow.TensorInfo.CompositeTensor"
tf_proto {
descriptor {
name: "CompositeTensor"
field {
name: "type_spec"
number: 1
label: LABEL_OPTIONAL
type: TYPE_MESSAGE
type_name: ".tensorflow.TypeSpecProto"
}
field {
name: "components"
number: 2
label: LABEL_REPEATED
type: TYPE_MESSAGE
type_name: ".tensorflow.TensorInfo"
}
}
}

View File

@ -17,6 +17,14 @@ tf_proto {
type_name: ".tensorflow.TensorInfo.CooSparse"
oneof_index: 0
}
field {
name: "composite_tensor"
number: 5
label: LABEL_OPTIONAL
type: TYPE_MESSAGE
type_name: ".tensorflow.TensorInfo.CompositeTensor"
oneof_index: 0
}
field {
name: "dtype"
number: 2
@ -52,6 +60,23 @@ tf_proto {
type: TYPE_STRING
}
}
nested_type {
name: "CompositeTensor"
field {
name: "type_spec"
number: 1
label: LABEL_OPTIONAL
type: TYPE_MESSAGE
type_name: ".tensorflow.TypeSpecProto"
}
field {
name: "components"
number: 2
label: LABEL_REPEATED
type: TYPE_MESSAGE
type_name: ".tensorflow.TensorInfo"
}
}
oneof_decl {
name: "encoding"
}