[TypeSpec] Support OrderedDict in TypeSpec.most_specific_compatible_type().

Previously, any instance of `OrderedDict` in a nest would silently be converted to a `dict` (because every `OrderedDict` is also a `dict`), but that would yield a `TypeSpec` that is incompatible with the input spec.

PiperOrigin-RevId: 305493326
Change-Id: Id0235901edcda1fdd8e3ca7db15a7449e1d0b45d
This commit is contained in:
Derek Murray 2020-04-08 09:30:39 -07:00 committed by TensorFlower Gardener
parent 81256beba4
commit f394a76871
2 changed files with 23 additions and 0 deletions

View File

@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from absl.testing import parameterized
import numpy as np
@ -63,6 +65,17 @@ class FromTensorSlicesTest(test_base.DatasetTestBase, parameterized.TestCase):
ds = ds.flat_map(lambda x: x)
self.assertDatasetProduces(ds, expected_output=list(range(10)) * 10)
@combinations.generate(test_base.default_test_combinations())
def testFromTensorSlicesDatasetOfOrderedDict(self):
dss = [dataset_ops.Dataset.range(10).map(
lambda x: collections.OrderedDict([("x", x)])) for _ in range(10)]
ds = dataset_ops.Dataset.from_tensor_slices(dss)
ds = ds.flat_map(lambda x: x)
self.assertDatasetProduces(
ds,
expected_output=[collections.OrderedDict([("x", x)])
for x in list(range(10)) * 10])
@combinations.generate(test_base.default_test_combinations())
def testFromTensorSlicesDatasetInFunction(self):
dss = [dataset_ops.Dataset.range(10) for _ in range(10)]

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import abc
import collections
import numpy as np
import six
@ -398,6 +399,15 @@ class TypeSpec(object):
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
return tuple(TypeSpec.__most_specific_compatible_type_serialization(x, y)
for (x, y) in zip(a, b))
if isinstance(a, collections.OrderedDict):
a_keys, b_keys = a.keys(), b.keys()
if len(a) != len(b) or a_keys != b_keys:
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
return collections.OrderedDict([
(k,
TypeSpec.__most_specific_compatible_type_serialization(a[k], b[k]))
for k in a_keys
])
if isinstance(a, dict):
a_keys, b_keys = sorted(a.keys()), sorted(b.keys())
if len(a) != len(b) or a_keys != b_keys: