Expand distribute_utils.regroup to work with collections.abc.Mapping-derived containers.
Motivation: This enables user-defined dict-like types inheriting from collections.abc.Mapping to work as return values of functions used with DistributionStrategy.run. Without this change, the entire collection is wrapped in a PerReplica which breaks assumptions of downstream code. PiperOrigin-RevId: 352064455 Change-Id: Iefda92654fa73d12ab213abe7ea13e0007201f95
This commit is contained in:
parent
3229483cb1
commit
102e1f9855
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import abc
|
||||
|
||||
from tensorflow.python.distribute import tpu_values as tpu_values_lib
|
||||
from tensorflow.python.distribute import values as values_lib
|
||||
from tensorflow.python.eager import context
|
||||
@ -68,10 +70,10 @@ def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False):
|
||||
else:
|
||||
return regrouped_tuple
|
||||
|
||||
if isinstance(v0, dict):
|
||||
if isinstance(v0, abc.Mapping):
|
||||
v0keys = v0.keys()
|
||||
for v in values[1:]:
|
||||
assert isinstance(v, dict), ("v[0]: %r v[i]: %r" % (v0, v))
|
||||
assert isinstance(v, abc.Mapping), ("v[0]: %r v[i]: %r" % (v0, v))
|
||||
assert set(v.keys()) == set(v0keys), ("v[0].keys: %s v[i].keys: %s" %
|
||||
(set(v0keys), set(v.keys())))
|
||||
# Use the actual type in case it is a class inherited from a dict.
|
||||
|
@ -85,6 +85,29 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
self._is_per_replica(result["a"], ["a1", "a2"])
|
||||
self._is_per_replica(result["b"], ["b1", "b2"])
|
||||
|
||||
def testRegroupCollectionsMapping(self):
|
||||
class CollectionsMappingBasedClass(collections.Mapping):
|
||||
"""Class inherited from collections.Mapping."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._d = dict(*args, **kwargs)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._d.__getitem__(key)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._d)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._d)
|
||||
|
||||
result = distribute_utils.regroup(
|
||||
(CollectionsMappingBasedClass(a="a1", b="b1"),
|
||||
CollectionsMappingBasedClass(a="a2", b="b2")))
|
||||
self.assertIsInstance(result, CollectionsMappingBasedClass)
|
||||
self._is_per_replica(result["a"], ["a1", "a2"])
|
||||
self._is_per_replica(result["b"], ["b1", "b2"])
|
||||
|
||||
def testWrapClass(self):
|
||||
# Normally a mirrored value would be the same across devices, but
|
||||
# for a test it is convenient to be able to tell the values apart.
|
||||
|
Loading…
Reference in New Issue
Block a user