When a pair of namedtuple objects are compared to determine their most specific compatible type, the namedtuple objects can be treated as normal tuples. This can result in failing to create a dataset consisting of nested namedtuples. This CL fixes this problem.

PiperOrigin-RevId: 352606395
Change-Id: I0d097c97ec397ef4efe319277f10b7e0bc943e51
This commit is contained in:
A. Unique TensorFlower 2021-01-19 10:57:33 -08:00 committed by TensorFlower Gardener
parent 465f83be84
commit 9aeefce7f2
3 changed files with 144 additions and 0 deletions

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import warnings
from absl.testing import parameterized
@ -556,6 +557,16 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaisesOpError(""):
self.getDatasetOutput(dataset)
def testNamedTupleStructure(self):
Foo = collections.namedtuple("Foo", ["a", "b"])
x = Foo(a=3, b="test")
dataset = dataset_ops.Dataset.from_tensors(x)
dataset = dataset_ops.Dataset.from_tensor_slices([dataset, dataset])
self.assertEqual(
str(dataset.element_spec),
"DatasetSpec(Foo(a=TensorSpec(shape=(), dtype=tf.int32, name=None), "
"b=TensorSpec(shape=(), dtype=tf.string, name=None)), TensorShape([]))")
if __name__ == "__main__":
test.main()

View File

@ -33,6 +33,7 @@ from tensorflow.python.util import _pywrap_utils
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
@ -405,6 +406,13 @@ class TypeSpec(object):
return a.is_compatible_with(b)
return a == b
@staticmethod
def __is_named_tuple(t):
"""Returns true if the given tuple t is a namedtuple."""
return (hasattr(t, "_fields") and
isinstance(t._fields, collections_abc.Sequence) and
all(isinstance(f, six.string_types) for f in t._fields))
@staticmethod
def __most_specific_compatible_type_serialization(a, b):
"""Helper for most_specific_compatible_type.
@ -439,6 +447,13 @@ class TypeSpec(object):
if isinstance(a, (list, tuple)):
if len(a) != len(b):
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
if TypeSpec.__is_named_tuple(a):
if not hasattr(b, "_fields") or not isinstance(
b._fields, collections_abc.Sequence) or a._fields != b._fields:
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
return type(a)(*[
TypeSpec.__most_specific_compatible_type_serialization(x, y)
for (x, y) in zip(a, b)])
return tuple(TypeSpec.__most_specific_compatible_type_serialization(x, y)
for (x, y) in zip(a, b))
if isinstance(a, collections.OrderedDict):

View File

@ -19,9 +19,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from absl.testing import parameterized
import numpy as np
import six
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -31,6 +34,8 @@ from tensorflow.python.framework import test_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import googletest
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
class TwoTensors(object):
@ -134,6 +139,57 @@ type_spec.register_type_spec_from_value_converter(
TwoComposites, TwoCompositesSpec.from_value)
class NestOfTensors(object):
"""CompositeTensor containing a nest of tensors."""
def __init__(self, x):
self.nest = x
@type_spec.register("tf.NestOfTensorsSpec")
class NestOfTensorsSpec(type_spec.TypeSpec):
"""A TypeSpec for the NestOfTensors value type."""
def __init__(self, spec):
self.spec = spec
value_type = property(lambda self: NestOfTensors)
_component_specs = property(lambda self: self.spec)
def _to_components(self, value):
return nest.flatten(value)
def _from_components(self, components):
return nest.pack_sequence_as(self.spec, components)
def _serialize(self):
return self.spec
def __repr__(self):
if hasattr(self.spec, "_fields") and isinstance(
self.spec._fields, collections_abc.Sequence) and all(
isinstance(f, six.string_types) for f in self.spec._fields):
return "%s(%r)" % (type(self).__name__, self._serialize())
return super(type_spec.TypeSpec, self).__repr__()
@classmethod
def from_value(cls, value):
return cls(nest.map_structure(type_spec.type_spec_from_value, value.nest))
@classmethod
def _deserialize(cls, spec):
return cls(spec)
type_spec.register_type_spec_from_value_converter(
NestOfTensors, NestOfTensorsSpec.from_value)
_TestNamedTuple = collections.namedtuple("NamedTuple", ["a", "b"])
_TestNamedTupleSingleField = collections.namedtuple("SingleField", ["a"])
_TestNamedTupleDifferentField = collections.namedtuple("DifferentField",
["a", "c"])
class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@parameterized.named_parameters(
@ -271,6 +327,16 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
tensor_spec.TensorSpec([4], name="b")),
TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool,
tensor_spec.TensorSpec([4], name=None))),
("NamedTuple",
NestOfTensorsSpec(_TestNamedTuple(
a=tensor_spec.TensorSpec((), dtypes.int32),
b=tensor_spec.TensorSpec((), dtypes.int32))),
NestOfTensorsSpec(_TestNamedTuple(
a=tensor_spec.TensorSpec((), dtypes.int32),
b=tensor_spec.TensorSpec((), dtypes.int32))),
NestOfTensorsSpec(_TestNamedTuple(
a=tensor_spec.TensorSpec((), dtypes.int32),
b=tensor_spec.TensorSpec((), dtypes.int32)))),
)
def testMostSpecificCompatibleType(self, v1, v2, expected):
self.assertEqual(v1.most_specific_compatible_type(v2), expected)
@ -290,6 +356,58 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
with self.assertRaises(ValueError):
v2.most_specific_compatible_type(v1)
def testMostSpecificCompatibleTypeNamedTupleIsNotTuple(self):
named_tuple_spec_a = NestOfTensorsSpec.from_value(NestOfTensors(
_TestNamedTuple(a=1, b="aaa")))
named_tuple_spec_b = NestOfTensorsSpec.from_value(NestOfTensors(
_TestNamedTuple(a=2, b="bbb")))
named_tuple_spec_c = NestOfTensorsSpec.from_value(NestOfTensors(
_TestNamedTuple(a=3, b="ccc")))
normal_tuple_spec = NestOfTensorsSpec.from_value(NestOfTensors((2, "bbb")))
result_a_b = named_tuple_spec_a.most_specific_compatible_type(
named_tuple_spec_b)
result_b_a = named_tuple_spec_b.most_specific_compatible_type(
named_tuple_spec_a)
self.assertEqual(repr(result_a_b), repr(named_tuple_spec_c))
self.assertEqual(repr(result_b_a), repr(named_tuple_spec_c))
# Test that spec of named tuple is not equal to spec of normal tuple.
self.assertNotEqual(repr(result_a_b), repr(normal_tuple_spec))
@parameterized.named_parameters(
("IncompatibleDtype",
NestOfTensorsSpec(_TestNamedTuple(
a=tensor_spec.TensorSpec((), dtypes.int32),
b=tensor_spec.TensorSpec((), dtypes.bool))),
NestOfTensorsSpec(_TestNamedTuple(
a=tensor_spec.TensorSpec((), dtypes.int32),
b=tensor_spec.TensorSpec((), dtypes.int32)))),
("DifferentTupleSize",
NestOfTensorsSpec(_TestNamedTuple(
a=tensor_spec.TensorSpec((), dtypes.int32),
b=tensor_spec.TensorSpec((), dtypes.bool))),
NestOfTensorsSpec(_TestNamedTupleSingleField(
a=tensor_spec.TensorSpec((), dtypes.int32)))),
("DifferentFieldName",
NestOfTensorsSpec(_TestNamedTuple(
a=tensor_spec.TensorSpec((), dtypes.int32),
b=tensor_spec.TensorSpec((), dtypes.int32))),
NestOfTensorsSpec(_TestNamedTupleDifferentField(
a=tensor_spec.TensorSpec((), dtypes.int32),
c=tensor_spec.TensorSpec((), dtypes.int32)))),
("NamedTupleAndTuple",
NestOfTensorsSpec(_TestNamedTuple(
a=tensor_spec.TensorSpec((), dtypes.int32),
b=tensor_spec.TensorSpec((), dtypes.int32))),
NestOfTensorsSpec((
tensor_spec.TensorSpec((), dtypes.int32),
tensor_spec.TensorSpec((), dtypes.int32)))),
)
def testMostSpecificCompatibleTypeForNamedTuplesException(self, v1, v2):
with self.assertRaises(ValueError):
v1.most_specific_compatible_type(v2)
with self.assertRaises(ValueError):
v2.most_specific_compatible_type(v1)
def toTensorList(self):
value = TwoTensors([1, 2, 3], [1.0, 2.0], "red")
spec = TwoTensorsSpec.from_value(value)