add checking for input values in GANHead constructor

This commit is contained in:
apantykhin 2018-04-16 20:40:51 +04:00
parent d17de3d27f
commit 8dc3b3c453

View File

@ -25,6 +25,7 @@ from tensorflow.contrib.gan.python import train as tfgan_train
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator.canned import head
from tensorflow.python.framework import ops
from tensorflow.python.training import optimizer
__all__ = [
'GANHead',
@ -90,9 +91,24 @@ class GANHead(head._Head): # pylint: disable=protected-access
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`.
"""
if not callable(generator_loss_fn):
raise TypeError('generator_loss_fn must be callable.')
if not callable(discriminator_loss_fn):
raise TypeError('discriminator_loss_fn must be callable.')
if not isinstance(generator_optimizer, optimizer.Optimizer):
raise TypeError('generator_optimizer must be Optimizer.')
if not isinstance(discriminator_optimizer, optimizer.Optimizer):
raise TypeError('discriminator_optimizer must be Optimizer.')
if not use_loss_summaries in [True, False, None]:
raise ValueError('use_loss_summaries must be True, False or None.')
if get_hooks_fn is not None and not callable(get_hooks_fn):
raise TypeError('get_hooks_fn must be callable.')
if name is not None and not isinstance(name, str):
raise TypeError('name must be string.')
if get_hooks_fn is None:
get_hooks_fn = tfgan_train.get_sequential_train_hooks()
# TODO(joelshor): Validate inputs.
if use_loss_summaries in [True, False]:
generator_loss_fn = functools.partial(