Some test cleanups.

PiperOrigin-RevId: 225819680
This commit is contained in:
A. Unique TensorFlower 2018-12-17 07:00:45 -08:00 committed by TensorFlower Gardener
parent 6decf0842b
commit 2a067cb0b1

View File

@ -75,8 +75,8 @@ class GetGANModelTest(test.TestCase, parameterized.TestCase):
def test_get_gan_model(self, mode):
with ops.Graph().as_default():
generator_inputs = {'x': array_ops.ones([3, 4])}
real_data = (array_ops.zeros([3, 4]) if
mode != model_fn_lib.ModeKeys.PREDICT else None)
is_predict = mode == model_fn_lib.ModeKeys.PREDICT
real_data = array_ops.zeros([3, 4]) if not is_predict else None
gan_model = estimator._get_gan_model(
mode, generator_fn, discriminator_fn, real_data, generator_inputs,
add_summaries=False)
@ -139,6 +139,7 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase):
@classmethod
def setUpClass(cls):
super(GetEstimatorSpecTest, cls).setUpClass()
cls._generator_optimizer = training.GradientDescentOptimizer(1.0)
cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0)
@ -200,7 +201,6 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase):
self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
# TODO(joelshor): Add pandas test.
class GANEstimatorIntegrationTest(test.TestCase):
def setUp(self):
@ -231,11 +231,11 @@ class GANEstimatorIntegrationTest(test.TestCase):
get_eval_metric_ops_fn=get_metrics,
model_dir=self._model_dir)
# TRAIN
# Train.
num_steps = 10
est.train(train_input_fn, steps=num_steps)
# EVALUTE
# Evaluate.
scores = est.evaluate(eval_input_fn)
self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
self.assertIn('loss', six.iterkeys(scores))
@ -243,7 +243,7 @@ class GANEstimatorIntegrationTest(test.TestCase):
scores['loss'])
self.assertIn('mse_custom_metric', six.iterkeys(scores))
# PREDICT
# Predict.
predictions = np.array([x for x in est.predict(predict_input_fn)])
self.assertAllEqual(prediction_size, predictions.shape)