Fixed string formatting and valid task listing.
Change: 144656658
This commit is contained in:
parent
dbd7a2c9e3
commit
26201f5938
@ -88,20 +88,20 @@ def run(experiment_fn, output_dir, schedule=None):
|
||||
# Execute the schedule
|
||||
if not hasattr(experiment, schedule):
|
||||
logging.error('Schedule references non-existent task %s', schedule)
|
||||
valid_tasks = [x for x in experiment.__dict__
|
||||
if callable(getattr(experiment, x))]
|
||||
valid_tasks = [x for x in dir(experiment)
|
||||
if not x.startswith('_')
|
||||
and callable(getattr(experiment, x))]
|
||||
logging.error('Allowed values for this experiment are: %s', valid_tasks)
|
||||
raise ValueError('Schedule references non-existent task %s', schedule)
|
||||
raise ValueError('Schedule references non-existent task %s' % schedule)
|
||||
|
||||
task = getattr(experiment, schedule)
|
||||
if not callable(task):
|
||||
logging.error('Schedule references non-callable member %s', schedule)
|
||||
valid_tasks = [
|
||||
x for x in experiment.__dict__
|
||||
if callable(getattr(experiment, x)) and not x.startswith('_')
|
||||
]
|
||||
valid_tasks = [x for x in dir(experiment)
|
||||
if not x.startswith('_')
|
||||
and callable(getattr(experiment, x))]
|
||||
logging.error('Allowed values for this experiment are: %s', valid_tasks)
|
||||
raise TypeError('Schedule references non-callable member %s', schedule)
|
||||
raise TypeError('Schedule references non-callable member %s' % schedule)
|
||||
|
||||
return task()
|
||||
|
||||
|
@ -27,9 +27,11 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
|
||||
import ctypes
|
||||
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
|
||||
|
||||
from tensorflow.contrib.learn.python.learn import evaluable # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.contrib.learn.python.learn import experiment
|
||||
from tensorflow.contrib.learn.python.learn import learn_runner
|
||||
from tensorflow.contrib.learn.python.learn import run_config
|
||||
from tensorflow.contrib.learn.python.learn import trainable
|
||||
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
@ -43,13 +45,22 @@ class TestExperiment(experiment.Experiment):
|
||||
self.default = default
|
||||
self.config = config
|
||||
|
||||
@property
|
||||
def estimator(self):
|
||||
|
||||
class Estimator(object):
|
||||
class Estimator(evaluable.Evaluable, trainable.Trainable):
|
||||
config = self.config
|
||||
|
||||
return Estimator()
|
||||
def model_dir(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
|
||||
monitors=None, max_steps=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None,
|
||||
batch_size=None, steps=None, metrics=None, name=None,
|
||||
checkpoint_path=None, hooks=None):
|
||||
raise NotImplementedError
|
||||
|
||||
super(TestExperiment, self).__init__(Estimator(), None, None)
|
||||
|
||||
def local_run(self):
|
||||
return "local_run"
|
||||
|
Loading…
Reference in New Issue
Block a user