Get namedtuple _make method from instance instead of class.
It is possible that v0 is a _TupleWrapper object wrapping a namedtuple instead of a namedtuple itself. The class _TupleWrapper does not have the class method _make, but it should have the instance method _make which calls the namedtuple class method. Instances of a namedtuple also have the instance method _make which calls the relevant class method. This will allow the code to work for both cases. PiperOrigin-RevId: 325815594 Change-Id: I209f8ec5a8617f72183e4c12937b4429321e7b4f
This commit is contained in:
parent
9173656622
commit
2c6f7e24dd
tensorflow/python/distribute
@ -1197,6 +1197,7 @@ distribute_py_test(
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/saved_model/model_utils:mode_keys",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"@wrapt",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -63,8 +63,8 @@ def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False):
|
||||
if hasattr(v0, "_fields"):
|
||||
# This tuple is in fact a namedtuple! Create a new namedtuple instance
|
||||
# and initialize it with the regrouped values:
|
||||
assert hasattr(type(v0), "_make")
|
||||
return type(v0)._make(regrouped_tuple)
|
||||
assert hasattr(v0, "_make")
|
||||
return v0._make(regrouped_tuple)
|
||||
else:
|
||||
return regrouped_tuple
|
||||
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import collections
|
||||
|
||||
from absl.testing import parameterized
|
||||
import wrapt
|
||||
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
@ -211,6 +212,15 @@ class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
distribute_utils.select_replica(
|
||||
device_id, merged_estimator_spec))
|
||||
|
||||
def testWrappedNamedTuple(self):
|
||||
Point = collections.namedtuple("Point", ["x", "y"])
|
||||
point1 = Point(x=0, y=2)
|
||||
point2 = Point(x=1, y=3)
|
||||
wrapped1 = wrapt.ObjectProxy(point1)
|
||||
wrapped2 = wrapt.ObjectProxy(point2)
|
||||
result = distribute_utils.regroup([wrapped1, wrapped2])
|
||||
self.assertEqual(result.x.values, (0, 1))
|
||||
self.assertEqual(result.y.values, (2, 3))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user