When a pair of namedtuple objects are compared to determine their most specific compatible type, the namedtuple objects can be treated as normal tuples. This can result in failing to create a dataset consisting of nested namedtuples. This CL fixes this problem.

PiperOrigin-RevId: 352606395
Change-Id: I0d097c97ec397ef4efe319277f10b7e0bc943e51
This commit is contained in:
A. Unique TensorFlower 2021-01-19 10:57:33 -08:00 committed by TensorFlower Gardener
parent 465f83be84
commit 9aeefce7f2
3 changed files with 144 additions and 0 deletions

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
import warnings import warnings
from absl.testing import parameterized from absl.testing import parameterized
@ -556,6 +557,16 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaisesOpError(""): with self.assertRaisesOpError(""):
self.getDatasetOutput(dataset) self.getDatasetOutput(dataset)
def testNamedTupleStructure(self):
Foo = collections.namedtuple("Foo", ["a", "b"])
x = Foo(a=3, b="test")
dataset = dataset_ops.Dataset.from_tensors(x)
dataset = dataset_ops.Dataset.from_tensor_slices([dataset, dataset])
self.assertEqual(
str(dataset.element_spec),
"DatasetSpec(Foo(a=TensorSpec(shape=(), dtype=tf.int32, name=None), "
"b=TensorSpec(shape=(), dtype=tf.string, name=None)), TensorShape([]))")
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -33,6 +33,7 @@ from tensorflow.python.util import _pywrap_utils
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_decorator
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.lazy_loader import LazyLoader from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -405,6 +406,13 @@ class TypeSpec(object):
return a.is_compatible_with(b) return a.is_compatible_with(b)
return a == b return a == b
@staticmethod
def __is_named_tuple(t):
"""Returns true if the given tuple t is a namedtuple."""
return (hasattr(t, "_fields") and
isinstance(t._fields, collections_abc.Sequence) and
all(isinstance(f, six.string_types) for f in t._fields))
@staticmethod @staticmethod
def __most_specific_compatible_type_serialization(a, b): def __most_specific_compatible_type_serialization(a, b):
"""Helper for most_specific_compatible_type. """Helper for most_specific_compatible_type.
@ -439,6 +447,13 @@ class TypeSpec(object):
if isinstance(a, (list, tuple)): if isinstance(a, (list, tuple)):
if len(a) != len(b): if len(a) != len(b):
raise ValueError("Types are not compatible: %r vs %r" % (a, b)) raise ValueError("Types are not compatible: %r vs %r" % (a, b))
if TypeSpec.__is_named_tuple(a):
if not hasattr(b, "_fields") or not isinstance(
b._fields, collections_abc.Sequence) or a._fields != b._fields:
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
return type(a)(*[
TypeSpec.__most_specific_compatible_type_serialization(x, y)
for (x, y) in zip(a, b)])
return tuple(TypeSpec.__most_specific_compatible_type_serialization(x, y) return tuple(TypeSpec.__most_specific_compatible_type_serialization(x, y)
for (x, y) in zip(a, b)) for (x, y) in zip(a, b))
if isinstance(a, collections.OrderedDict): if isinstance(a, collections.OrderedDict):

View File

@ -19,9 +19,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import six
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -31,6 +34,8 @@ from tensorflow.python.framework import test_util
from tensorflow.python.framework import type_spec from tensorflow.python.framework import type_spec
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
class TwoTensors(object): class TwoTensors(object):
@ -134,6 +139,57 @@ type_spec.register_type_spec_from_value_converter(
TwoComposites, TwoCompositesSpec.from_value) TwoComposites, TwoCompositesSpec.from_value)
class NestOfTensors(object):
"""CompositeTensor containing a nest of tensors."""
def __init__(self, x):
self.nest = x
@type_spec.register("tf.NestOfTensorsSpec")
class NestOfTensorsSpec(type_spec.TypeSpec):
"""A TypeSpec for the NestOfTensors value type."""
def __init__(self, spec):
self.spec = spec
value_type = property(lambda self: NestOfTensors)
_component_specs = property(lambda self: self.spec)
def _to_components(self, value):
return nest.flatten(value)
def _from_components(self, components):
return nest.pack_sequence_as(self.spec, components)
def _serialize(self):
return self.spec
def __repr__(self):
if hasattr(self.spec, "_fields") and isinstance(
self.spec._fields, collections_abc.Sequence) and all(
isinstance(f, six.string_types) for f in self.spec._fields):
return "%s(%r)" % (type(self).__name__, self._serialize())
return super(type_spec.TypeSpec, self).__repr__()
@classmethod
def from_value(cls, value):
return cls(nest.map_structure(type_spec.type_spec_from_value, value.nest))
@classmethod
def _deserialize(cls, spec):
return cls(spec)
type_spec.register_type_spec_from_value_converter(
NestOfTensors, NestOfTensorsSpec.from_value)
_TestNamedTuple = collections.namedtuple("NamedTuple", ["a", "b"])
_TestNamedTupleSingleField = collections.namedtuple("SingleField", ["a"])
_TestNamedTupleDifferentField = collections.namedtuple("DifferentField",
["a", "c"])
class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase): class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@parameterized.named_parameters( @parameterized.named_parameters(
@ -271,6 +327,16 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
tensor_spec.TensorSpec([4], name="b")), tensor_spec.TensorSpec([4], name="b")),
TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool, TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool,
tensor_spec.TensorSpec([4], name=None))), tensor_spec.TensorSpec([4], name=None))),
("NamedTuple",
NestOfTensorsSpec(_TestNamedTuple(
a=tensor_spec.TensorSpec((), dtypes.int32),
b=tensor_spec.TensorSpec((), dtypes.int32))),
NestOfTensorsSpec(_TestNamedTuple(
a=tensor_spec.TensorSpec((), dtypes.int32),
b=tensor_spec.TensorSpec((), dtypes.int32))),
NestOfTensorsSpec(_TestNamedTuple(
a=tensor_spec.TensorSpec((), dtypes.int32),
b=tensor_spec.TensorSpec((), dtypes.int32)))),
) )
def testMostSpecificCompatibleType(self, v1, v2, expected): def testMostSpecificCompatibleType(self, v1, v2, expected):
self.assertEqual(v1.most_specific_compatible_type(v2), expected) self.assertEqual(v1.most_specific_compatible_type(v2), expected)
@ -290,6 +356,58 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
v2.most_specific_compatible_type(v1) v2.most_specific_compatible_type(v1)
def testMostSpecificCompatibleTypeNamedTupleIsNotTuple(self):
named_tuple_spec_a = NestOfTensorsSpec.from_value(NestOfTensors(
_TestNamedTuple(a=1, b="aaa")))
named_tuple_spec_b = NestOfTensorsSpec.from_value(NestOfTensors(
_TestNamedTuple(a=2, b="bbb")))
named_tuple_spec_c = NestOfTensorsSpec.from_value(NestOfTensors(
_TestNamedTuple(a=3, b="ccc")))
normal_tuple_spec = NestOfTensorsSpec.from_value(NestOfTensors((2, "bbb")))
result_a_b = named_tuple_spec_a.most_specific_compatible_type(
named_tuple_spec_b)
result_b_a = named_tuple_spec_b.most_specific_compatible_type(
named_tuple_spec_a)
self.assertEqual(repr(result_a_b), repr(named_tuple_spec_c))
self.assertEqual(repr(result_b_a), repr(named_tuple_spec_c))
# Test that spec of named tuple is not equal to spec of normal tuple.
self.assertNotEqual(repr(result_a_b), repr(normal_tuple_spec))
@parameterized.named_parameters(
("IncompatibleDtype",
NestOfTensorsSpec(_TestNamedTuple(
a=tensor_spec.TensorSpec((), dtypes.int32),
b=tensor_spec.TensorSpec((), dtypes.bool))),
NestOfTensorsSpec(_TestNamedTuple(
a=tensor_spec.TensorSpec((), dtypes.int32),
b=tensor_spec.TensorSpec((), dtypes.int32)))),
("DifferentTupleSize",
NestOfTensorsSpec(_TestNamedTuple(
a=tensor_spec.TensorSpec((), dtypes.int32),
b=tensor_spec.TensorSpec((), dtypes.bool))),
NestOfTensorsSpec(_TestNamedTupleSingleField(
a=tensor_spec.TensorSpec((), dtypes.int32)))),
("DifferentFieldName",
NestOfTensorsSpec(_TestNamedTuple(
a=tensor_spec.TensorSpec((), dtypes.int32),
b=tensor_spec.TensorSpec((), dtypes.int32))),
NestOfTensorsSpec(_TestNamedTupleDifferentField(
a=tensor_spec.TensorSpec((), dtypes.int32),
c=tensor_spec.TensorSpec((), dtypes.int32)))),
("NamedTupleAndTuple",
NestOfTensorsSpec(_TestNamedTuple(
a=tensor_spec.TensorSpec((), dtypes.int32),
b=tensor_spec.TensorSpec((), dtypes.int32))),
NestOfTensorsSpec((
tensor_spec.TensorSpec((), dtypes.int32),
tensor_spec.TensorSpec((), dtypes.int32)))),
)
def testMostSpecificCompatibleTypeForNamedTuplesException(self, v1, v2):
with self.assertRaises(ValueError):
v1.most_specific_compatible_type(v2)
with self.assertRaises(ValueError):
v2.most_specific_compatible_type(v1)
def toTensorList(self): def toTensorList(self):
value = TwoTensors([1, 2, 3], [1.0, 2.0], "red") value = TwoTensors([1, 2, 3], [1.0, 2.0], "red")
spec = TwoTensorsSpec.from_value(value) spec = TwoTensorsSpec.from_value(value)