diff --git a/tensorflow/python/distribute/distribute_utils.py b/tensorflow/python/distribute/distribute_utils.py index 1ecf3c065c4..633cc2a3582 100644 --- a/tensorflow/python/distribute/distribute_utils.py +++ b/tensorflow/python/distribute/distribute_utils.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import abc + from tensorflow.python.distribute import tpu_values as tpu_values_lib from tensorflow.python.distribute import values as values_lib from tensorflow.python.eager import context @@ -68,10 +70,10 @@ def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False): else: return regrouped_tuple - if isinstance(v0, dict): + if isinstance(v0, abc.Mapping): v0keys = v0.keys() for v in values[1:]: - assert isinstance(v, dict), ("v[0]: %r v[i]: %r" % (v0, v)) + assert isinstance(v, abc.Mapping), ("v[0]: %r v[i]: %r" % (v0, v)) assert set(v.keys()) == set(v0keys), ("v[0].keys: %s v[i].keys: %s" % (set(v0keys), set(v.keys()))) # Use the actual type in case it is a class inherited from a dict. diff --git a/tensorflow/python/distribute/distribute_utils_test.py b/tensorflow/python/distribute/distribute_utils_test.py index 22ea6264d07..fd63c4949ef 100644 --- a/tensorflow/python/distribute/distribute_utils_test.py +++ b/tensorflow/python/distribute/distribute_utils_test.py @@ -85,6 +85,29 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase): self._is_per_replica(result["a"], ["a1", "a2"]) self._is_per_replica(result["b"], ["b1", "b2"]) + def testRegroupCollectionsMapping(self): + class CollectionsMappingBasedClass(collections.Mapping): + """Class inherited from collections.Mapping.""" + + def __init__(self, *args, **kwargs): + self._d = dict(*args, **kwargs) + + def __getitem__(self, key): + return self._d.__getitem__(key) + + def __iter__(self): + return iter(self._d) + + def __len__(self): + return len(self._d) + + result = distribute_utils.regroup( + (CollectionsMappingBasedClass(a="a1", b="b1"), + CollectionsMappingBasedClass(a="a2", b="b2"))) + self.assertIsInstance(result, CollectionsMappingBasedClass) + self._is_per_replica(result["a"], ["a1", "a2"]) + self._is_per_replica(result["b"], ["b1", "b2"]) + def testWrapClass(self): # Normally a mirrored value would be the same across devices, but # for a test it is convenient to be able to tell the values apart.