From 8dc3b3c453180211f4be5302f957664004e1ec04 Mon Sep 17 00:00:00 2001 From: apantykhin Date: Mon, 16 Apr 2018 20:40:51 +0400 Subject: [PATCH 1/2] add checking for input values in GANHead constructor --- .../gan/python/estimator/python/head_impl.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index a21358c50bb..652ffee30ac 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -25,6 +25,7 @@ from tensorflow.contrib.gan.python import train as tfgan_train from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.canned import head from tensorflow.python.framework import ops +from tensorflow.python.training import optimizer __all__ = [ 'GANHead', @@ -90,9 +91,24 @@ class GANHead(head._Head): # pylint: disable=protected-access name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ + + if not callable(generator_loss_fn): + raise TypeError('generator_loss_fn must be callable.') + if not callable(discriminator_loss_fn): + raise TypeError('discriminator_loss_fn must be callable.') + if not isinstance(generator_optimizer, optimizer.Optimizer): + raise TypeError('generator_optimizer must be Optimizer.') + if not isinstance(discriminator_optimizer, optimizer.Optimizer): + raise TypeError('discriminator_optimizer must be Optimizer.') + if not use_loss_summaries in [True, False, None]: + raise ValueError('use_loss_summaries must be True, False or None.') + if get_hooks_fn is not None and not callable(get_hooks_fn): + raise TypeError('get_hooks_fn must be callable.') + if name is not None and not isinstance(name, str): + raise TypeError('name must be string.') + if get_hooks_fn is None: get_hooks_fn = tfgan_train.get_sequential_train_hooks() - # TODO(joelshor): Validate inputs. if use_loss_summaries in [True, False]: generator_loss_fn = functools.partial( From 558cbd9fc89055f532a9558a276a9e6b438371cf Mon Sep 17 00:00:00 2001 From: apantykhin Date: Wed, 6 Jun 2018 19:55:46 +0400 Subject: [PATCH 2/2] remove optimizer checking. --- tensorflow/contrib/gan/python/estimator/python/head_impl.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index 652ffee30ac..4750f94d9a6 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -25,7 +25,6 @@ from tensorflow.contrib.gan.python import train as tfgan_train from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.canned import head from tensorflow.python.framework import ops -from tensorflow.python.training import optimizer __all__ = [ 'GANHead', @@ -96,10 +95,6 @@ class GANHead(head._Head): # pylint: disable=protected-access raise TypeError('generator_loss_fn must be callable.') if not callable(discriminator_loss_fn): raise TypeError('discriminator_loss_fn must be callable.') - if not isinstance(generator_optimizer, optimizer.Optimizer): - raise TypeError('generator_optimizer must be Optimizer.') - if not isinstance(discriminator_optimizer, optimizer.Optimizer): - raise TypeError('discriminator_optimizer must be Optimizer.') if not use_loss_summaries in [True, False, None]: raise ValueError('use_loss_summaries must be True, False or None.') if get_hooks_fn is not None and not callable(get_hooks_fn):