Fixed string formatting and valid task listing.

Change: 144656658
This commit is contained in:
A. Unique TensorFlower 2017-01-16 15:44:31 -08:00 committed by TensorFlower Gardener
parent dbd7a2c9e3
commit 26201f5938
2 changed files with 24 additions and 13 deletions

View File

@ -88,20 +88,20 @@ def run(experiment_fn, output_dir, schedule=None):
# Execute the schedule
if not hasattr(experiment, schedule):
logging.error('Schedule references non-existent task %s', schedule)
valid_tasks = [x for x in experiment.__dict__
if callable(getattr(experiment, x))]
valid_tasks = [x for x in dir(experiment)
if not x.startswith('_')
and callable(getattr(experiment, x))]
logging.error('Allowed values for this experiment are: %s', valid_tasks)
raise ValueError('Schedule references non-existent task %s', schedule)
raise ValueError('Schedule references non-existent task %s' % schedule)
task = getattr(experiment, schedule)
if not callable(task):
logging.error('Schedule references non-callable member %s', schedule)
valid_tasks = [
x for x in experiment.__dict__
if callable(getattr(experiment, x)) and not x.startswith('_')
]
valid_tasks = [x for x in dir(experiment)
if not x.startswith('_')
and callable(getattr(experiment, x))]
logging.error('Allowed values for this experiment are: %s', valid_tasks)
raise TypeError('Schedule references non-callable member %s', schedule)
raise TypeError('Schedule references non-callable member %s' % schedule)
return task()

View File

@ -27,9 +27,11 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
import ctypes
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
from tensorflow.contrib.learn.python.learn import evaluable # pylint: disable=g-import-not-at-top
from tensorflow.contrib.learn.python.learn import experiment
from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.contrib.learn.python.learn import run_config
from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
@ -43,13 +45,22 @@ class TestExperiment(experiment.Experiment):
self.default = default
self.config = config
@property
def estimator(self):
class Estimator(object):
class Estimator(evaluable.Evaluable, trainable.Trainable):
config = self.config
return Estimator()
def model_dir(self):
raise NotImplementedError
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
monitors=None, max_steps=None):
raise NotImplementedError
def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None,
batch_size=None, steps=None, metrics=None, name=None,
checkpoint_path=None, hooks=None):
raise NotImplementedError
super(TestExperiment, self).__init__(Estimator(), None, None)
def local_run(self):
return "local_run"