Fix typo in the run_local call of Experiment. Added test for run_local.
Change: 125458571
This commit is contained in:
parent
002e1854f6
commit
1602ac6a91
@ -121,9 +121,10 @@ class Experiment(object):
|
||||
Returns:
|
||||
The result of the `evaluate` call to the `Estimator`.
|
||||
"""
|
||||
self._train_monitors = self._train_monitors or []
|
||||
if self._local_eval_frequency:
|
||||
self._train_monitors += [monitors.ValidationMonitor(
|
||||
input_fn=self._eval_input_fn, steps=self._eval_steps,
|
||||
input_fn=self._eval_input_fn, eval_steps=self._eval_steps,
|
||||
metrics=self._eval_metrics, every_n_steps=self._local_eval_frequency
|
||||
)]
|
||||
self.train()
|
||||
|
@ -27,6 +27,7 @@ class TestEstimator(object):
|
||||
def __init__(self):
|
||||
self.eval_count = 0
|
||||
self.fit_count = 0
|
||||
self.monitors = []
|
||||
|
||||
def evaluate(self, **kwargs):
|
||||
tf.logging.info('evaluate called with args: %s' % kwargs)
|
||||
@ -39,6 +40,8 @@ class TestEstimator(object):
|
||||
def fit(self, **kwargs):
|
||||
tf.logging.info('fit called with args: %s' % kwargs)
|
||||
self.fit_count += 1
|
||||
if 'monitors' in kwargs:
|
||||
self.monitors = kwargs['monitors']
|
||||
return [(key, kwargs[key]) for key in sorted(kwargs.keys())]
|
||||
|
||||
|
||||
@ -115,6 +118,22 @@ class ExperimentTest(tf.test.TestCase):
|
||||
tf.logging.info('eval duration (expected %f): %f', expected, duration)
|
||||
self.assertTrue(duration > expected - 0.5 and duration < expected + 0.5)
|
||||
|
||||
def test_run_local(self):
|
||||
est = TestEstimator()
|
||||
ex = tf.contrib.learn.Experiment(est,
|
||||
train_input_fn='train_input',
|
||||
eval_input_fn='eval_input',
|
||||
eval_metrics='eval_metrics',
|
||||
train_steps=100,
|
||||
eval_steps=100,
|
||||
local_eval_frequency=10)
|
||||
ex.local_run()
|
||||
self.assertEquals(1, est.fit_count)
|
||||
self.assertEquals(1, est.eval_count)
|
||||
self.assertEquals(1, len(est.monitors))
|
||||
self.assertTrue(isinstance(est.monitors[0],
|
||||
tf.contrib.learn.monitors.ValidationMonitor))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user