When a pair of namedtuple objects are compared to determine their most specific compatible type, the namedtuple objects can be treated as normal tuples. This can result in failing to create a dataset consisting of nested namedtuples. This CL fixes this problem.
PiperOrigin-RevId: 352606395 Change-Id: I0d097c97ec397ef4efe319277f10b7e0bc943e51
This commit is contained in:
parent
465f83be84
commit
9aeefce7f2
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import warnings
|
||||
|
||||
from absl.testing import parameterized
|
||||
@ -556,6 +557,16 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaisesOpError(""):
|
||||
self.getDatasetOutput(dataset)
|
||||
|
||||
def testNamedTupleStructure(self):
|
||||
Foo = collections.namedtuple("Foo", ["a", "b"])
|
||||
x = Foo(a=3, b="test")
|
||||
dataset = dataset_ops.Dataset.from_tensors(x)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices([dataset, dataset])
|
||||
self.assertEqual(
|
||||
str(dataset.element_spec),
|
||||
"DatasetSpec(Foo(a=TensorSpec(shape=(), dtype=tf.int32, name=None), "
|
||||
"b=TensorSpec(shape=(), dtype=tf.string, name=None)), TensorShape([]))")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.util import _pywrap_utils
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util.compat import collections_abc
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -405,6 +406,13 @@ class TypeSpec(object):
|
||||
return a.is_compatible_with(b)
|
||||
return a == b
|
||||
|
||||
@staticmethod
|
||||
def __is_named_tuple(t):
|
||||
"""Returns true if the given tuple t is a namedtuple."""
|
||||
return (hasattr(t, "_fields") and
|
||||
isinstance(t._fields, collections_abc.Sequence) and
|
||||
all(isinstance(f, six.string_types) for f in t._fields))
|
||||
|
||||
@staticmethod
|
||||
def __most_specific_compatible_type_serialization(a, b):
|
||||
"""Helper for most_specific_compatible_type.
|
||||
@ -439,6 +447,13 @@ class TypeSpec(object):
|
||||
if isinstance(a, (list, tuple)):
|
||||
if len(a) != len(b):
|
||||
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
|
||||
if TypeSpec.__is_named_tuple(a):
|
||||
if not hasattr(b, "_fields") or not isinstance(
|
||||
b._fields, collections_abc.Sequence) or a._fields != b._fields:
|
||||
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
|
||||
return type(a)(*[
|
||||
TypeSpec.__most_specific_compatible_type_serialization(x, y)
|
||||
for (x, y) in zip(a, b)])
|
||||
return tuple(TypeSpec.__most_specific_compatible_type_serialization(x, y)
|
||||
for (x, y) in zip(a, b))
|
||||
if isinstance(a, collections.OrderedDict):
|
||||
|
@ -19,9 +19,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -31,6 +34,8 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.compat import collections_abc
|
||||
|
||||
|
||||
class TwoTensors(object):
|
||||
@ -134,6 +139,57 @@ type_spec.register_type_spec_from_value_converter(
|
||||
TwoComposites, TwoCompositesSpec.from_value)
|
||||
|
||||
|
||||
class NestOfTensors(object):
|
||||
"""CompositeTensor containing a nest of tensors."""
|
||||
|
||||
def __init__(self, x):
|
||||
self.nest = x
|
||||
|
||||
|
||||
@type_spec.register("tf.NestOfTensorsSpec")
|
||||
class NestOfTensorsSpec(type_spec.TypeSpec):
|
||||
"""A TypeSpec for the NestOfTensors value type."""
|
||||
|
||||
def __init__(self, spec):
|
||||
self.spec = spec
|
||||
|
||||
value_type = property(lambda self: NestOfTensors)
|
||||
_component_specs = property(lambda self: self.spec)
|
||||
|
||||
def _to_components(self, value):
|
||||
return nest.flatten(value)
|
||||
|
||||
def _from_components(self, components):
|
||||
return nest.pack_sequence_as(self.spec, components)
|
||||
|
||||
def _serialize(self):
|
||||
return self.spec
|
||||
|
||||
def __repr__(self):
|
||||
if hasattr(self.spec, "_fields") and isinstance(
|
||||
self.spec._fields, collections_abc.Sequence) and all(
|
||||
isinstance(f, six.string_types) for f in self.spec._fields):
|
||||
return "%s(%r)" % (type(self).__name__, self._serialize())
|
||||
return super(type_spec.TypeSpec, self).__repr__()
|
||||
|
||||
@classmethod
|
||||
def from_value(cls, value):
|
||||
return cls(nest.map_structure(type_spec.type_spec_from_value, value.nest))
|
||||
|
||||
@classmethod
|
||||
def _deserialize(cls, spec):
|
||||
return cls(spec)
|
||||
|
||||
|
||||
type_spec.register_type_spec_from_value_converter(
|
||||
NestOfTensors, NestOfTensorsSpec.from_value)
|
||||
|
||||
_TestNamedTuple = collections.namedtuple("NamedTuple", ["a", "b"])
|
||||
_TestNamedTupleSingleField = collections.namedtuple("SingleField", ["a"])
|
||||
_TestNamedTupleDifferentField = collections.namedtuple("DifferentField",
|
||||
["a", "c"])
|
||||
|
||||
|
||||
class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -271,6 +327,16 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
tensor_spec.TensorSpec([4], name="b")),
|
||||
TwoTensorsSpec([5, 3], dtypes.int32, [3], dtypes.bool,
|
||||
tensor_spec.TensorSpec([4], name=None))),
|
||||
("NamedTuple",
|
||||
NestOfTensorsSpec(_TestNamedTuple(
|
||||
a=tensor_spec.TensorSpec((), dtypes.int32),
|
||||
b=tensor_spec.TensorSpec((), dtypes.int32))),
|
||||
NestOfTensorsSpec(_TestNamedTuple(
|
||||
a=tensor_spec.TensorSpec((), dtypes.int32),
|
||||
b=tensor_spec.TensorSpec((), dtypes.int32))),
|
||||
NestOfTensorsSpec(_TestNamedTuple(
|
||||
a=tensor_spec.TensorSpec((), dtypes.int32),
|
||||
b=tensor_spec.TensorSpec((), dtypes.int32)))),
|
||||
)
|
||||
def testMostSpecificCompatibleType(self, v1, v2, expected):
|
||||
self.assertEqual(v1.most_specific_compatible_type(v2), expected)
|
||||
@ -290,6 +356,58 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
v2.most_specific_compatible_type(v1)
|
||||
|
||||
def testMostSpecificCompatibleTypeNamedTupleIsNotTuple(self):
|
||||
named_tuple_spec_a = NestOfTensorsSpec.from_value(NestOfTensors(
|
||||
_TestNamedTuple(a=1, b="aaa")))
|
||||
named_tuple_spec_b = NestOfTensorsSpec.from_value(NestOfTensors(
|
||||
_TestNamedTuple(a=2, b="bbb")))
|
||||
named_tuple_spec_c = NestOfTensorsSpec.from_value(NestOfTensors(
|
||||
_TestNamedTuple(a=3, b="ccc")))
|
||||
normal_tuple_spec = NestOfTensorsSpec.from_value(NestOfTensors((2, "bbb")))
|
||||
result_a_b = named_tuple_spec_a.most_specific_compatible_type(
|
||||
named_tuple_spec_b)
|
||||
result_b_a = named_tuple_spec_b.most_specific_compatible_type(
|
||||
named_tuple_spec_a)
|
||||
self.assertEqual(repr(result_a_b), repr(named_tuple_spec_c))
|
||||
self.assertEqual(repr(result_b_a), repr(named_tuple_spec_c))
|
||||
# Test that spec of named tuple is not equal to spec of normal tuple.
|
||||
self.assertNotEqual(repr(result_a_b), repr(normal_tuple_spec))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("IncompatibleDtype",
|
||||
NestOfTensorsSpec(_TestNamedTuple(
|
||||
a=tensor_spec.TensorSpec((), dtypes.int32),
|
||||
b=tensor_spec.TensorSpec((), dtypes.bool))),
|
||||
NestOfTensorsSpec(_TestNamedTuple(
|
||||
a=tensor_spec.TensorSpec((), dtypes.int32),
|
||||
b=tensor_spec.TensorSpec((), dtypes.int32)))),
|
||||
("DifferentTupleSize",
|
||||
NestOfTensorsSpec(_TestNamedTuple(
|
||||
a=tensor_spec.TensorSpec((), dtypes.int32),
|
||||
b=tensor_spec.TensorSpec((), dtypes.bool))),
|
||||
NestOfTensorsSpec(_TestNamedTupleSingleField(
|
||||
a=tensor_spec.TensorSpec((), dtypes.int32)))),
|
||||
("DifferentFieldName",
|
||||
NestOfTensorsSpec(_TestNamedTuple(
|
||||
a=tensor_spec.TensorSpec((), dtypes.int32),
|
||||
b=tensor_spec.TensorSpec((), dtypes.int32))),
|
||||
NestOfTensorsSpec(_TestNamedTupleDifferentField(
|
||||
a=tensor_spec.TensorSpec((), dtypes.int32),
|
||||
c=tensor_spec.TensorSpec((), dtypes.int32)))),
|
||||
("NamedTupleAndTuple",
|
||||
NestOfTensorsSpec(_TestNamedTuple(
|
||||
a=tensor_spec.TensorSpec((), dtypes.int32),
|
||||
b=tensor_spec.TensorSpec((), dtypes.int32))),
|
||||
NestOfTensorsSpec((
|
||||
tensor_spec.TensorSpec((), dtypes.int32),
|
||||
tensor_spec.TensorSpec((), dtypes.int32)))),
|
||||
)
|
||||
def testMostSpecificCompatibleTypeForNamedTuplesException(self, v1, v2):
|
||||
with self.assertRaises(ValueError):
|
||||
v1.most_specific_compatible_type(v2)
|
||||
with self.assertRaises(ValueError):
|
||||
v2.most_specific_compatible_type(v1)
|
||||
|
||||
def toTensorList(self):
|
||||
value = TwoTensors([1, 2, 3], [1.0, 2.0], "red")
|
||||
spec = TwoTensorsSpec.from_value(value)
|
||||
|
Loading…
Reference in New Issue
Block a user