Fix tuple_losses bug caused by Python bug.

PiperOrigin-RevId: 168386341
This commit is contained in:
A. Unique TensorFlower 2017-09-12 09:28:51 -07:00 committed by TensorFlower Gardener
parent 7a8c63da36
commit bc6b60f1bc
2 changed files with 30 additions and 1 deletions

View File

@ -73,7 +73,21 @@ def _args_to_gan_model(loss_fn):
default_args_dict = dict(zip(args_with_defaults, defaults)) default_args_dict = dict(zip(args_with_defaults, defaults))
def new_loss_fn(gan_model, **kwargs): # pylint:disable=missing-docstring def new_loss_fn(gan_model, **kwargs): # pylint:disable=missing-docstring
gan_model_dict = gan_model._asdict() def _asdict(namedtuple):
"""Returns a namedtuple as a dictionary.
This is required because `_asdict()` in Python 3.x.x is broken in classes
that inherit from `collections.namedtuple`. See
https://bugs.python.org/issue24931 for more details.
Args:
namedtuple: An object that inherits from `collections.namedtuple`.
Returns:
A dictionary version of the tuple.
"""
return {k: getattr(namedtuple, k) for k in namedtuple._fields}
gan_model_dict = _asdict(gan_model)
# Make sure non-tuple required args are supplied. # Make sure non-tuple required args are supplied.
args_from_tuple = set(argspec.args).intersection(set(gan_model._fields)) args_from_tuple = set(argspec.args).intersection(set(gan_model._fields))

View File

@ -79,6 +79,21 @@ class ArgsToGanModelTest(test.TestCase):
# If `arg3` were not set properly, this value would be different. # If `arg3` were not set properly, this value would be different.
self.assertEqual(-1 + 2 * 2 + 3 * 4, loss) self.assertEqual(-1 + 2 * 2 + 3 * 4, loss)
def test_works_with_child_classes(self):
"""`args_to_gan_model` should work with classes derived from namedtuple."""
tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg2'])
class InheritedType(tuple_type):
pass
def args_loss(arg1, arg2, arg3=3):
return arg1 + 2 * arg2 + 3 * arg3
loss_fn = tfgan_losses._args_to_gan_model(args_loss)
loss = loss_fn(InheritedType(arg1=-1, arg2=2), arg3=4)
# If `arg3` were not set properly, this value would be different.
self.assertEqual(-1 + 2 * 2 + 3 * 4, loss)
class ConsistentLossesTest(test.TestCase): class ConsistentLossesTest(test.TestCase):