Some test cleanups.
PiperOrigin-RevId: 225819680
This commit is contained in:
parent
6decf0842b
commit
2a067cb0b1
@ -75,8 +75,8 @@ class GetGANModelTest(test.TestCase, parameterized.TestCase):
|
||||
def test_get_gan_model(self, mode):
|
||||
with ops.Graph().as_default():
|
||||
generator_inputs = {'x': array_ops.ones([3, 4])}
|
||||
real_data = (array_ops.zeros([3, 4]) if
|
||||
mode != model_fn_lib.ModeKeys.PREDICT else None)
|
||||
is_predict = mode == model_fn_lib.ModeKeys.PREDICT
|
||||
real_data = array_ops.zeros([3, 4]) if not is_predict else None
|
||||
gan_model = estimator._get_gan_model(
|
||||
mode, generator_fn, discriminator_fn, real_data, generator_inputs,
|
||||
add_summaries=False)
|
||||
@ -139,6 +139,7 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(GetEstimatorSpecTest, cls).setUpClass()
|
||||
cls._generator_optimizer = training.GradientDescentOptimizer(1.0)
|
||||
cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0)
|
||||
|
||||
@ -200,7 +201,6 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
|
||||
|
||||
|
||||
# TODO(joelshor): Add pandas test.
|
||||
class GANEstimatorIntegrationTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -231,11 +231,11 @@ class GANEstimatorIntegrationTest(test.TestCase):
|
||||
get_eval_metric_ops_fn=get_metrics,
|
||||
model_dir=self._model_dir)
|
||||
|
||||
# TRAIN
|
||||
# Train.
|
||||
num_steps = 10
|
||||
est.train(train_input_fn, steps=num_steps)
|
||||
|
||||
# EVALUTE
|
||||
# Evaluate.
|
||||
scores = est.evaluate(eval_input_fn)
|
||||
self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
|
||||
self.assertIn('loss', six.iterkeys(scores))
|
||||
@ -243,7 +243,7 @@ class GANEstimatorIntegrationTest(test.TestCase):
|
||||
scores['loss'])
|
||||
self.assertIn('mse_custom_metric', six.iterkeys(scores))
|
||||
|
||||
# PREDICT
|
||||
# Predict.
|
||||
predictions = np.array([x for x in est.predict(predict_input_fn)])
|
||||
|
||||
self.assertAllEqual(prediction_size, predictions.shape)
|
||||
|
Loading…
Reference in New Issue
Block a user