Fixed string formatting and valid task listing.
Change: 144656658
This commit is contained in:
parent
dbd7a2c9e3
commit
26201f5938
@ -88,20 +88,20 @@ def run(experiment_fn, output_dir, schedule=None):
|
|||||||
# Execute the schedule
|
# Execute the schedule
|
||||||
if not hasattr(experiment, schedule):
|
if not hasattr(experiment, schedule):
|
||||||
logging.error('Schedule references non-existent task %s', schedule)
|
logging.error('Schedule references non-existent task %s', schedule)
|
||||||
valid_tasks = [x for x in experiment.__dict__
|
valid_tasks = [x for x in dir(experiment)
|
||||||
if callable(getattr(experiment, x))]
|
if not x.startswith('_')
|
||||||
|
and callable(getattr(experiment, x))]
|
||||||
logging.error('Allowed values for this experiment are: %s', valid_tasks)
|
logging.error('Allowed values for this experiment are: %s', valid_tasks)
|
||||||
raise ValueError('Schedule references non-existent task %s', schedule)
|
raise ValueError('Schedule references non-existent task %s' % schedule)
|
||||||
|
|
||||||
task = getattr(experiment, schedule)
|
task = getattr(experiment, schedule)
|
||||||
if not callable(task):
|
if not callable(task):
|
||||||
logging.error('Schedule references non-callable member %s', schedule)
|
logging.error('Schedule references non-callable member %s', schedule)
|
||||||
valid_tasks = [
|
valid_tasks = [x for x in dir(experiment)
|
||||||
x for x in experiment.__dict__
|
if not x.startswith('_')
|
||||||
if callable(getattr(experiment, x)) and not x.startswith('_')
|
and callable(getattr(experiment, x))]
|
||||||
]
|
|
||||||
logging.error('Allowed values for this experiment are: %s', valid_tasks)
|
logging.error('Allowed values for this experiment are: %s', valid_tasks)
|
||||||
raise TypeError('Schedule references non-callable member %s', schedule)
|
raise TypeError('Schedule references non-callable member %s' % schedule)
|
||||||
|
|
||||||
return task()
|
return task()
|
||||||
|
|
||||||
|
@ -27,9 +27,11 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
|
|||||||
import ctypes
|
import ctypes
|
||||||
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
|
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
|
||||||
|
|
||||||
|
from tensorflow.contrib.learn.python.learn import evaluable # pylint: disable=g-import-not-at-top
|
||||||
from tensorflow.contrib.learn.python.learn import experiment
|
from tensorflow.contrib.learn.python.learn import experiment
|
||||||
from tensorflow.contrib.learn.python.learn import learn_runner
|
from tensorflow.contrib.learn.python.learn import learn_runner
|
||||||
from tensorflow.contrib.learn.python.learn import run_config
|
from tensorflow.contrib.learn.python.learn import run_config
|
||||||
|
from tensorflow.contrib.learn.python.learn import trainable
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
|
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.platform import tf_logging
|
from tensorflow.python.platform import tf_logging
|
||||||
@ -43,13 +45,22 @@ class TestExperiment(experiment.Experiment):
|
|||||||
self.default = default
|
self.default = default
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@property
|
class Estimator(evaluable.Evaluable, trainable.Trainable):
|
||||||
def estimator(self):
|
|
||||||
|
|
||||||
class Estimator(object):
|
|
||||||
config = self.config
|
config = self.config
|
||||||
|
|
||||||
return Estimator()
|
def model_dir(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
|
||||||
|
monitors=None, max_steps=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None,
|
||||||
|
batch_size=None, steps=None, metrics=None, name=None,
|
||||||
|
checkpoint_path=None, hooks=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
super(TestExperiment, self).__init__(Estimator(), None, None)
|
||||||
|
|
||||||
def local_run(self):
|
def local_run(self):
|
||||||
return "local_run"
|
return "local_run"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user