Update the TensorInfo protobuf message with an encoding for composite tensors; and update SavedModel to use this new encoding.
PiperOrigin-RevId: 262639435
This commit is contained in:
parent
23e33f871b
commit
b78d23cf92
@ -42,6 +42,10 @@ void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info,
|
||||
tensor_names->insert(coo_sparse.values_tensor_name());
|
||||
tensor_names->insert(coo_sparse.indices_tensor_name());
|
||||
tensor_names->insert(coo_sparse.dense_shape_tensor_name());
|
||||
} else if (tensor_info.has_composite_tensor()) {
|
||||
for (const auto& component : tensor_info.composite_tensor().components()) {
|
||||
tensor_names->insert(component.name());
|
||||
}
|
||||
} else {
|
||||
tensor_names->insert(tensor_info.name());
|
||||
}
|
||||
|
@ -425,5 +425,63 @@ TEST_F(FreezeTest, GraphDefWithAndWithoutDependentResourceVariables) {
|
||||
TestFreezeGraphWithAndWithoutDependentVariables(true);
|
||||
}
|
||||
|
||||
TEST_F(FreezeTest, InputsAndOutputsCompositeTensorSignatureDef) {
|
||||
// Test that inputs and outputs get correctly populated for a
|
||||
// SignatureDef containing composite tensor inputs and outputs.
|
||||
SavedModelBundle saved_model_bundle;
|
||||
SignatureDef signature_def;
|
||||
|
||||
TensorInfo& in = (*signature_def.mutable_inputs())["input_arg"];
|
||||
in.mutable_composite_tensor()->add_components()->set_name("input1:0");
|
||||
in.mutable_composite_tensor()->add_components()->set_name("input2:0");
|
||||
|
||||
TensorInfo& out = (*signature_def.mutable_outputs())["output_arg"];
|
||||
out.mutable_composite_tensor()->add_components()->set_name("output2:0");
|
||||
out.mutable_composite_tensor()->add_components()->set_name("output1:0");
|
||||
|
||||
AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
|
||||
&saved_model_bundle);
|
||||
GraphDef frozen_graph_def;
|
||||
std::unordered_set<string> inputs;
|
||||
std::unordered_set<string> outputs;
|
||||
TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
|
||||
&outputs));
|
||||
std::unordered_set<string> expected_inputs = {"input1:0", "input2:0"};
|
||||
std::unordered_set<string> expected_outputs = {"output1:0", "output2:0"};
|
||||
EXPECT_EQ(expected_inputs, inputs);
|
||||
EXPECT_EQ(expected_outputs, outputs);
|
||||
}
|
||||
|
||||
TEST_F(FreezeTest, InputsAndOutputsSparseCooSignatureDef) {
|
||||
// Test that inputs and outputs get correctly populated for a
|
||||
// SignatureDef containing composite tensor inputs and outputs.
|
||||
SavedModelBundle saved_model_bundle;
|
||||
SignatureDef signature_def;
|
||||
|
||||
TensorInfo& in = (*signature_def.mutable_inputs())["input_arg"];
|
||||
in.mutable_coo_sparse()->set_values_tensor_name("input1:0");
|
||||
in.mutable_coo_sparse()->set_indices_tensor_name("input2:0");
|
||||
in.mutable_coo_sparse()->set_dense_shape_tensor_name("input3:0");
|
||||
|
||||
TensorInfo& out = (*signature_def.mutable_outputs())["output_arg"];
|
||||
out.mutable_coo_sparse()->set_values_tensor_name("output1:0");
|
||||
out.mutable_coo_sparse()->set_indices_tensor_name("output2:0");
|
||||
out.mutable_coo_sparse()->set_dense_shape_tensor_name("output3:0");
|
||||
|
||||
AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
|
||||
&saved_model_bundle);
|
||||
GraphDef frozen_graph_def;
|
||||
std::unordered_set<string> inputs;
|
||||
std::unordered_set<string> outputs;
|
||||
TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
|
||||
&outputs));
|
||||
std::unordered_set<string> expected_inputs = {"input1:0", "input2:0",
|
||||
"input3:0"};
|
||||
std::unordered_set<string> expected_outputs = {"output1:0", "output2:0",
|
||||
"output3:0"};
|
||||
EXPECT_EQ(expected_inputs, inputs);
|
||||
EXPECT_EQ(expected_outputs, outputs);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -14,6 +14,7 @@ import "tensorflow/core/framework/tensor_shape.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
import "tensorflow/core/protobuf/saved_object_graph.proto";
|
||||
import "tensorflow/core/protobuf/saver.proto";
|
||||
import "tensorflow/core/protobuf/struct.proto";
|
||||
|
||||
// NOTE: This protocol buffer is evolving, and will go through revisions in the
|
||||
// coming months.
|
||||
@ -225,6 +226,15 @@ message TensorInfo {
|
||||
string dense_shape_tensor_name = 3;
|
||||
}
|
||||
|
||||
// Generic encoding for composite tensors.
|
||||
message CompositeTensor {
|
||||
// The serialized TypeSpec for the composite tensor.
|
||||
TypeSpecProto type_spec = 1;
|
||||
|
||||
// A TensorInfo for each flattened component tensor.
|
||||
repeated TensorInfo components = 2;
|
||||
}
|
||||
|
||||
oneof encoding {
|
||||
// For dense `Tensor`s, the name of the tensor in the graph.
|
||||
string name = 1;
|
||||
@ -233,6 +243,8 @@ message TensorInfo {
|
||||
// uses only the COO encoding. This is supported and documented in the
|
||||
// SparseTensor Python class.
|
||||
CooSparse coo_sparse = 4;
|
||||
// Generic encoding for CompositeTensors.
|
||||
CompositeTensor composite_tensor = 5;
|
||||
}
|
||||
DataType dtype = 2;
|
||||
// The static shape should be recorded here, to the extent that it can
|
||||
|
@ -720,6 +720,7 @@ py_library(
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:template",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/saved_model:nested_structure_coder",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
],
|
||||
)
|
||||
|
@ -22,9 +22,11 @@ from __future__ import print_function
|
||||
import weakref
|
||||
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.core.protobuf import struct_pb2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import lift_to_graph
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import importer
|
||||
from tensorflow.python.framework import ops
|
||||
@ -34,6 +36,7 @@ from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import nested_structure_coder
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
@ -104,6 +107,14 @@ def _get_element_from_tensor_info(tensor_info, graph):
|
||||
graph.get_tensor_by_name(tensor_info.coo_sparse.values_tensor_name),
|
||||
graph.get_tensor_by_name(
|
||||
tensor_info.coo_sparse.dense_shape_tensor_name))
|
||||
elif encoding == "composite_tensor":
|
||||
struct_coder = nested_structure_coder.StructureCoder()
|
||||
spec_proto = struct_pb2.StructuredValue(
|
||||
type_spec_value=tensor_info.composite_tensor.type_spec)
|
||||
spec = struct_coder.decode_proto(spec_proto)
|
||||
components = [graph.get_tensor_by_name(component.name) for component in
|
||||
tensor_info.composite_tensor.components]
|
||||
return spec._from_components(components) # pylint: disable=protected-access
|
||||
else:
|
||||
raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
|
||||
|
||||
@ -243,8 +254,8 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
"""
|
||||
# TODO(b/129646028): Add support for CompositeTensors.
|
||||
name = name or "pruned"
|
||||
feeds = nest.map_structure(self.graph.as_graph_element, feeds)
|
||||
flat_feeds = nest.flatten(feeds)
|
||||
flat_feeds = nest.flatten(feeds, expand_composites=True)
|
||||
flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds]
|
||||
for f in flat_feeds:
|
||||
if not isinstance(f, ops.Tensor):
|
||||
raise ValueError("Feeds must be tensors.")
|
||||
@ -278,12 +289,13 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
elif isinstance(fetch, meta_graph_pb2.TensorInfo):
|
||||
tensor_infos.append(fetch)
|
||||
decoded = _get_element_from_tensor_info(fetch, self._func_graph)
|
||||
if tensor_util.is_tensor(decoded):
|
||||
if (tensor_util.is_tensor(decoded) or
|
||||
isinstance(decoded, composite_tensor.CompositeTensor)):
|
||||
tensor_fetches.append(decoded)
|
||||
else:
|
||||
operation_fetches.append(decoded)
|
||||
return decoded
|
||||
elif isinstance(fetch, ops.Tensor):
|
||||
elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)):
|
||||
tensor_fetches.append(fetch)
|
||||
return fetch
|
||||
else:
|
||||
|
@ -36,6 +36,8 @@ from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
@ -84,6 +86,31 @@ class WrapFunctionTest(test.TestCase):
|
||||
f_pruned = f_wrapped.prune(x_in[0], [x_out[0]])
|
||||
self.assertAllEqual(f_pruned(ops.convert_to_tensor(2.0)), [4.0])
|
||||
|
||||
def testPruneRagged(self):
|
||||
|
||||
x_in = []
|
||||
x_out = []
|
||||
|
||||
def f(x, y):
|
||||
x_in.append(x)
|
||||
xx = x * x
|
||||
x_out.append(xx)
|
||||
return xx, y * y
|
||||
|
||||
x_spec = ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32)
|
||||
y_spec = tensor_spec.TensorSpec((), dtypes.float32)
|
||||
|
||||
f_wrapped = wrap_function.wrap_function(f, [x_spec, y_spec])
|
||||
|
||||
f_pruned = f_wrapped.prune(x_in[0], x_out[0])
|
||||
rt = ragged_factory_ops.constant([[1.0, 2.0], [3.0]])
|
||||
expected = ragged_factory_ops.constant_value([[1.0, 4.0], [9.0]])
|
||||
|
||||
# Note: when we call f_pruned, we must pass the RaggedTensor in using
|
||||
# its components, since that's the current convention for how concrete
|
||||
# functions handle structured inputs.
|
||||
self.assertAllEqual(f_pruned(rt.values, rt.row_splits), expected)
|
||||
|
||||
def _assert_single_captured_variable_argument(self, graph_def):
|
||||
# The single FunctionDef should have one argument, a captured variable
|
||||
function_def, = graph_def.library.function
|
||||
|
@ -191,6 +191,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":constants",
|
||||
":nested_structure_coder",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:lib",
|
||||
@ -481,6 +482,12 @@ py_library(
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//tensorflow/python/data/ops:optional_ops",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/ops/ragged",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -155,14 +155,14 @@ class _SavedModelBuilder(object):
|
||||
def _validate_tensor_info(self, tensor_info):
|
||||
"""Validates the `TensorInfo` proto.
|
||||
|
||||
Checks if the `encoding` (`name` or `coo_sparse`) and `dtype` fields exist
|
||||
and are non-empty.
|
||||
Checks if the `encoding` (`name` or `coo_sparse` or `type_spec`) and
|
||||
`dtype` fields exist and are non-empty.
|
||||
|
||||
Args:
|
||||
tensor_info: `TensorInfo` protocol buffer to validate.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the `name` or `dtype` fields of the supplied
|
||||
AssertionError: If the `encoding` or `dtype` fields of the supplied
|
||||
`TensorInfo` proto are not populated.
|
||||
"""
|
||||
if tensor_info is None:
|
||||
@ -175,7 +175,10 @@ class _SavedModelBuilder(object):
|
||||
"All TensorInfo protos used in the SignatureDefs must have one of "
|
||||
"the 'encoding' fields (e.g., name or coo_sparse) set: %s"
|
||||
% tensor_info)
|
||||
if tensor_info.dtype is types_pb2.DT_INVALID:
|
||||
if tensor_info.WhichOneof("encoding") == "composite_tensor":
|
||||
for component in tensor_info.composite_tensor.components:
|
||||
self._validate_tensor_info(component)
|
||||
elif tensor_info.dtype == types_pb2.DT_INVALID:
|
||||
raise AssertionError(
|
||||
"All TensorInfo protos used in the SignatureDefs must have the dtype "
|
||||
"field set: %s" % tensor_info)
|
||||
|
@ -1709,7 +1709,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
imported = cycle(root, cycles)
|
||||
self.assertAllClose(2., imported.f(constant_op.constant(1.)))
|
||||
|
||||
def test_ragged_no_signature(self, cycles):
|
||||
def test_ragged(self, cycles):
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
ragged_tensor.RaggedTensorSpec(shape=[None, None], dtype=dtypes.int32)
|
||||
@ -1720,10 +1720,13 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
obj = tracking.AutoTrackable()
|
||||
obj.f = f
|
||||
|
||||
imported = cycle(obj, cycles, signatures={})
|
||||
imported1 = cycle(obj, cycles, signatures={})
|
||||
rt = ragged_factory_ops.constant([[1, 2], [3]])
|
||||
self.assertAllEqual(imported.f(rt), [[2, 3], [4]])
|
||||
self.assertAllEqual(imported1.f(rt), [[2, 3], [4]])
|
||||
|
||||
imported2 = cycle(obj, cycles)
|
||||
rt = ragged_factory_ops.constant([[1, 2], [3]])
|
||||
self.assertAllEqual(imported2.f(rt), [[2, 3], [4]])
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
@parameterized.named_parameters(
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import builder as saved_model_builder
|
||||
from tensorflow.python.saved_model import constants
|
||||
@ -43,6 +44,7 @@ from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.saved_model import main_op
|
||||
from tensorflow.python.saved_model import signature_def_utils
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.saved_model import utils
|
||||
from tensorflow.python.training import saver_test_utils
|
||||
from tensorflow.python.training import training
|
||||
from tensorflow.python.util import compat
|
||||
@ -642,6 +644,19 @@ class SavedModelTest(SavedModelTestBase):
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
self._validate_outputs_tensor_info_accept(builder, tensor_with_coo)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSignatureDefValidationSucceedsWithRagged(self):
|
||||
ragged_tensor = ragged_factory_ops.constant([[1, 2], [3]])
|
||||
tensor_with_ragged = utils.build_tensor_info(ragged_tensor)
|
||||
|
||||
export_dir = self._get_export_dir("test_signature_def_validation_ragged_1")
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
self._validate_inputs_tensor_info_accept(builder, tensor_with_ragged)
|
||||
|
||||
export_dir = self._get_export_dir("test_signature_def_validation_ragged_2")
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
self._validate_outputs_tensor_info_accept(builder, tensor_with_ragged)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAssets(self):
|
||||
export_dir = self._get_export_dir("test_assets")
|
||||
|
@ -22,15 +22,19 @@ import os
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.core.protobuf import struct_pb2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import nested_structure_coder
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -65,6 +69,10 @@ def build_tensor_info(tensor):
|
||||
|
||||
def build_tensor_info_internal(tensor):
|
||||
"""Utility function to build TensorInfo proto from a Tensor."""
|
||||
if (isinstance(tensor, composite_tensor.CompositeTensor) and
|
||||
not isinstance(tensor, sparse_tensor.SparseTensor)):
|
||||
return _build_composite_tensor_info_internal(tensor)
|
||||
|
||||
tensor_info = meta_graph_pb2.TensorInfo(
|
||||
dtype=dtypes.as_dtype(tensor.dtype).as_datatype_enum,
|
||||
tensor_shape=tensor.get_shape().as_proto())
|
||||
@ -77,6 +85,19 @@ def build_tensor_info_internal(tensor):
|
||||
return tensor_info
|
||||
|
||||
|
||||
def _build_composite_tensor_info_internal(tensor):
|
||||
"""Utility function to build TensorInfo proto from a CompositeTensor."""
|
||||
spec = tensor._type_spec # pylint: disable=protected-access
|
||||
tensor_info = meta_graph_pb2.TensorInfo()
|
||||
struct_coder = nested_structure_coder.StructureCoder()
|
||||
spec_proto = struct_coder.encode_structure(spec)
|
||||
tensor_info.composite_tensor.type_spec.CopyFrom(spec_proto.type_spec_value)
|
||||
for component in nest.flatten(tensor, expand_composites=True):
|
||||
tensor_info.composite_tensor.components.add().CopyFrom(
|
||||
build_tensor_info_internal(component))
|
||||
return tensor_info
|
||||
|
||||
|
||||
def build_tensor_info_from_op(op):
|
||||
"""Utility function to build TensorInfo proto from an Op.
|
||||
|
||||
@ -120,17 +141,19 @@ def build_tensor_info_from_op(op):
|
||||
"library as tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info or "
|
||||
"tf.compat.v1.saved_model.get_tensor_from_tensor_info.")
|
||||
def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
|
||||
"""Returns the Tensor or SparseTensor described by a TensorInfo proto.
|
||||
"""Returns the Tensor or CompositeTensor described by a TensorInfo proto.
|
||||
|
||||
Args:
|
||||
tensor_info: A TensorInfo proto describing a Tensor or SparseTensor.
|
||||
tensor_info: A TensorInfo proto describing a Tensor or SparseTensor or
|
||||
CompositeTensor.
|
||||
graph: The tf.Graph in which tensors are looked up. If None, the
|
||||
current default graph is used.
|
||||
import_scope: If not None, names in `tensor_info` are prefixed with this
|
||||
string before lookup.
|
||||
|
||||
Returns:
|
||||
The Tensor or SparseTensor in `graph` described by `tensor_info`.
|
||||
The Tensor or SparseTensor or CompositeTensor in `graph` described by
|
||||
`tensor_info`.
|
||||
|
||||
Raises:
|
||||
KeyError: If `tensor_info` does not correspond to a tensor in `graph`.
|
||||
@ -148,6 +171,14 @@ def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
|
||||
_get_tensor(tensor_info.coo_sparse.indices_tensor_name),
|
||||
_get_tensor(tensor_info.coo_sparse.values_tensor_name),
|
||||
_get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name))
|
||||
elif encoding == "composite_tensor":
|
||||
struct_coder = nested_structure_coder.StructureCoder()
|
||||
spec_proto = struct_pb2.StructuredValue(
|
||||
type_spec_value=tensor_info.composite_tensor.type_spec)
|
||||
spec = struct_coder.decode_proto(spec_proto)
|
||||
components = [_get_tensor(component.name) for component in
|
||||
tensor_info.composite_tensor.components]
|
||||
return spec.from_components(components)
|
||||
else:
|
||||
raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
|
||||
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.core.protobuf import struct_pb2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -28,7 +29,9 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import nested_structure_coder
|
||||
from tensorflow.python.saved_model import utils
|
||||
|
||||
|
||||
@ -82,6 +85,26 @@ class UtilsTest(test.TestCase):
|
||||
self.assertEqual(42, x_tensor_info.tensor_shape.dim[0].size)
|
||||
self.assertEqual(69, x_tensor_info.tensor_shape.dim[1].size)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testBuildTensorInfoRagged(self):
|
||||
x = ragged_factory_ops.constant([[1, 2], [3]])
|
||||
x_tensor_info = utils.build_tensor_info(x)
|
||||
# Check components
|
||||
self.assertEqual(x.values.name,
|
||||
x_tensor_info.composite_tensor.components[0].name)
|
||||
self.assertEqual(types_pb2.DT_INT32,
|
||||
x_tensor_info.composite_tensor.components[0].dtype)
|
||||
self.assertEqual(x.row_splits.name,
|
||||
x_tensor_info.composite_tensor.components[1].name)
|
||||
self.assertEqual(types_pb2.DT_INT64,
|
||||
x_tensor_info.composite_tensor.components[1].dtype)
|
||||
# Check type_spec.
|
||||
struct_coder = nested_structure_coder.StructureCoder()
|
||||
spec_proto = struct_pb2.StructuredValue(
|
||||
type_spec_value=x_tensor_info.composite_tensor.type_spec)
|
||||
spec = struct_coder.decode_proto(spec_proto)
|
||||
self.assertEqual(spec, x._type_spec)
|
||||
|
||||
def testBuildTensorInfoEager(self):
|
||||
x = constant_op.constant(1, name="x")
|
||||
with context.eager_mode(), self.assertRaisesRegexp(
|
||||
|
@ -0,0 +1,20 @@
|
||||
path: "tensorflow.TensorInfo.CompositeTensor"
|
||||
tf_proto {
|
||||
descriptor {
|
||||
name: "CompositeTensor"
|
||||
field {
|
||||
name: "type_spec"
|
||||
number: 1
|
||||
label: LABEL_OPTIONAL
|
||||
type: TYPE_MESSAGE
|
||||
type_name: ".tensorflow.TypeSpecProto"
|
||||
}
|
||||
field {
|
||||
name: "components"
|
||||
number: 2
|
||||
label: LABEL_REPEATED
|
||||
type: TYPE_MESSAGE
|
||||
type_name: ".tensorflow.TensorInfo"
|
||||
}
|
||||
}
|
||||
}
|
@ -17,6 +17,14 @@ tf_proto {
|
||||
type_name: ".tensorflow.TensorInfo.CooSparse"
|
||||
oneof_index: 0
|
||||
}
|
||||
field {
|
||||
name: "composite_tensor"
|
||||
number: 5
|
||||
label: LABEL_OPTIONAL
|
||||
type: TYPE_MESSAGE
|
||||
type_name: ".tensorflow.TensorInfo.CompositeTensor"
|
||||
oneof_index: 0
|
||||
}
|
||||
field {
|
||||
name: "dtype"
|
||||
number: 2
|
||||
@ -52,6 +60,23 @@ tf_proto {
|
||||
type: TYPE_STRING
|
||||
}
|
||||
}
|
||||
nested_type {
|
||||
name: "CompositeTensor"
|
||||
field {
|
||||
name: "type_spec"
|
||||
number: 1
|
||||
label: LABEL_OPTIONAL
|
||||
type: TYPE_MESSAGE
|
||||
type_name: ".tensorflow.TypeSpecProto"
|
||||
}
|
||||
field {
|
||||
name: "components"
|
||||
number: 2
|
||||
label: LABEL_REPEATED
|
||||
type: TYPE_MESSAGE
|
||||
type_name: ".tensorflow.TensorInfo"
|
||||
}
|
||||
}
|
||||
oneof_decl {
|
||||
name: "encoding"
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user