From 1602ac6a918aa130c14e8058c9b940792242d995 Mon Sep 17 00:00:00 2001 From: Illia Polosukhin Date: Tue, 21 Jun 2016 08:18:50 -0800 Subject: [PATCH] Fix typo in the run_local call of Experiment. Added test for run_local. Change: 125458571 --- .../contrib/learn/python/learn/experiment.py | 3 ++- .../python/learn/tests/experiment_test.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index e477eb217d5..3d01810bc14 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -121,9 +121,10 @@ class Experiment(object): Returns: The result of the `evaluate` call to the `Estimator`. """ + self._train_monitors = self._train_monitors or [] if self._local_eval_frequency: self._train_monitors += [monitors.ValidationMonitor( - input_fn=self._eval_input_fn, steps=self._eval_steps, + input_fn=self._eval_input_fn, eval_steps=self._eval_steps, metrics=self._eval_metrics, every_n_steps=self._local_eval_frequency )] self.train() diff --git a/tensorflow/contrib/learn/python/learn/tests/experiment_test.py b/tensorflow/contrib/learn/python/learn/tests/experiment_test.py index bc3bd4b7287..8dc54a32414 100644 --- a/tensorflow/contrib/learn/python/learn/tests/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/experiment_test.py @@ -27,6 +27,7 @@ class TestEstimator(object): def __init__(self): self.eval_count = 0 self.fit_count = 0 + self.monitors = [] def evaluate(self, **kwargs): tf.logging.info('evaluate called with args: %s' % kwargs) @@ -39,6 +40,8 @@ class TestEstimator(object): def fit(self, **kwargs): tf.logging.info('fit called with args: %s' % kwargs) self.fit_count += 1 + if 'monitors' in kwargs: + self.monitors = kwargs['monitors'] return [(key, kwargs[key]) for key in sorted(kwargs.keys())] @@ -115,6 +118,22 @@ class ExperimentTest(tf.test.TestCase): tf.logging.info('eval duration (expected %f): %f', expected, duration) self.assertTrue(duration > expected - 0.5 and duration < expected + 0.5) + def test_run_local(self): + est = TestEstimator() + ex = tf.contrib.learn.Experiment(est, + train_input_fn='train_input', + eval_input_fn='eval_input', + eval_metrics='eval_metrics', + train_steps=100, + eval_steps=100, + local_eval_frequency=10) + ex.local_run() + self.assertEquals(1, est.fit_count) + self.assertEquals(1, est.eval_count) + self.assertEquals(1, len(est.monitors)) + self.assertTrue(isinstance(est.monitors[0], + tf.contrib.learn.monitors.ValidationMonitor)) + if __name__ == '__main__': tf.test.main()