Expand distribute_utils.regroup to work with collections.abc.Mapping-derived containers.

Motivation: This enables user-defined dict-like types inheriting from collections.abc.Mapping to work as return values of functions used with DistributionStrategy.run. Without this change, the entire collection is wrapped in a PerReplica which breaks assumptions of downstream code.
PiperOrigin-RevId: 352064455
Change-Id: Iefda92654fa73d12ab213abe7ea13e0007201f95
This commit is contained in:
RJ Skerry-Ryan 2021-01-15 12:40:10 -08:00 committed by TensorFlower Gardener
parent 3229483cb1
commit 102e1f9855
2 changed files with 27 additions and 2 deletions

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import abc
from tensorflow.python.distribute import tpu_values as tpu_values_lib
from tensorflow.python.distribute import values as values_lib
from tensorflow.python.eager import context
@ -68,10 +70,10 @@ def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False):
else:
return regrouped_tuple
if isinstance(v0, dict):
if isinstance(v0, abc.Mapping):
v0keys = v0.keys()
for v in values[1:]:
assert isinstance(v, dict), ("v[0]: %r v[i]: %r" % (v0, v))
assert isinstance(v, abc.Mapping), ("v[0]: %r v[i]: %r" % (v0, v))
assert set(v.keys()) == set(v0keys), ("v[0].keys: %s v[i].keys: %s" %
(set(v0keys), set(v.keys())))
# Use the actual type in case it is a class inherited from a dict.

View File

@ -85,6 +85,29 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
self._is_per_replica(result["a"], ["a1", "a2"])
self._is_per_replica(result["b"], ["b1", "b2"])
def testRegroupCollectionsMapping(self):
class CollectionsMappingBasedClass(collections.Mapping):
"""Class inherited from collections.Mapping."""
def __init__(self, *args, **kwargs):
self._d = dict(*args, **kwargs)
def __getitem__(self, key):
return self._d.__getitem__(key)
def __iter__(self):
return iter(self._d)
def __len__(self):
return len(self._d)
result = distribute_utils.regroup(
(CollectionsMappingBasedClass(a="a1", b="b1"),
CollectionsMappingBasedClass(a="a2", b="b2")))
self.assertIsInstance(result, CollectionsMappingBasedClass)
self._is_per_replica(result["a"], ["a1", "a2"])
self._is_per_replica(result["b"], ["b1", "b2"])
def testWrapClass(self):
# Normally a mirrored value would be the same across devices, but
# for a test it is convenient to be able to tell the values apart.