[TypeSpec] Support OrderedDict
in TypeSpec.most_specific_compatible_type()
.
Previously, any instance of `OrderedDict` in a nest would silently be converted to a `dict` (because every `OrderedDict` is also a `dict`), but that would yield a `TypeSpec` that is incompatible with the input spec. PiperOrigin-RevId: 305493326 Change-Id: Id0235901edcda1fdd8e3ca7db15a7449e1d0b45d
This commit is contained in:
parent
81256beba4
commit
f394a76871
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
@ -63,6 +65,17 @@ class FromTensorSlicesTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ds = ds.flat_map(lambda x: x)
|
||||
self.assertDatasetProduces(ds, expected_output=list(range(10)) * 10)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromTensorSlicesDatasetOfOrderedDict(self):
|
||||
dss = [dataset_ops.Dataset.range(10).map(
|
||||
lambda x: collections.OrderedDict([("x", x)])) for _ in range(10)]
|
||||
ds = dataset_ops.Dataset.from_tensor_slices(dss)
|
||||
ds = ds.flat_map(lambda x: x)
|
||||
self.assertDatasetProduces(
|
||||
ds,
|
||||
expected_output=[collections.OrderedDict([("x", x)])
|
||||
for x in list(range(10)) * 10])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testFromTensorSlicesDatasetInFunction(self):
|
||||
dss = [dataset_ops.Dataset.range(10) for _ in range(10)]
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
@ -398,6 +399,15 @@ class TypeSpec(object):
|
||||
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
|
||||
return tuple(TypeSpec.__most_specific_compatible_type_serialization(x, y)
|
||||
for (x, y) in zip(a, b))
|
||||
if isinstance(a, collections.OrderedDict):
|
||||
a_keys, b_keys = a.keys(), b.keys()
|
||||
if len(a) != len(b) or a_keys != b_keys:
|
||||
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
|
||||
return collections.OrderedDict([
|
||||
(k,
|
||||
TypeSpec.__most_specific_compatible_type_serialization(a[k], b[k]))
|
||||
for k in a_keys
|
||||
])
|
||||
if isinstance(a, dict):
|
||||
a_keys, b_keys = sorted(a.keys()), sorted(b.keys())
|
||||
if len(a) != len(b) or a_keys != b_keys:
|
||||
|
Loading…
Reference in New Issue
Block a user