Expose session_creation_timeout_secs in estimator RunConfig
PiperOrigin-RevId: 267290523
This commit is contained in:
parent
132bebc675
commit
4aa71c406c
@ -1088,6 +1088,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
|
||||
chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
|
||||
save_checkpoint_secs=0, # Saving is handled by a hook.
|
||||
save_summaries_steps=self._config.save_summary_steps,
|
||||
max_wait_secs=self._config.session_creation_timeout_secs,
|
||||
config=self._session_config) as mon_sess:
|
||||
loss = None
|
||||
while not mon_sess.should_stop():
|
||||
|
@ -243,7 +243,8 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
|
||||
protocol=None,
|
||||
evaluation_master='',
|
||||
model_dir=None,
|
||||
session_config=None):
|
||||
session_config=None,
|
||||
session_creation_timeout_secs=7200):
|
||||
"""Constructor.
|
||||
|
||||
The superclass `ClusterConfig` may set properties like `cluster_spec`,
|
||||
@ -282,6 +283,8 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
|
||||
the feature.
|
||||
log_step_count_steps: The frequency, in number of global steps, that the
|
||||
global step/sec will be logged during training.
|
||||
protocol: An optional argument which specifies the protocol used when
|
||||
starting server. None means default to grpc.
|
||||
evaluation_master: the master on which to perform evaluation.
|
||||
model_dir: directory where model parameters, graph etc are saved. If
|
||||
`None`, will use `model_dir` property in `TF_CONFIG` environment
|
||||
@ -290,8 +293,11 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
|
||||
session_config: a ConfigProto used to set session parameters, or None.
|
||||
Note - using this argument, it is easy to provide settings which break
|
||||
otherwise perfectly good models. Use with care.
|
||||
protocol: An optional argument which specifies the protocol used when
|
||||
starting server. None means default to grpc.
|
||||
session_creation_timeout_secs: Max time workers should wait for a session
|
||||
to become available (on initialization or when recovering a session)
|
||||
with MonitoredTrainingSession. Defaults to 7200 seconds, but users may
|
||||
want to set a lower value to detect problems with variable / session
|
||||
(re)-initialization more quickly.
|
||||
"""
|
||||
# Neither parent class calls super().__init__(), so here we have to
|
||||
# manually call their __init__() methods.
|
||||
@ -332,6 +338,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
|
||||
self._keep_checkpoint_max = keep_checkpoint_max
|
||||
self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
|
||||
self._model_dir = _get_model_dir(model_dir)
|
||||
self._session_creation_timeout_secs = session_creation_timeout_secs
|
||||
|
||||
@experimental
|
||||
def uid(self, whitelist=None):
|
||||
@ -408,6 +415,10 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
|
||||
def log_step_count_steps(self):
|
||||
return self._log_step_count_steps
|
||||
|
||||
@property
|
||||
def session_creation_timeout_secs(self):
|
||||
return self._session_creation_timeout_secs
|
||||
|
||||
|
||||
def _count_ps(cluster_spec):
|
||||
"""Counts the number of parameter servers in cluster_spec."""
|
||||
|
@ -82,6 +82,10 @@ tf_class {
|
||||
name: "session_config"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "session_creation_timeout_secs"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_id"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -100,7 +104,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\', \'experimental_max_worker_delay_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\', \'experimental_max_worker_delay_secs\', \'session_creation_timeout_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'7200\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "replace"
|
||||
|
@ -87,6 +87,10 @@ tf_class {
|
||||
name: "session_config"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "session_creation_timeout_secs"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_id"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -82,6 +82,10 @@ tf_class {
|
||||
name: "session_config"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "session_creation_timeout_secs"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "task_id"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -100,7 +104,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\', \'experimental_max_worker_delay_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\', \'experimental_max_worker_delay_secs\', \'session_creation_timeout_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'7200\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "replace"
|
||||
|
Loading…
Reference in New Issue
Block a user