Fix tuple_losses bug caused by Python bug.
PiperOrigin-RevId: 168386341
This commit is contained in:
parent
7a8c63da36
commit
bc6b60f1bc
@ -73,7 +73,21 @@ def _args_to_gan_model(loss_fn):
|
||||
default_args_dict = dict(zip(args_with_defaults, defaults))
|
||||
|
||||
def new_loss_fn(gan_model, **kwargs): # pylint:disable=missing-docstring
|
||||
gan_model_dict = gan_model._asdict()
|
||||
def _asdict(namedtuple):
|
||||
"""Returns a namedtuple as a dictionary.
|
||||
|
||||
This is required because `_asdict()` in Python 3.x.x is broken in classes
|
||||
that inherit from `collections.namedtuple`. See
|
||||
https://bugs.python.org/issue24931 for more details.
|
||||
|
||||
Args:
|
||||
namedtuple: An object that inherits from `collections.namedtuple`.
|
||||
|
||||
Returns:
|
||||
A dictionary version of the tuple.
|
||||
"""
|
||||
return {k: getattr(namedtuple, k) for k in namedtuple._fields}
|
||||
gan_model_dict = _asdict(gan_model)
|
||||
|
||||
# Make sure non-tuple required args are supplied.
|
||||
args_from_tuple = set(argspec.args).intersection(set(gan_model._fields))
|
||||
|
@ -79,6 +79,21 @@ class ArgsToGanModelTest(test.TestCase):
|
||||
# If `arg3` were not set properly, this value would be different.
|
||||
self.assertEqual(-1 + 2 * 2 + 3 * 4, loss)
|
||||
|
||||
def test_works_with_child_classes(self):
|
||||
"""`args_to_gan_model` should work with classes derived from namedtuple."""
|
||||
tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg2'])
|
||||
|
||||
class InheritedType(tuple_type):
|
||||
pass
|
||||
def args_loss(arg1, arg2, arg3=3):
|
||||
return arg1 + 2 * arg2 + 3 * arg3
|
||||
|
||||
loss_fn = tfgan_losses._args_to_gan_model(args_loss)
|
||||
loss = loss_fn(InheritedType(arg1=-1, arg2=2), arg3=4)
|
||||
|
||||
# If `arg3` were not set properly, this value would be different.
|
||||
self.assertEqual(-1 + 2 * 2 + 3 * 4, loss)
|
||||
|
||||
|
||||
class ConsistentLossesTest(test.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user