Expose session_creation_timeout_secs in estimator RunConfig

PiperOrigin-RevId: 267290523
This commit is contained in:
Amy Skerry-Ryan 2019-09-04 20:54:34 -07:00 committed by TensorFlower Gardener
parent 132bebc675
commit 4aa71c406c
5 changed files with 29 additions and 5 deletions

View File

@ -1088,6 +1088,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=self._config.save_summary_steps,
max_wait_secs=self._config.session_creation_timeout_secs,
config=self._session_config) as mon_sess:
loss = None
while not mon_sess.should_stop():

View File

@ -243,7 +243,8 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
protocol=None,
evaluation_master='',
model_dir=None,
session_config=None):
session_config=None,
session_creation_timeout_secs=7200):
"""Constructor.
The superclass `ClusterConfig` may set properties like `cluster_spec`,
@ -282,6 +283,8 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
the feature.
log_step_count_steps: The frequency, in number of global steps, that the
global step/sec will be logged during training.
protocol: An optional argument which specifies the protocol used when
starting server. None means default to grpc.
evaluation_master: the master on which to perform evaluation.
model_dir: directory where model parameters, graph etc are saved. If
`None`, will use `model_dir` property in `TF_CONFIG` environment
@ -290,8 +293,11 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
session_config: a ConfigProto used to set session parameters, or None.
Note - using this argument, it is easy to provide settings which break
otherwise perfectly good models. Use with care.
protocol: An optional argument which specifies the protocol used when
starting server. None means default to grpc.
session_creation_timeout_secs: Max time workers should wait for a session
to become available (on initialization or when recovering a session)
with MonitoredTrainingSession. Defaults to 7200 seconds, but users may
want to set a lower value to detect problems with variable / session
(re)-initialization more quickly.
"""
# Neither parent class calls super().__init__(), so here we have to
# manually call their __init__() methods.
@ -332,6 +338,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
self._keep_checkpoint_max = keep_checkpoint_max
self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
self._model_dir = _get_model_dir(model_dir)
self._session_creation_timeout_secs = session_creation_timeout_secs
@experimental
def uid(self, whitelist=None):
@ -408,6 +415,10 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
def log_step_count_steps(self):
return self._log_step_count_steps
@property
def session_creation_timeout_secs(self):
return self._session_creation_timeout_secs
def _count_ps(cluster_spec):
"""Counts the number of parameter servers in cluster_spec."""

View File

@ -82,6 +82,10 @@ tf_class {
name: "session_config"
mtype: "<type \'property\'>"
}
member {
name: "session_creation_timeout_secs"
mtype: "<type \'property\'>"
}
member {
name: "task_id"
mtype: "<type \'property\'>"
@ -100,7 +104,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\', \'experimental_max_worker_delay_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\', \'experimental_max_worker_delay_secs\', \'session_creation_timeout_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'7200\'], "
}
member_method {
name: "replace"

View File

@ -87,6 +87,10 @@ tf_class {
name: "session_config"
mtype: "<type \'property\'>"
}
member {
name: "session_creation_timeout_secs"
mtype: "<type \'property\'>"
}
member {
name: "task_id"
mtype: "<type \'property\'>"

View File

@ -82,6 +82,10 @@ tf_class {
name: "session_config"
mtype: "<type \'property\'>"
}
member {
name: "session_creation_timeout_secs"
mtype: "<type \'property\'>"
}
member {
name: "task_id"
mtype: "<type \'property\'>"
@ -100,7 +104,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\', \'experimental_max_worker_delay_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\', \'experimental_max_worker_delay_secs\', \'session_creation_timeout_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'7200\'], "
}
member_method {
name: "replace"