Merge pull request #18565 from alexpantyukhin/ganhead_constructor_validate
add checking for input values in GANHead constructor
This commit is contained in:
commit
85cb8d48b7
@ -103,9 +103,20 @@ class GANHead(head._Head): # pylint: disable=protected-access
|
||||
name: name of the head. If provided, summary and metrics keys will be
|
||||
suffixed by `"/" + name`.
|
||||
"""
|
||||
|
||||
if not callable(generator_loss_fn):
|
||||
raise TypeError('generator_loss_fn must be callable.')
|
||||
if not callable(discriminator_loss_fn):
|
||||
raise TypeError('discriminator_loss_fn must be callable.')
|
||||
if not use_loss_summaries in [True, False, None]:
|
||||
raise ValueError('use_loss_summaries must be True, False or None.')
|
||||
if get_hooks_fn is not None and not callable(get_hooks_fn):
|
||||
raise TypeError('get_hooks_fn must be callable.')
|
||||
if name is not None and not isinstance(name, str):
|
||||
raise TypeError('name must be string.')
|
||||
|
||||
if get_hooks_fn is None:
|
||||
get_hooks_fn = tfgan_train.get_sequential_train_hooks()
|
||||
# TODO(joelshor): Validate inputs.
|
||||
|
||||
if use_loss_summaries in [True, False]:
|
||||
generator_loss_fn = functools.partial(
|
||||
|
Loading…
Reference in New Issue
Block a user