From b78d23cf92656db63bca1f2cbc9636c7caa387ca Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Fri, 9 Aug 2019 14:54:19 -0700 Subject: [PATCH] Update the TensorInfo protobuf message with an encoding for composite tensors; and update SavedModel to use this new encoding. PiperOrigin-RevId: 262639435 --- tensorflow/cc/tools/freeze_saved_model.cc | 4 ++ .../cc/tools/freeze_saved_model_test.cc | 58 +++++++++++++++++++ tensorflow/core/protobuf/meta_graph.proto | 12 ++++ tensorflow/python/eager/BUILD | 1 + tensorflow/python/eager/wrap_function.py | 20 +++++-- tensorflow/python/eager/wrap_function_test.py | 27 +++++++++ tensorflow/python/saved_model/BUILD | 7 +++ tensorflow/python/saved_model/builder_impl.py | 11 ++-- tensorflow/python/saved_model/load_test.py | 9 ++- .../python/saved_model/saved_model_test.py | 15 +++++ tensorflow/python/saved_model/utils_impl.py | 37 +++++++++++- tensorflow/python/saved_model/utils_test.py | 23 ++++++++ ...rflow.-tensor-info.-composite-tensor.pbtxt | 20 +++++++ .../golden/v1/tensorflow.-tensor-info.pbtxt | 25 ++++++++ 14 files changed, 255 insertions(+), 14 deletions(-) create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.-composite-tensor.pbtxt diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index eeb91017890..0ec48ec9357 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -42,6 +42,10 @@ void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info, tensor_names->insert(coo_sparse.values_tensor_name()); tensor_names->insert(coo_sparse.indices_tensor_name()); tensor_names->insert(coo_sparse.dense_shape_tensor_name()); + } else if (tensor_info.has_composite_tensor()) { + for (const auto& component : tensor_info.composite_tensor().components()) { + tensor_names->insert(component.name()); + } } else { tensor_names->insert(tensor_info.name()); } diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc index 979b23c3fc5..274a1630a05 100644 --- a/tensorflow/cc/tools/freeze_saved_model_test.cc +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -425,5 +425,63 @@ TEST_F(FreezeTest, GraphDefWithAndWithoutDependentResourceVariables) { TestFreezeGraphWithAndWithoutDependentVariables(true); } +TEST_F(FreezeTest, InputsAndOutputsCompositeTensorSignatureDef) { + // Test that inputs and outputs get correctly populated for a + // SignatureDef containing composite tensor inputs and outputs. + SavedModelBundle saved_model_bundle; + SignatureDef signature_def; + + TensorInfo& in = (*signature_def.mutable_inputs())["input_arg"]; + in.mutable_composite_tensor()->add_components()->set_name("input1:0"); + in.mutable_composite_tensor()->add_components()->set_name("input2:0"); + + TensorInfo& out = (*signature_def.mutable_outputs())["output_arg"]; + out.mutable_composite_tensor()->add_components()->set_name("output2:0"); + out.mutable_composite_tensor()->add_components()->set_name("output1:0"); + + AddSignatureDefToSavedModelBundle(signature_def, "signature_def", + &saved_model_bundle); + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + std::unordered_set expected_inputs = {"input1:0", "input2:0"}; + std::unordered_set expected_outputs = {"output1:0", "output2:0"}; + EXPECT_EQ(expected_inputs, inputs); + EXPECT_EQ(expected_outputs, outputs); +} + +TEST_F(FreezeTest, InputsAndOutputsSparseCooSignatureDef) { + // Test that inputs and outputs get correctly populated for a + // SignatureDef containing composite tensor inputs and outputs. + SavedModelBundle saved_model_bundle; + SignatureDef signature_def; + + TensorInfo& in = (*signature_def.mutable_inputs())["input_arg"]; + in.mutable_coo_sparse()->set_values_tensor_name("input1:0"); + in.mutable_coo_sparse()->set_indices_tensor_name("input2:0"); + in.mutable_coo_sparse()->set_dense_shape_tensor_name("input3:0"); + + TensorInfo& out = (*signature_def.mutable_outputs())["output_arg"]; + out.mutable_coo_sparse()->set_values_tensor_name("output1:0"); + out.mutable_coo_sparse()->set_indices_tensor_name("output2:0"); + out.mutable_coo_sparse()->set_dense_shape_tensor_name("output3:0"); + + AddSignatureDefToSavedModelBundle(signature_def, "signature_def", + &saved_model_bundle); + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + std::unordered_set expected_inputs = {"input1:0", "input2:0", + "input3:0"}; + std::unordered_set expected_outputs = {"output1:0", "output2:0", + "output3:0"}; + EXPECT_EQ(expected_inputs, inputs); + EXPECT_EQ(expected_outputs, outputs); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/protobuf/meta_graph.proto b/tensorflow/core/protobuf/meta_graph.proto index fa0192cf67c..1eb2023f01d 100644 --- a/tensorflow/core/protobuf/meta_graph.proto +++ b/tensorflow/core/protobuf/meta_graph.proto @@ -14,6 +14,7 @@ import "tensorflow/core/framework/tensor_shape.proto"; import "tensorflow/core/framework/types.proto"; import "tensorflow/core/protobuf/saved_object_graph.proto"; import "tensorflow/core/protobuf/saver.proto"; +import "tensorflow/core/protobuf/struct.proto"; // NOTE: This protocol buffer is evolving, and will go through revisions in the // coming months. @@ -225,6 +226,15 @@ message TensorInfo { string dense_shape_tensor_name = 3; } + // Generic encoding for composite tensors. + message CompositeTensor { + // The serialized TypeSpec for the composite tensor. + TypeSpecProto type_spec = 1; + + // A TensorInfo for each flattened component tensor. + repeated TensorInfo components = 2; + } + oneof encoding { // For dense `Tensor`s, the name of the tensor in the graph. string name = 1; @@ -233,6 +243,8 @@ message TensorInfo { // uses only the COO encoding. This is supported and documented in the // SparseTensor Python class. CooSparse coo_sparse = 4; + // Generic encoding for CompositeTensors. + CompositeTensor composite_tensor = 5; } DataType dtype = 2; // The static shape should be recorded here, to the extent that it can diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 7d3f62f28d0..91615a9a3f3 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -720,6 +720,7 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:template", "//tensorflow/python:variable_scope", + "//tensorflow/python/saved_model:nested_structure_coder", "//tensorflow/python/training/tracking:base", ], ) diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index 8f7a8fea05a..625a7d3c166 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -22,9 +22,11 @@ from __future__ import print_function import weakref from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.core.protobuf import struct_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import lift_to_graph +from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import func_graph from tensorflow.python.framework import importer from tensorflow.python.framework import ops @@ -34,6 +36,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.training.tracking import data_structures from tensorflow.python.util import nest from tensorflow.python.util import object_identity @@ -104,6 +107,14 @@ def _get_element_from_tensor_info(tensor_info, graph): graph.get_tensor_by_name(tensor_info.coo_sparse.values_tensor_name), graph.get_tensor_by_name( tensor_info.coo_sparse.dense_shape_tensor_name)) + elif encoding == "composite_tensor": + struct_coder = nested_structure_coder.StructureCoder() + spec_proto = struct_pb2.StructuredValue( + type_spec_value=tensor_info.composite_tensor.type_spec) + spec = struct_coder.decode_proto(spec_proto) + components = [graph.get_tensor_by_name(component.name) for component in + tensor_info.composite_tensor.components] + return spec._from_components(components) # pylint: disable=protected-access else: raise ValueError("Invalid TensorInfo.encoding: %s" % encoding) @@ -243,8 +254,8 @@ class WrappedFunction(function.ConcreteFunction): """ # TODO(b/129646028): Add support for CompositeTensors. name = name or "pruned" - feeds = nest.map_structure(self.graph.as_graph_element, feeds) - flat_feeds = nest.flatten(feeds) + flat_feeds = nest.flatten(feeds, expand_composites=True) + flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds] for f in flat_feeds: if not isinstance(f, ops.Tensor): raise ValueError("Feeds must be tensors.") @@ -278,12 +289,13 @@ class WrappedFunction(function.ConcreteFunction): elif isinstance(fetch, meta_graph_pb2.TensorInfo): tensor_infos.append(fetch) decoded = _get_element_from_tensor_info(fetch, self._func_graph) - if tensor_util.is_tensor(decoded): + if (tensor_util.is_tensor(decoded) or + isinstance(decoded, composite_tensor.CompositeTensor)): tensor_fetches.append(decoded) else: operation_fetches.append(decoded) return decoded - elif isinstance(fetch, ops.Tensor): + elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)): tensor_fetches.append(fetch) return fetch else: diff --git a/tensorflow/python/eager/wrap_function_test.py b/tensorflow/python/eager/wrap_function_test.py index 1a135b3534f..4b592a5f8df 100644 --- a/tensorflow/python/eager/wrap_function_test.py +++ b/tensorflow/python/eager/wrap_function_test.py @@ -36,6 +36,8 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test from tensorflow.python.training import saver as saver_lib @@ -84,6 +86,31 @@ class WrapFunctionTest(test.TestCase): f_pruned = f_wrapped.prune(x_in[0], [x_out[0]]) self.assertAllEqual(f_pruned(ops.convert_to_tensor(2.0)), [4.0]) + def testPruneRagged(self): + + x_in = [] + x_out = [] + + def f(x, y): + x_in.append(x) + xx = x * x + x_out.append(xx) + return xx, y * y + + x_spec = ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32) + y_spec = tensor_spec.TensorSpec((), dtypes.float32) + + f_wrapped = wrap_function.wrap_function(f, [x_spec, y_spec]) + + f_pruned = f_wrapped.prune(x_in[0], x_out[0]) + rt = ragged_factory_ops.constant([[1.0, 2.0], [3.0]]) + expected = ragged_factory_ops.constant_value([[1.0, 4.0], [9.0]]) + + # Note: when we call f_pruned, we must pass the RaggedTensor in using + # its components, since that's the current convention for how concrete + # functions handle structured inputs. + self.assertAllEqual(f_pruned(rt.values, rt.row_splits), expected) + def _assert_single_captured_variable_argument(self, graph_def): # The single FunctionDef should have one argument, a captured variable function_def, = graph_def.library.function diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 29ce69ce9a3..1ca3804515e 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -191,6 +191,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":constants", + ":nested_structure_coder", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:lib", @@ -481,6 +482,12 @@ py_library( deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:framework", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:optional_ops", + "//tensorflow/python/distribute:values", + "//tensorflow/python/ops/ragged", "@six_archive//:six", ], ) diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 65cffc624d8..29b62a6566b 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -155,14 +155,14 @@ class _SavedModelBuilder(object): def _validate_tensor_info(self, tensor_info): """Validates the `TensorInfo` proto. - Checks if the `encoding` (`name` or `coo_sparse`) and `dtype` fields exist - and are non-empty. + Checks if the `encoding` (`name` or `coo_sparse` or `type_spec`) and + `dtype` fields exist and are non-empty. Args: tensor_info: `TensorInfo` protocol buffer to validate. Raises: - AssertionError: If the `name` or `dtype` fields of the supplied + AssertionError: If the `encoding` or `dtype` fields of the supplied `TensorInfo` proto are not populated. """ if tensor_info is None: @@ -175,7 +175,10 @@ class _SavedModelBuilder(object): "All TensorInfo protos used in the SignatureDefs must have one of " "the 'encoding' fields (e.g., name or coo_sparse) set: %s" % tensor_info) - if tensor_info.dtype is types_pb2.DT_INVALID: + if tensor_info.WhichOneof("encoding") == "composite_tensor": + for component in tensor_info.composite_tensor.components: + self._validate_tensor_info(component) + elif tensor_info.dtype == types_pb2.DT_INVALID: raise AssertionError( "All TensorInfo protos used in the SignatureDefs must have the dtype " "field set: %s" % tensor_info) diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index 4fa837380a6..102b93e4f3d 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -1709,7 +1709,7 @@ class LoadTest(test.TestCase, parameterized.TestCase): imported = cycle(root, cycles) self.assertAllClose(2., imported.f(constant_op.constant(1.))) - def test_ragged_no_signature(self, cycles): + def test_ragged(self, cycles): @def_function.function(input_signature=[ ragged_tensor.RaggedTensorSpec(shape=[None, None], dtype=dtypes.int32) @@ -1720,10 +1720,13 @@ class LoadTest(test.TestCase, parameterized.TestCase): obj = tracking.AutoTrackable() obj.f = f - imported = cycle(obj, cycles, signatures={}) + imported1 = cycle(obj, cycles, signatures={}) rt = ragged_factory_ops.constant([[1, 2], [3]]) - self.assertAllEqual(imported.f(rt), [[2, 3], [4]]) + self.assertAllEqual(imported1.f(rt), [[2, 3], [4]]) + imported2 = cycle(obj, cycles) + rt = ragged_factory_ops.constant([[1, 2], [3]]) + self.assertAllEqual(imported2.f(rt), [[2, 3], [4]]) @keras_parameterized.run_all_keras_modes(always_skip_v1=True) @parameterized.named_parameters( diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index e36b8b30bf2..7722cd3b14c 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables +from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import constants @@ -43,6 +44,7 @@ from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import main_op from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants +from tensorflow.python.saved_model import utils from tensorflow.python.training import saver_test_utils from tensorflow.python.training import training from tensorflow.python.util import compat @@ -642,6 +644,19 @@ class SavedModelTest(SavedModelTestBase): builder = saved_model_builder._SavedModelBuilder(export_dir) self._validate_outputs_tensor_info_accept(builder, tensor_with_coo) + @test_util.run_deprecated_v1 + def testSignatureDefValidationSucceedsWithRagged(self): + ragged_tensor = ragged_factory_ops.constant([[1, 2], [3]]) + tensor_with_ragged = utils.build_tensor_info(ragged_tensor) + + export_dir = self._get_export_dir("test_signature_def_validation_ragged_1") + builder = saved_model_builder._SavedModelBuilder(export_dir) + self._validate_inputs_tensor_info_accept(builder, tensor_with_ragged) + + export_dir = self._get_export_dir("test_signature_def_validation_ragged_2") + builder = saved_model_builder._SavedModelBuilder(export_dir) + self._validate_outputs_tensor_info_accept(builder, tensor_with_ragged) + @test_util.run_deprecated_v1 def testAssets(self): export_dir = self._get_export_dir("test_assets") diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py index 2e7b2080574..3dd7d6c7ae4 100644 --- a/tensorflow/python/saved_model/utils_impl.py +++ b/tensorflow/python/saved_model/utils_impl.py @@ -22,15 +22,19 @@ import os from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.core.protobuf import struct_pb2 from tensorflow.python.eager import context +from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import file_io from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.util import compat from tensorflow.python.util import deprecation +from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -65,6 +69,10 @@ def build_tensor_info(tensor): def build_tensor_info_internal(tensor): """Utility function to build TensorInfo proto from a Tensor.""" + if (isinstance(tensor, composite_tensor.CompositeTensor) and + not isinstance(tensor, sparse_tensor.SparseTensor)): + return _build_composite_tensor_info_internal(tensor) + tensor_info = meta_graph_pb2.TensorInfo( dtype=dtypes.as_dtype(tensor.dtype).as_datatype_enum, tensor_shape=tensor.get_shape().as_proto()) @@ -77,6 +85,19 @@ def build_tensor_info_internal(tensor): return tensor_info +def _build_composite_tensor_info_internal(tensor): + """Utility function to build TensorInfo proto from a CompositeTensor.""" + spec = tensor._type_spec # pylint: disable=protected-access + tensor_info = meta_graph_pb2.TensorInfo() + struct_coder = nested_structure_coder.StructureCoder() + spec_proto = struct_coder.encode_structure(spec) + tensor_info.composite_tensor.type_spec.CopyFrom(spec_proto.type_spec_value) + for component in nest.flatten(tensor, expand_composites=True): + tensor_info.composite_tensor.components.add().CopyFrom( + build_tensor_info_internal(component)) + return tensor_info + + def build_tensor_info_from_op(op): """Utility function to build TensorInfo proto from an Op. @@ -120,17 +141,19 @@ def build_tensor_info_from_op(op): "library as tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info or " "tf.compat.v1.saved_model.get_tensor_from_tensor_info.") def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None): - """Returns the Tensor or SparseTensor described by a TensorInfo proto. + """Returns the Tensor or CompositeTensor described by a TensorInfo proto. Args: - tensor_info: A TensorInfo proto describing a Tensor or SparseTensor. + tensor_info: A TensorInfo proto describing a Tensor or SparseTensor or + CompositeTensor. graph: The tf.Graph in which tensors are looked up. If None, the current default graph is used. import_scope: If not None, names in `tensor_info` are prefixed with this string before lookup. Returns: - The Tensor or SparseTensor in `graph` described by `tensor_info`. + The Tensor or SparseTensor or CompositeTensor in `graph` described by + `tensor_info`. Raises: KeyError: If `tensor_info` does not correspond to a tensor in `graph`. @@ -148,6 +171,14 @@ def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None): _get_tensor(tensor_info.coo_sparse.indices_tensor_name), _get_tensor(tensor_info.coo_sparse.values_tensor_name), _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name)) + elif encoding == "composite_tensor": + struct_coder = nested_structure_coder.StructureCoder() + spec_proto = struct_pb2.StructuredValue( + type_spec_value=tensor_info.composite_tensor.type_spec) + spec = struct_coder.decode_proto(spec_proto) + components = [_get_tensor(component.name) for component in + tensor_info.composite_tensor.components] + return spec.from_components(components) else: raise ValueError("Invalid TensorInfo.encoding: %s" % encoding) diff --git a/tensorflow/python/saved_model/utils_test.py b/tensorflow/python/saved_model/utils_test.py index 1e12de91b86..d176b91db1e 100644 --- a/tensorflow/python/saved_model/utils_test.py +++ b/tensorflow/python/saved_model/utils_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.core.framework import types_pb2 +from tensorflow.core.protobuf import struct_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import constant_op @@ -28,7 +29,9 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test +from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.saved_model import utils @@ -82,6 +85,26 @@ class UtilsTest(test.TestCase): self.assertEqual(42, x_tensor_info.tensor_shape.dim[0].size) self.assertEqual(69, x_tensor_info.tensor_shape.dim[1].size) + @test_util.run_v1_only("b/120545219") + def testBuildTensorInfoRagged(self): + x = ragged_factory_ops.constant([[1, 2], [3]]) + x_tensor_info = utils.build_tensor_info(x) + # Check components + self.assertEqual(x.values.name, + x_tensor_info.composite_tensor.components[0].name) + self.assertEqual(types_pb2.DT_INT32, + x_tensor_info.composite_tensor.components[0].dtype) + self.assertEqual(x.row_splits.name, + x_tensor_info.composite_tensor.components[1].name) + self.assertEqual(types_pb2.DT_INT64, + x_tensor_info.composite_tensor.components[1].dtype) + # Check type_spec. + struct_coder = nested_structure_coder.StructureCoder() + spec_proto = struct_pb2.StructuredValue( + type_spec_value=x_tensor_info.composite_tensor.type_spec) + spec = struct_coder.decode_proto(spec_proto) + self.assertEqual(spec, x._type_spec) + def testBuildTensorInfoEager(self): x = constant_op.constant(1, name="x") with context.eager_mode(), self.assertRaisesRegexp( diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.-composite-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.-composite-tensor.pbtxt new file mode 100644 index 00000000000..5fe1b984af2 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.-composite-tensor.pbtxt @@ -0,0 +1,20 @@ +path: "tensorflow.TensorInfo.CompositeTensor" +tf_proto { + descriptor { + name: "CompositeTensor" + field { + name: "type_spec" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TypeSpecProto" + } + field { + name: "components" + number: 2 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorInfo" + } + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.pbtxt index 63566c808e5..48773ea0dce 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.pbtxt @@ -17,6 +17,14 @@ tf_proto { type_name: ".tensorflow.TensorInfo.CooSparse" oneof_index: 0 } + field { + name: "composite_tensor" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorInfo.CompositeTensor" + oneof_index: 0 + } field { name: "dtype" number: 2 @@ -52,6 +60,23 @@ tf_proto { type: TYPE_STRING } } + nested_type { + name: "CompositeTensor" + field { + name: "type_spec" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TypeSpecProto" + } + field { + name: "components" + number: 2 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorInfo" + } + } oneof_decl { name: "encoding" }