From 4aa71c406c7c5c3c2b202c110d0734f2887d88bd Mon Sep 17 00:00:00 2001 From: Amy Skerry-Ryan Date: Wed, 4 Sep 2019 20:54:34 -0700 Subject: [PATCH] Expose session_creation_timeout_secs in estimator RunConfig PiperOrigin-RevId: 267290523 --- .../learn/python/learn/estimators/estimator.py | 1 + .../learn/python/learn/estimators/run_config.py | 17 ++++++++++++++--- .../v1/tensorflow.estimator.-run-config.pbtxt | 6 +++++- .../tensorflow.estimator.tpu.-run-config.pbtxt | 4 ++++ .../v2/tensorflow.estimator.-run-config.pbtxt | 6 +++++- 5 files changed, 29 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 9132b2209bc..8d8f5619a4a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -1088,6 +1088,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks, save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, + max_wait_secs=self._config.session_creation_timeout_secs, config=self._session_config) as mon_sess: loss = None while not mon_sess.should_stop(): diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index b51ea30959e..e435fd65702 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -243,7 +243,8 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): protocol=None, evaluation_master='', model_dir=None, - session_config=None): + session_config=None, + session_creation_timeout_secs=7200): """Constructor. The superclass `ClusterConfig` may set properties like `cluster_spec`, @@ -282,6 +283,8 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): the feature. log_step_count_steps: The frequency, in number of global steps, that the global step/sec will be logged during training. + protocol: An optional argument which specifies the protocol used when + starting server. None means default to grpc. evaluation_master: the master on which to perform evaluation. model_dir: directory where model parameters, graph etc are saved. If `None`, will use `model_dir` property in `TF_CONFIG` environment @@ -290,8 +293,11 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): session_config: a ConfigProto used to set session parameters, or None. Note - using this argument, it is easy to provide settings which break otherwise perfectly good models. Use with care. - protocol: An optional argument which specifies the protocol used when - starting server. None means default to grpc. + session_creation_timeout_secs: Max time workers should wait for a session + to become available (on initialization or when recovering a session) + with MonitoredTrainingSession. Defaults to 7200 seconds, but users may + want to set a lower value to detect problems with variable / session + (re)-initialization more quickly. """ # Neither parent class calls super().__init__(), so here we have to # manually call their __init__() methods. @@ -332,6 +338,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): self._keep_checkpoint_max = keep_checkpoint_max self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours self._model_dir = _get_model_dir(model_dir) + self._session_creation_timeout_secs = session_creation_timeout_secs @experimental def uid(self, whitelist=None): @@ -408,6 +415,10 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): def log_step_count_steps(self): return self._log_step_count_steps + @property + def session_creation_timeout_secs(self): + return self._session_creation_timeout_secs + def _count_ps(cluster_spec): """Counts the number of parameter servers in cluster_spec.""" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt index 843298f61f8..b730913ca91 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt @@ -82,6 +82,10 @@ tf_class { name: "session_config" mtype: "" } + member { + name: "session_creation_timeout_secs" + mtype: "" + } member { name: "task_id" mtype: "" @@ -100,7 +104,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\', \'experimental_max_worker_delay_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'\', \'\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\', \'experimental_max_worker_delay_secs\', \'session_creation_timeout_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'\', \'\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'7200\'], " } member_method { name: "replace" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.-run-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.-run-config.pbtxt index ea95acf18e5..3c94cf708e0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.-run-config.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.-run-config.pbtxt @@ -87,6 +87,10 @@ tf_class { name: "session_config" mtype: "" } + member { + name: "session_creation_timeout_secs" + mtype: "" + } member { name: "task_id" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt index 843298f61f8..b730913ca91 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt @@ -82,6 +82,10 @@ tf_class { name: "session_config" mtype: "" } + member { + name: "session_creation_timeout_secs" + mtype: "" + } member { name: "task_id" mtype: "" @@ -100,7 +104,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\', \'experimental_max_worker_delay_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'\', \'\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\', \'experimental_distribute\', \'experimental_max_worker_delay_secs\', \'session_creation_timeout_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'\', \'\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'7200\'], " } member_method { name: "replace"