diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py index 8805633deeb..fca8063891f 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py @@ -73,7 +73,21 @@ def _args_to_gan_model(loss_fn): default_args_dict = dict(zip(args_with_defaults, defaults)) def new_loss_fn(gan_model, **kwargs): # pylint:disable=missing-docstring - gan_model_dict = gan_model._asdict() + def _asdict(namedtuple): + """Returns a namedtuple as a dictionary. + + This is required because `_asdict()` in Python 3.x.x is broken in classes + that inherit from `collections.namedtuple`. See + https://bugs.python.org/issue24931 for more details. + + Args: + namedtuple: An object that inherits from `collections.namedtuple`. + + Returns: + A dictionary version of the tuple. + """ + return {k: getattr(namedtuple, k) for k in namedtuple._fields} + gan_model_dict = _asdict(gan_model) # Make sure non-tuple required args are supplied. args_from_tuple = set(argspec.args).intersection(set(gan_model._fields)) diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py index f65b20d0b57..215b15ef691 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py @@ -79,6 +79,21 @@ class ArgsToGanModelTest(test.TestCase): # If `arg3` were not set properly, this value would be different. self.assertEqual(-1 + 2 * 2 + 3 * 4, loss) + def test_works_with_child_classes(self): + """`args_to_gan_model` should work with classes derived from namedtuple.""" + tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg2']) + + class InheritedType(tuple_type): + pass + def args_loss(arg1, arg2, arg3=3): + return arg1 + 2 * arg2 + 3 * arg3 + + loss_fn = tfgan_losses._args_to_gan_model(args_loss) + loss = loss_fn(InheritedType(arg1=-1, arg2=2), arg3=4) + + # If `arg3` were not set properly, this value would be different. + self.assertEqual(-1 + 2 * 2 + 3 * 4, loss) + class ConsistentLossesTest(test.TestCase):