Get namedtuple _make method from instance instead of class.

It is possible that v0 is a _TupleWrapper object wrapping a namedtuple instead of a namedtuple itself. The class _TupleWrapper does not have the class method _make, but it should have the instance method _make which calls the namedtuple class method.

Instances of a namedtuple also have the instance method _make which calls the relevant class method. This will allow the code to work for both cases.

PiperOrigin-RevId: 325815594
Change-Id: I209f8ec5a8617f72183e4c12937b4429321e7b4f
This commit is contained in:
A. Unique TensorFlower 2020-08-10 09:05:54 -07:00 committed by TensorFlower Gardener
parent 9173656622
commit 2c6f7e24dd
3 changed files with 13 additions and 2 deletions

View File

@ -1197,6 +1197,7 @@ distribute_py_test(
"//tensorflow/python/eager:test",
"//tensorflow/python/saved_model/model_utils:mode_keys",
"@absl_py//absl/testing:parameterized",
"@wrapt",
],
)

View File

@ -63,8 +63,8 @@ def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False):
if hasattr(v0, "_fields"):
# This tuple is in fact a namedtuple! Create a new namedtuple instance
# and initialize it with the regrouped values:
assert hasattr(type(v0), "_make")
return type(v0)._make(regrouped_tuple)
assert hasattr(v0, "_make")
return v0._make(regrouped_tuple)
else:
return regrouped_tuple

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import collections
from absl.testing import parameterized
import wrapt
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_utils
@ -211,6 +212,15 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
distribute_utils.select_replica(
device_id, merged_estimator_spec))
def testWrappedNamedTuple(self):
Point = collections.namedtuple("Point", ["x", "y"])
point1 = Point(x=0, y=2)
point2 = Point(x=1, y=3)
wrapped1 = wrapt.ObjectProxy(point1)
wrapped2 = wrapt.ObjectProxy(point2)
result = distribute_utils.regroup([wrapped1, wrapped2])
self.assertEqual(result.x.values, (0, 1))
self.assertEqual(result.y.values, (2, 3))
if __name__ == "__main__":
test.main()