Some test cleanups.
PiperOrigin-RevId: 225819680
This commit is contained in:
parent
6decf0842b
commit
2a067cb0b1
@ -75,8 +75,8 @@ class GetGANModelTest(test.TestCase, parameterized.TestCase):
|
|||||||
def test_get_gan_model(self, mode):
|
def test_get_gan_model(self, mode):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
generator_inputs = {'x': array_ops.ones([3, 4])}
|
generator_inputs = {'x': array_ops.ones([3, 4])}
|
||||||
real_data = (array_ops.zeros([3, 4]) if
|
is_predict = mode == model_fn_lib.ModeKeys.PREDICT
|
||||||
mode != model_fn_lib.ModeKeys.PREDICT else None)
|
real_data = array_ops.zeros([3, 4]) if not is_predict else None
|
||||||
gan_model = estimator._get_gan_model(
|
gan_model = estimator._get_gan_model(
|
||||||
mode, generator_fn, discriminator_fn, real_data, generator_inputs,
|
mode, generator_fn, discriminator_fn, real_data, generator_inputs,
|
||||||
add_summaries=False)
|
add_summaries=False)
|
||||||
@ -139,6 +139,7 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
super(GetEstimatorSpecTest, cls).setUpClass()
|
||||||
cls._generator_optimizer = training.GradientDescentOptimizer(1.0)
|
cls._generator_optimizer = training.GradientDescentOptimizer(1.0)
|
||||||
cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0)
|
cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0)
|
||||||
|
|
||||||
@ -200,7 +201,6 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
|
self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
|
||||||
|
|
||||||
|
|
||||||
# TODO(joelshor): Add pandas test.
|
|
||||||
class GANEstimatorIntegrationTest(test.TestCase):
|
class GANEstimatorIntegrationTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -231,11 +231,11 @@ class GANEstimatorIntegrationTest(test.TestCase):
|
|||||||
get_eval_metric_ops_fn=get_metrics,
|
get_eval_metric_ops_fn=get_metrics,
|
||||||
model_dir=self._model_dir)
|
model_dir=self._model_dir)
|
||||||
|
|
||||||
# TRAIN
|
# Train.
|
||||||
num_steps = 10
|
num_steps = 10
|
||||||
est.train(train_input_fn, steps=num_steps)
|
est.train(train_input_fn, steps=num_steps)
|
||||||
|
|
||||||
# EVALUTE
|
# Evaluate.
|
||||||
scores = est.evaluate(eval_input_fn)
|
scores = est.evaluate(eval_input_fn)
|
||||||
self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
|
self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
|
||||||
self.assertIn('loss', six.iterkeys(scores))
|
self.assertIn('loss', six.iterkeys(scores))
|
||||||
@ -243,7 +243,7 @@ class GANEstimatorIntegrationTest(test.TestCase):
|
|||||||
scores['loss'])
|
scores['loss'])
|
||||||
self.assertIn('mse_custom_metric', six.iterkeys(scores))
|
self.assertIn('mse_custom_metric', six.iterkeys(scores))
|
||||||
|
|
||||||
# PREDICT
|
# Predict.
|
||||||
predictions = np.array([x for x in est.predict(predict_input_fn)])
|
predictions = np.array([x for x in est.predict(predict_input_fn)])
|
||||||
|
|
||||||
self.assertAllEqual(prediction_size, predictions.shape)
|
self.assertAllEqual(prediction_size, predictions.shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user