From ca94990804cf5326c0f6f46d75c96e0f0e240366 Mon Sep 17 00:00:00 2001 From: Yuefeng Zhou <yuefengz@google.com> Date: Fri, 24 Aug 2018 19:14:44 -0700 Subject: [PATCH] Add an option to RunConfig and train_and_evaluate to run distribute coordinator. This is necessary to run multi-worker MirroredStrategy and CollectiveAllReduceStrategy with estimator. PiperOrigin-RevId: 210192378 --- tensorflow/contrib/distribute/BUILD | 1 + tensorflow/contrib/distribute/__init__.py | 2 + tensorflow/contrib/distribute/python/BUILD | 26 + .../python/estimator_training_test.py | 659 ++++++++++++++++++ tensorflow/python/BUILD | 1 + tensorflow/python/distribute/BUILD | 33 + .../python/distribute/distribute_config.py | 45 ++ .../distribute/distribute_coordinator.py | 53 +- .../python/distribute/estimator_training.py | 264 +++++++ tensorflow/python/estimator/run_config.py | 20 +- tensorflow/python/estimator/training.py | 22 +- 11 files changed, 1100 insertions(+), 26 deletions(-) create mode 100644 tensorflow/contrib/distribute/python/estimator_training_test.py create mode 100644 tensorflow/python/distribute/distribute_config.py create mode 100644 tensorflow/python/distribute/estimator_training.py diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD index c16f1d6035d..02feeafb60a 100644 --- a/tensorflow/contrib/distribute/BUILD +++ b/tensorflow/contrib/distribute/BUILD @@ -35,5 +35,6 @@ py_library( "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/distribute:distribute_config", ], ) diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 588a4f2898b..bf763215ba2 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -27,6 +27,7 @@ from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceSt from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy +from tensorflow.python.distribute.distribute_config import DistributeConfig from tensorflow.python.training.distribute import * from tensorflow.python.training.distribution_strategy_context import * @@ -37,6 +38,7 @@ _allowed_symbols = [ 'AllReduceCrossTowerOps', 'CollectiveAllReduceStrategy', 'CrossTowerOps', + 'DistributeConfig', 'DistributionStrategy', 'MirroredStrategy', 'Monitor', diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 8173b5d4baf..f5b236e35f0 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -452,6 +452,32 @@ cuda_py_test( ], ) +cuda_py_test( + name = "estimator_training_test", + size = "large", + srcs = ["estimator_training_test.py"], + additional_deps = [ + ":combinations", + ":mirrored_strategy", + ":multi_worker_test_base", + ":parameter_server_strategy", + "//third_party/py/numpy", + "//tensorflow/contrib/optimizer_v2:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute", + "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/feature_column", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + py_library( name = "single_loss_example", srcs = ["single_loss_example.py"], diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py new file mode 100644 index 00000000000..5348512016e --- /dev/null +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -0,0 +1,659 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests that show Distribute Coordinator works with Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import json +import os +import sys +import tempfile +import threading +from absl.testing import parameterized +import numpy as np +import six + +_portpicker_import_error = None +try: + import portpicker # pylint: disable=g-import-not-at-top +except ImportError as _error: # pylint: disable=invalid-name + _portpicker_import_error = _error + portpicker = None + +# pylint: disable=g-import-not-at-top +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import parameter_server_strategy +from tensorflow.contrib.optimizer_v2 import adagrad +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_coordinator as dc +from tensorflow.python.distribute import estimator_training as dc_training +from tensorflow.python.distribute.distribute_config import DistributeConfig +from tensorflow.python.eager import context +from tensorflow.python.estimator import exporter as exporter_lib +from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.estimator import training as estimator_training +from tensorflow.python.estimator.canned import dnn_linear_combined +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export as export_lib +from tensorflow.python.feature_column import feature_column +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary import summary_iterator +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import server_lib + +BATCH_SIZE = 10 +LABEL_DIMENSION = 2 +DATA = np.linspace( + 0., 2., BATCH_SIZE * LABEL_DIMENSION, dtype=np.float32).reshape( + BATCH_SIZE, LABEL_DIMENSION) +EVAL_NAME = "foo" +EXPORTER_NAME = "saved_model_exporter" +MAX_STEPS = 10 + +CHIEF = dc._TaskType.CHIEF +EVALUATOR = dc._TaskType.EVALUATOR +WORKER = dc._TaskType.WORKER +PS = dc._TaskType.PS + +original_run_distribute_coordinator = dc.run_distribute_coordinator + + +# TODO(yuefengz): merge this method back to test_util. +def _create_local_cluster(num_workers, + num_ps, + has_eval=False, + protocol="grpc", + worker_config=None, + ps_config=None): + if _portpicker_import_error: + raise _portpicker_import_error # pylint: disable=raising-bad-type + worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] + ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] + + cluster_dict = { + "worker": ["localhost:%s" % port for port in worker_ports], + "ps": ["localhost:%s" % port for port in ps_ports] + } + if has_eval: + cluster_dict["evaluator"] = ["localhost:%s" % portpicker.pick_unused_port()] + + cs = server_lib.ClusterSpec(cluster_dict) + + workers = [ + server_lib.Server( + cs, + job_name="worker", + protocol=protocol, + task_index=ix, + config=worker_config, + start=True) for ix in range(num_workers) + ] + ps_servers = [ + server_lib.Server( + cs, + job_name="ps", + protocol=protocol, + task_index=ix, + config=ps_config, + start=True) for ix in range(num_ps) + ] + if has_eval: + evals = [ + server_lib.Server( + cs, + job_name="evaluator", + protocol=protocol, + task_index=0, + config=worker_config, + start=True) + ] + else: + evals = [] + + return workers, ps_servers, evals + + +def _create_in_process_cluster(num_workers, num_ps, has_eval=False): + """Create an in-process cluster that consists of only standard server.""" + # Leave some memory for cuda runtime. + if has_eval: + gpu_mem_frac = 0.7 / (num_workers + 1) + else: + gpu_mem_frac = 0.7 / num_workers + + worker_config = config_pb2.ConfigProto() + worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac + + # Enable collective ops which has no impact on non-collective ops. + # TODO(yuefengz, tucker): removing this after we move the initialization of + # collective mgr to the session level. + worker_config.experimental.collective_group_leader = ( + "/job:worker/replica:0/task:0") + + ps_config = config_pb2.ConfigProto() + ps_config.device_count["GPU"] = 0 + + return _create_local_cluster( + num_workers, + num_ps=num_ps, + has_eval=has_eval, + worker_config=worker_config, + ps_config=ps_config, + protocol="grpc") + + +def _create_cluster_spec(has_chief=False, + num_workers=1, + num_ps=0, + has_eval=False): + if _portpicker_import_error: + raise _portpicker_import_error # pylint: disable=raising-bad-type + + cluster_spec = {} + if has_chief: + cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()] + if num_workers: + cluster_spec[WORKER] = [ + "localhost:%s" % portpicker.pick_unused_port() + for _ in range(num_workers) + ] + if num_ps: + cluster_spec[PS] = [ + "localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps) + ] + if has_eval: + cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()] + return cluster_spec + + +def _bytes_to_str(maybe_bytes): + if isinstance(maybe_bytes, six.string_types): + return maybe_bytes + else: + return str(maybe_bytes, "utf-8") + + +def _strip_protocol(target): + # cluster_spec expects "host:port" strings. + if "//" in target: + return target.split("//")[1] + else: + return target + + +class DistributeCoordinatorIntegrationTest(test.TestCase, + parameterized.TestCase): + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + cls._workers, cls._ps, cls._evals = _create_in_process_cluster( + num_workers=3, num_ps=2, has_eval=True) + cls._cluster_spec = { + "worker": [ + _strip_protocol(_bytes_to_str(w.target)) for w in cls._workers + ], + "ps": [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps], + "evaluator": [ + _strip_protocol(_bytes_to_str(e.target)) for e in cls._evals + ] + } + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + self._event = threading.Event() + super(DistributeCoordinatorIntegrationTest, self).setUp() + + def dataset_input_fn(self, x, y, batch_size, shuffle): + + def input_fn(): + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + if shuffle: + dataset = dataset.shuffle(batch_size) + dataset = dataset.repeat(100).batch(batch_size) + return dataset + + return input_fn + + def _get_exporter(self, name, fc): + feature_spec = feature_column.make_parse_example_spec(fc) + serving_input_receiver_fn = ( + export_lib.build_parsing_serving_input_receiver_fn(feature_spec)) + return exporter_lib.LatestExporter( + name, serving_input_receiver_fn=serving_input_receiver_fn) + + def _extract_loss_and_global_step(self, event_folder): + """Returns the loss and global step in last event.""" + event_paths = glob.glob(os.path.join(event_folder, "events*")) + + loss = None + global_step_count = None + + for e in summary_iterator.summary_iterator(event_paths[-1]): + current_loss = None + for v in e.summary.value: + if v.tag == "loss": + current_loss = v.simple_value + + # If loss is not found, global step is meaningless. + if current_loss is None: + continue + + current_global_step = e.step + if global_step_count is None or current_global_step > global_step_count: + global_step_count = current_global_step + loss = current_loss + + return (loss, global_step_count) + + def _get_estimator(self, + train_distribute, + eval_distribute, + remote_cluster=None): + input_dimension = LABEL_DIMENSION + linear_feature_columns = [ + feature_column.numeric_column("x", shape=(input_dimension,)) + ] + dnn_feature_columns = [ + feature_column.numeric_column("x", shape=(input_dimension,)) + ] + + return dnn_linear_combined.DNNLinearCombinedRegressor( + linear_feature_columns=linear_feature_columns, + dnn_hidden_units=(2, 2), + dnn_feature_columns=dnn_feature_columns, + label_dimension=LABEL_DIMENSION, + model_dir=self._model_dir, + dnn_optimizer=adagrad.AdagradOptimizer(0.001), + linear_optimizer=adagrad.AdagradOptimizer(0.001), + config=run_config_lib.RunConfig( + experimental_distribute=DistributeConfig( + train_distribute=train_distribute, + eval_distribute=eval_distribute, + remote_cluster=remote_cluster))) + + def _complete_flow(self, + train_distribute, + eval_distribute, + remote_cluster=None): + estimator = self._get_estimator(train_distribute, eval_distribute, + remote_cluster) + + input_dimension = LABEL_DIMENSION + train_input_fn = self.dataset_input_fn( + x={"x": DATA}, + y=DATA, + batch_size=BATCH_SIZE // len(train_distribute.worker_devices), + shuffle=True) + if eval_distribute: + eval_batch_size = BATCH_SIZE // len(eval_distribute.worker_devices) + else: + eval_batch_size = BATCH_SIZE + eval_input_fn = self.dataset_input_fn( + x={"x": DATA}, y=DATA, batch_size=eval_batch_size, shuffle=False) + + linear_feature_columns = [ + feature_column.numeric_column("x", shape=(input_dimension,)) + ] + dnn_feature_columns = [ + feature_column.numeric_column("x", shape=(input_dimension,)) + ] + feature_columns = linear_feature_columns + dnn_feature_columns + + estimator_training.train_and_evaluate( + estimator, + estimator_training.TrainSpec(train_input_fn, max_steps=MAX_STEPS), + estimator_training.EvalSpec( + name=EVAL_NAME, + input_fn=eval_input_fn, + steps=None, + exporters=self._get_exporter(EXPORTER_NAME, feature_columns), + start_delay_secs=0, + throttle_secs=1)) + return estimator + + def _inspect_train_and_eval_events(self, estimator): + # Make sure nothing is stuck in limbo. + writer_cache.FileWriterCache.clear() + + # Examine the training events. Use a range to check global step to avoid + # flakyness due to global step race condition. + training_loss, _ = self._extract_loss_and_global_step(self._model_dir) + self.assertIsNotNone(training_loss) + + # Examine the eval events. The global step should be accurate. + eval_dir = os.path.join(self._model_dir, "eval_" + EVAL_NAME) + eval_loss, eval_global_step = self._extract_loss_and_global_step( + event_folder=eval_dir) + self.assertIsNotNone(eval_loss) + self.assertGreaterEqual(eval_global_step, MAX_STEPS) + + # Examine the export folder. + export_dir = os.path.join( + os.path.join(self._model_dir, "export"), EXPORTER_NAME) + self.assertTrue(gfile.Exists(export_dir)) + + # Examine the ckpt for predict. + def predict_input_fn(): + return dataset_ops.Dataset.from_tensor_slices({ + "x": DATA + }).batch(BATCH_SIZE) + + predicted_proba = np.array([ + x[prediction_keys.PredictionKeys.PREDICTIONS] + for x in estimator.predict(predict_input_fn) + ]) + self.assertAllEqual((BATCH_SIZE, LABEL_DIMENSION), predicted_proba.shape) + + @combinations.generate( + combinations.combine( + mode=["graph"], + train_distribute_cls=[ + mirrored_strategy.MirroredStrategy, + parameter_server_strategy.ParameterServerStrategy + ], + eval_distribute_cls=[ + None, mirrored_strategy.MirroredStrategy, + parameter_server_strategy.ParameterServerStrategy + ], + required_gpus=1)) + def test_complete_flow_standalone_client(self, train_distribute_cls, + eval_distribute_cls): + try: + train_distribute = train_distribute_cls(num_gpus=context.num_gpus()) + except TypeError: + train_distribute = train_distribute_cls(num_gpus_per_worker=2) + + if eval_distribute_cls: + eval_distribute = eval_distribute_cls() + else: + eval_distribute = None + + estimator = self._complete_flow( + train_distribute, eval_distribute, remote_cluster=self._cluster_spec) + self._inspect_train_and_eval_events(estimator) + + def _mock_run_distribute_coordinator( + self, + worker_fn, + strategy, + eval_fn, + eval_strategy, + mode=dc.CoordinatorMode.STANDALONE_CLIENT, + cluster_spec=None, + session_config=None): + # Calls the origial `run_distribute_coordinator` method but gets task config + # from environment variables and then signals the caller. + task_type = None + task_id = None + if not cluster_spec: + cluster_spec = None + tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) + if not cluster_spec: + cluster_spec = tf_config.get("cluster", {}) + task_env = tf_config.get("task", {}) + if task_env: + task_type = task_env.get("type", task_type) + task_id = int(task_env.get("index", task_id)) + self._event.set() + original_run_distribute_coordinator( + worker_fn, + strategy, + eval_fn, + eval_strategy, + mode=mode, + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id, + session_config=session_config) + + def _task_thread(self, train_distribute, eval_distribute): + with test.mock.patch.object(dc, "run_distribute_coordinator", + self._mock_run_distribute_coordinator): + self._complete_flow(train_distribute, eval_distribute) + + def _run_task_in_thread(self, cluster_spec, task_type, task_id, + train_distribute, eval_distribute): + if task_type: + tf_config = { + "cluster": cluster_spec, + "task": { + "type": task_type, + "index": task_id + } + } + else: + tf_config = { + "cluster": cluster_spec, + "task": { + "type": task_type, + "index": task_id + } + } + self._event.clear() + t = threading.Thread( + target=self._task_thread, args=(train_distribute, eval_distribute)) + with test.mock.patch.dict("os.environ", + {"TF_CONFIG": json.dumps(tf_config)}): + t.start() + self._event.wait() + return t + + def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute, + eval_distribute): + threads = {} + for task_type in cluster_spec.keys(): + threads[task_type] = [] + for task_id in range(len(cluster_spec[task_type])): + t = self._run_task_in_thread(cluster_spec, task_type, task_id, + train_distribute, eval_distribute) + threads[task_type].append(t) + return threads + + @combinations.generate( + combinations.combine( + mode=["graph"], + train_distribute_cls=[ + parameter_server_strategy.ParameterServerStrategy, + ], + eval_distribute_cls=[ + None, mirrored_strategy.MirroredStrategy, + parameter_server_strategy.ParameterServerStrategy + ], + required_gpus=1)) + def test_complete_flow_indepedent_worker_between_graph( + self, train_distribute_cls, eval_distribute_cls): + train_distribute = train_distribute_cls( + num_gpus_per_worker=context.num_gpus()) + + if eval_distribute_cls: + eval_distribute = eval_distribute_cls() + else: + eval_distribute = None + + cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True) + threads = self._run_multiple_tasks_in_threads( + cluster_spec, train_distribute, eval_distribute) + for task_type, ts in threads.items(): + if task_type == PS: + continue + for t in ts: + t.join() + + estimator = self._get_estimator(train_distribute, eval_distribute) + self._inspect_train_and_eval_events(estimator) + + @combinations.generate( + combinations.combine( + mode=["graph"], + train_distribute_cls=[mirrored_strategy.MirroredStrategy], + eval_distribute_cls=[None, mirrored_strategy.MirroredStrategy], + required_gpus=1)) + def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls, + eval_distribute_cls): + train_distribute = train_distribute_cls(num_gpus=context.num_gpus()) + + if eval_distribute_cls: + eval_distribute = eval_distribute_cls() + else: + eval_distribute = None + + cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True) + threads = self._run_multiple_tasks_in_threads( + cluster_spec, train_distribute, eval_distribute) + threads[WORKER][0].join() + threads[EVALUATOR][0].join() + + estimator = self._get_estimator(train_distribute, eval_distribute) + self._inspect_train_and_eval_events(estimator) + + +TF_CONFIG_WITH_CHIEF = { + "cluster": { + "chief": ["fake_chief"], + }, + "task": { + "type": "chief", + "index": 0 + } +} + +TF_CONFIG_WITH_MASTER = { + "cluster": { + "master": ["fake_master"], + }, + "task": { + "type": "master", + "index": 0 + } +} + +TF_CONFIG_WITHOUT_TASK = {"cluster": {"chief": ["fake_worker"]}} + + +class RunConfigTest(test.TestCase): + + def test_previously_unexpected_cluster_spec(self): + with test.mock.patch.dict( + "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITHOUT_TASK)}): + run_config_lib.RunConfig( + experimental_distribute=DistributeConfig( + train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + + def test_should_run_distribute_coordinator(self): + """Tests that should_run_distribute_coordinator return a correct value.""" + # We don't use distribute coordinator for local training. + self.assertFalse( + dc_training.should_run_distribute_coordinator( + run_config_lib.RunConfig())) + + # When `train_distribute` is not specified, don't use distribute + # coordinator. + with test.mock.patch.dict("os.environ", + {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): + self.assertFalse( + dc_training.should_run_distribute_coordinator( + run_config_lib.RunConfig())) + + # When `train_distribute` is specified and TF_CONFIG is detected, use + # distribute coordinator. + with test.mock.patch.dict("os.environ", + {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): + config_with_train_distribute = run_config_lib.RunConfig( + experimental_distribute=DistributeConfig( + train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + config_with_eval_distribute = run_config_lib.RunConfig( + experimental_distribute=DistributeConfig( + eval_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + self.assertTrue( + dc_training.should_run_distribute_coordinator( + config_with_train_distribute)) + self.assertFalse( + dc_training.should_run_distribute_coordinator( + config_with_eval_distribute)) + + # With a master in the cluster, don't run distribute coordinator. + with test.mock.patch.dict("os.environ", + {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}): + config = run_config_lib.RunConfig( + experimental_distribute=DistributeConfig( + train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + self.assertFalse(dc_training.should_run_distribute_coordinator(config)) + + def test_init_run_config_duplicate_distribute(self): + with self.assertRaises(ValueError): + run_config_lib.RunConfig( + train_distribute=mirrored_strategy.MirroredStrategy(), + experimental_distribute=DistributeConfig( + train_distribute=mirrored_strategy.MirroredStrategy())) + + with self.assertRaises(ValueError): + run_config_lib.RunConfig( + eval_distribute=mirrored_strategy.MirroredStrategy(), + experimental_distribute=DistributeConfig( + eval_distribute=mirrored_strategy.MirroredStrategy())) + + def test_init_run_config_none_distribute_coordinator_mode(self): + # We don't use distribute coordinator for local training. + config = run_config_lib.RunConfig( + train_distribute=mirrored_strategy.MirroredStrategy()) + dc_training.init_run_config(config, {}) + self.assertIsNone(config._distribute_coordinator_mode) + + # With a master in the cluster, don't run distribute coordinator. + with test.mock.patch.dict("os.environ", + {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}): + config = run_config_lib.RunConfig( + train_distribute=mirrored_strategy.MirroredStrategy()) + self.assertIsNone(config._distribute_coordinator_mode) + + # When `train_distribute` is not specified, don't use distribute + # coordinator. + with test.mock.patch.dict("os.environ", + {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): + config = run_config_lib.RunConfig() + self.assertFalse(hasattr(config, "_distribute_coordinator_mode")) + + def test_init_run_config_independent_worker(self): + # When `train_distribute` is specified and TF_CONFIG is detected, use + # distribute coordinator with INDEPENDENT_WORKER mode. + with test.mock.patch.dict("os.environ", + {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): + config = run_config_lib.RunConfig( + train_distribute=mirrored_strategy.MirroredStrategy()) + self.assertEqual(config._distribute_coordinator_mode, + dc.CoordinatorMode.INDEPENDENT_WORKER) + + def test_init_run_config_standalone_client(self): + # When `train_distribute` is specified, TF_CONFIG is detected and + # `experimental.remote_cluster` is set use distribute coordinator with + # STANDALONE_CLIENT mode. + config = run_config_lib.RunConfig( + train_distribute=mirrored_strategy.MirroredStrategy(), + experimental_distribute=DistributeConfig( + remote_cluster={"chief": ["fake_worker"]})) + self.assertEqual(config._distribute_coordinator_mode, + dc.CoordinatorMode.STANDALONE_CLIENT) + + +if __name__ == "__main__": + with test.mock.patch.object(sys, "exit", os._exit): + test.main() diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 40f98474b56..37af3d350e5 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -134,6 +134,7 @@ py_library( "//tensorflow/core:protos_all_py", "//tensorflow/python/compat", "//tensorflow/python/data", + "//tensorflow/python/distribute:estimator_training", "//tensorflow/python/feature_column:feature_column_py", "//tensorflow/python/keras", "//tensorflow/python/ops/distributions", diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 98ef9bf4926..ebfcd085e66 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -8,6 +8,25 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "py_test") +py_library( + name = "distribute", + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":distribute_config", + ":distribute_coordinator", + ":distribute_coordinator_context", + ], +) + +py_library( + name = "distribute_config", + srcs = [ + "distribute_config.py", + ], + deps = [], +) + py_library( name = "distribute_coordinator", srcs = [ @@ -81,3 +100,17 @@ py_test( "@absl_py//absl/testing:parameterized", ], ) + +# Used only by estimator. +py_library( + name = "estimator_training", + srcs = [ + "estimator_training.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":distribute_coordinator", + ":distribute_coordinator_context", + "//tensorflow/python:training", + ], +) diff --git a/tensorflow/python/distribute/distribute_config.py b/tensorflow/python/distribute/distribute_config.py new file mode 100644 index 00000000000..fac35742fe0 --- /dev/null +++ b/tensorflow/python/distribute/distribute_config.py @@ -0,0 +1,45 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A configure tuple for high-level APIs for running distribution strategies.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + + +class DistributeConfig( + collections.namedtuple( + 'DistributeConfig', + ['train_distribute', 'eval_distribute', 'remote_cluster'])): + """A config tuple for distribution strategies. + + Attributes: + train_distribute: a `DistributionStrategy` object for training. + eval_distribute: an optional `DistributionStrategy` object for + evaluation. + remote_cluster: a dict, `ClusterDef` or `ClusterSpec` object specifying + the cluster configurations. If this is given, the `train_and_evaluate` + method will be running as a standalone client which connects to the + cluster for training. + """ + + def __new__(cls, + train_distribute=None, + eval_distribute=None, + remote_cluster=None): + return super(DistributeConfig, cls).__new__(cls, train_distribute, + eval_distribute, remote_cluster) diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py index eb081b65fc7..9cf0b3b7a68 100644 --- a/tensorflow/python/distribute/distribute_coordinator.py +++ b/tensorflow/python/distribute/distribute_coordinator.py @@ -311,7 +311,11 @@ def _run_single_worker(worker_fn, worker_barrier=None): """Runs a single worker by calling `worker_fn` under context.""" strategy = copy.deepcopy(strategy) - strategy.configure(session_config, cluster_spec, task_type, task_id) + # If there is an EVALUATOR task, we run single-machine eval on that task. + if task_type == _TaskType.EVALUATOR: + strategy.configure(session_config) + else: + strategy.configure(session_config, cluster_spec, task_type, task_id) context = _WorkerContext( strategy, cluster_spec, @@ -340,14 +344,14 @@ def _run_std_server(cluster_spec=None, return server -def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config, - rpc_layer): +def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy, + cluster_spec, session_config, rpc_layer): """Runs a standalone client for between-graph replication.""" eval_thread = None if _TaskType.EVALUATOR in cluster_spec.jobs: eval_thread = threading.Thread( target=_run_single_worker, - args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0, + args=(eval_fn, eval_strategy, None, _TaskType.EVALUATOR, 0, session_config), kwargs={ "rpc_layer": rpc_layer, @@ -378,14 +382,14 @@ def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config, eval_thread.join() -def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config, - rpc_layer): +def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy, + cluster_spec, session_config, rpc_layer): """Runs a standalone client for in-graph replication.""" eval_thread = None if _TaskType.EVALUATOR in cluster_spec.jobs: eval_thread = threading.Thread( target=_run_single_worker, - args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0, + args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0, session_config), kwargs={ "rpc_layer": rpc_layer, @@ -408,6 +412,8 @@ def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config, # is the special task when we support cluster_spec propagation. def run_distribute_coordinator(worker_fn, strategy, + eval_fn=None, + eval_strategy=None, mode=CoordinatorMode.STANDALONE_CLIENT, cluster_spec=None, task_type=None, @@ -488,10 +494,12 @@ def run_distribute_coordinator(worker_fn, If `cluster_spec` is not given in any format, it becomes local training and this coordinator will connect to a local session. - For evaluation, if "evaluator" exist in the cluster_spec, a separate thread - will be created with its `task_type` set to "evaluator". If "evaluator" is not - set in the cluster_spec, it entirely depends on the `worker_fn` for how to do - evaluation. + For evaluation, if "evaluator" exists in the cluster_spec, a separate thread + will be created to call `eval_fn` with its `task_type` set to "evaluator". If + `eval_fn` is not defined, fall back to `worker_fn`. This implies that + evaluation will be done on a single machine if there is an "evaluator" task. + If "evaluator" doesn't exit in the cluster_spec, it entirely depends on the + `worker_fn` for how to do evaluation. Args: worker_fn: the function to be called. The function should accept a @@ -501,6 +509,8 @@ def run_distribute_coordinator(worker_fn, run between-graph replicated training or not, whether to run init ops, etc. This object will also be configured given `session_config`, `cluster_spc`, `task_type` and `task_id`. + eval_fn: optional function for "evaluator" task. + eval_strategy: optional DistributionStrategy object for "evaluator" task. mode: in which mode this distribute coordinator runs. cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles in a cluster. If not set or empty, fall back to local training. @@ -535,16 +545,22 @@ def run_distribute_coordinator(worker_fn, # `mode` is ignored in the local case. _run_single_worker(worker_fn, strategy, None, None, None, session_config, rpc_layer) + if eval_fn: + _run_single_worker(eval_fn, eval_strategy or strategy, None, None, None, + session_config, rpc_layer) elif mode == CoordinatorMode.STANDALONE_CLIENT: + eval_fn = eval_fn or worker_fn + eval_strategy = eval_strategy or strategy + # The client must know the cluster but servers in the cluster don't have to # know the client. if task_type in [_TaskType.CLIENT, None]: if strategy.between_graph: - _run_between_graph_client(worker_fn, strategy, cluster_spec, - session_config, rpc_layer) + _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy, + cluster_spec, session_config, rpc_layer) else: - _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config, - rpc_layer) + _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy, + cluster_spec, session_config, rpc_layer) else: # If not a client job, run the standard server. server = _run_std_server( @@ -554,6 +570,9 @@ def run_distribute_coordinator(worker_fn, if mode != CoordinatorMode.INDEPENDENT_WORKER: raise ValueError("Unexpected coordinator mode: %r" % mode) + eval_fn = eval_fn or worker_fn + eval_strategy = eval_strategy or strategy + # Every one starts a standard server. server = _run_std_server( cluster_spec=cluster_spec, task_type=task_type, task_id=task_id) @@ -572,8 +591,8 @@ def run_distribute_coordinator(worker_fn, else: server.join() elif task_type == _TaskType.EVALUATOR: - _run_single_worker(worker_fn, strategy, cluster_spec, task_type, task_id, - session_config, rpc_layer) + _run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type, + task_id, session_config, rpc_layer) else: if task_type != _TaskType.PS: raise ValueError("Unexpected task_type: %r" % task_type) diff --git a/tensorflow/python/distribute/estimator_training.py b/tensorflow/python/distribute/estimator_training.py new file mode 100644 index 00000000000..202e19c420e --- /dev/null +++ b/tensorflow/python/distribute/estimator_training.py @@ -0,0 +1,264 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Training utilities for Estimator to use Distribute Coordinator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +import six + +from tensorflow.python.distribute import distribute_coordinator as dc +from tensorflow.python.distribute import distribute_coordinator_context as dc_context +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import server_lib + +# pylint: disable=protected-access +CHIEF = dc._TaskType.CHIEF +EVALUATOR = dc._TaskType.EVALUATOR +PS = dc._TaskType.PS +WORKER = dc._TaskType.WORKER + +# pylint: enable=protected-access + + +def _count_ps(cluster_spec): + """Counts the number of parameter servers in cluster_spec.""" + if not cluster_spec: + raise RuntimeError( + 'Internal error: `_count_ps` does not expect empty cluster_spec.') + + return len(cluster_spec.as_dict().get(PS, [])) + + +def _count_worker(cluster_spec, chief_task_type): + """Counts the number of workers (including chief) in cluster_spec.""" + if not cluster_spec: + raise RuntimeError( + 'Internal error: `_count_worker` does not expect empty cluster_spec.') + + return (len(cluster_spec.as_dict().get(WORKER, [])) + len( + cluster_spec.as_dict().get(chief_task_type, []))) + + +def _get_global_id(cluster_spec, task_type, task_id, chief_task_type): + """Returns the global id of the given task type in a cluster.""" + if not task_type: + return 0 + + # Sort task names in cluster by "chief"/"master", "evaluator", "worker" + # and "ps". More details can be found at the documentation of + # @{tf.estimator.RunConfig.global_id_in_cluster}. + task_type_ordered_list = [] + if chief_task_type in cluster_spec.jobs: + task_type_ordered_list = [chief_task_type] + task_type_ordered_list.extend([ + t for t in sorted(cluster_spec.jobs) if t != chief_task_type and t != PS + ]) + if PS in cluster_spec.jobs: + task_type_ordered_list.append(PS) + + # Find the right gloabl_id for current task. + next_global_id = 0 + for t in task_type_ordered_list: + if t == task_type: + return next_global_id + task_id + # `cluster_spec.job_tasks` returns all task addresses of type `t`. + next_global_id += len(cluster_spec.job_tasks(t)) + + # It is unexpected that it passes through all task_types in + # `task_type_ordered_list`. + raise RuntimeError('Internal Error: `task_type` ({}) is not in ' + 'cluster_spec ({}).'.format(task_type, cluster_spec)) + + +def _init_run_config_from_worker_context(config, worker_context): + """Initializes run config from distribute coordinator's worker context.""" + + # pylint: disable=protected-access + config._service = None + config._cluster_spec = worker_context.cluster_spec + config._task_type = worker_context.task_type + config._task_id = worker_context.task_id + config._evaluation_master = worker_context.master_target + config._master = worker_context.master_target + config._is_chief = worker_context.is_chief + + if config._cluster_spec: + # Distributed mode. + if config._task_type != EVALUATOR: + + config._num_ps_replicas = _count_ps(config._cluster_spec) + config._num_worker_replicas = _count_worker( + config._cluster_spec, chief_task_type=CHIEF) + config._global_id_in_cluster = _get_global_id( + config._cluster_spec, + config._task_type, + config._task_id, + chief_task_type=CHIEF) + else: + # Evaluator task should not be aware of the other tasks. + config._cluster_spec = server_lib.ClusterSpec({}) + config._num_ps_replicas = 0 + config._num_worker_replicas = 0 + config._global_id_in_cluster = None # undefined + else: + # Local mode. + config._global_id_in_cluster = 0 + config._num_ps_replicas = 0 + config._num_worker_replicas = 1 + + +def init_run_config(config, tf_config): + """Initializes RunConfig for distribution strategies.""" + # pylint: disable=protected-access + if (config._experimental_distribute and + config._experimental_distribute.train_distribute): + if config._train_distribute: + raise ValueError('Either `train_distribute` or' + '`experimental_distribute.train_distribute` can be set.') + config._train_distribute = config._experimental_distribute.train_distribute + + if (config._experimental_distribute and + config._experimental_distribute.eval_distribute): + if config._eval_distribute: + raise ValueError('Either `eval_distribute` or' + '`experimental_distribute.eval_distribute` can be set.') + config._eval_distribute = config._experimental_distribute.eval_distribute + + cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {})) + config._init_distributed_setting_from_environment_var({}) + + # Use distribute coordinator with STANDALONE_CLIENT mode if + # `experimental_distribute.remote_cluster` is set. + if (config._train_distribute and config._experimental_distribute and + config._experimental_distribute.remote_cluster): + if tf_config: + raise ValueError('Cannot set both TF_CONFIG environment variable and ' + '`experimental_distribute.remote_cluster`') + config._distribute_coordinator_mode = dc.CoordinatorMode.STANDALONE_CLIENT + config._cluster_spec = config._experimental_distribute.remote_cluster + logging.info('RunConfig initialized for Distribute Coordinator with ' + 'STANDALONE_CLIENT mode') + return + + # Don't use distribute coordinator if it is local training or cluster has a + # MASTER job or `train_distribute` is not specifed. + if (not tf_config or 'master' in cluster_spec.jobs or + not config._train_distribute): + config._distribute_coordinator_mode = None + config._init_distributed_setting_from_environment_var(tf_config) + config._maybe_overwrite_session_config_for_distributed_training() + logging.info('Not using Distribute Coordinator.') + return + + # Use distribute coordinator with INDEPENDENT_WORKER mode otherwise. + assert tf_config + + # Set the cluster_spec only since the distributed setting will come from + # distribute coordinator. + config._cluster_spec = cluster_spec + config._distribute_coordinator_mode = dc.CoordinatorMode.INDEPENDENT_WORKER + logging.info('RunConfig initialized for Distribute Coordinator with ' + 'INDEPENDENT_WORKER mode') + + +def should_run_distribute_coordinator(config): + """Checks the config to see whether to run distribute coordinator.""" + # pylint: disable=protected-access + if (not hasattr(config, '_distribute_coordinator_mode') or + config._distribute_coordinator_mode is None): + return False + if (not isinstance(config._distribute_coordinator_mode, six.string_types) or + config._distribute_coordinator_mode not in [ + dc.CoordinatorMode.STANDALONE_CLIENT, + dc.CoordinatorMode.INDEPENDENT_WORKER + ]): + logging.warning('Unexpected distribute_coordinator_mode: %r', + config._distribute_coordinator_mode) + return False + if not config.cluster_spec: + logging.warning('Running `train_and_evaluate` locally, ignoring ' + '`experimental_distribute_coordinator_mode`.') + return False + return True + + +def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls): + """Run distribute coordinator for Estimator's `train_and_evaluate`. + + Args: + estimator: An `Estimator` instance to train and evaluate. + train_spec: A `TrainSpec` instance to specify the training specification. + eval_spec: A `EvalSpec` instance to specify the evaluation and export + specification. + executor_cls: the evaluation executor class of Estimator. + + Raises: + ValueError: if `distribute_coordinator_mode` is None in RunConfig. + """ + run_config = estimator.config + if not run_config._distribute_coordinator_mode: # pylint: disable=protected-access + raise ValueError( + 'Distribute coordinator mode is not specified in `RunConfig`.') + + def _worker_fn(strategy): + """Function for worker task.""" + local_estimator = copy.deepcopy(estimator) + # pylint: disable=protected-access + local_estimator._config._train_distribute = strategy + _init_run_config_from_worker_context( + local_estimator._config, dc_context.get_current_worker_context()) + local_estimator._train_distribution = strategy + # pylint: enable=protected-access + + local_estimator.train( + input_fn=train_spec.input_fn, + max_steps=train_spec.max_steps, + hooks=list(train_spec.hooks)) + + def _eval_fn(strategy): + """Function for evaluator task.""" + local_estimator = copy.deepcopy(estimator) + # pylint: disable=protected-access + local_estimator._config._eval_distribute = strategy + _init_run_config_from_worker_context( + local_estimator._config, dc_context.get_current_worker_context()) + local_estimator._eval_distribution = strategy + + executor = executor_cls(local_estimator, train_spec, eval_spec) + executor._start_continuous_evaluation() + # pylint: enable=protected-access + + # pylint: disable=protected-access + if (run_config._distribute_coordinator_mode == + dc.CoordinatorMode.STANDALONE_CLIENT): + cluster_spec = run_config.cluster_spec + assert cluster_spec + else: + # The cluster_spec comes from TF_CONFIG environment variable if it is + # INDEPENDENT_WORKER mode. + cluster_spec = None + + dc.run_distribute_coordinator( + _worker_fn, + run_config.train_distribute, + _eval_fn, + run_config.eval_distribute, + mode=run_config._distribute_coordinator_mode, + cluster_spec=cluster_spec, + session_config=run_config.session_config) diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index 12daddb044c..b1ca207b621 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -26,6 +26,7 @@ import six from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.distribute import estimator_training as distribute_coordinator_training from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib from tensorflow.python.util import compat_internal @@ -460,7 +461,8 @@ class RunConfig(object): train_distribute: An optional instance of `tf.contrib.distribute.DistributionStrategy`. If specified, then Estimator will distribute the user's model during training, - according to the policy specified by that strategy. + according to the policy specified by that strategy. Setting + `experimental_distribute.train_distribute` is preferred. device_fn: A callable invoked for every `Operation` that takes the `Operation` and returns the device string. If `None`, defaults to the device function returned by `tf.train.replica_device_setter` @@ -470,10 +472,13 @@ class RunConfig(object): eval_distribute: An optional instance of `tf.contrib.distribute.DistributionStrategy`. If specified, then Estimator will distribute the user's model during evaluation, - according to the policy specified by that strategy. + according to the policy specified by that strategy. Setting + `experimental_distribute.eval_distribute` is preferred. experimental_distribute: an optional `tf.contrib.distribute.DistributeConfig` object specifying - DistributionStrategy-related configuration. + DistributionStrategy-related configuration. The `train_distribute` and + `eval_distribute` can be passed as parameters to `RunConfig` or set in + `experimental_distribute` but not both. Raises: ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs` @@ -516,9 +521,12 @@ class RunConfig(object): eval_distribute=eval_distribute, experimental_distribute=experimental_distribute) - self._init_distributed_setting_from_environment_var(tf_config) - - self._maybe_overwrite_session_config_for_distributed_training() + if train_distribute or eval_distribute or experimental_distribute: + logging.info('Initializing RunConfig with distribution strategies.') + distribute_coordinator_training.init_run_config(self, tf_config) + else: + self._init_distributed_setting_from_environment_var(tf_config) + self._maybe_overwrite_session_config_for_distributed_training() def _maybe_overwrite_session_config_for_distributed_training(self): """Overwrites the session_config for distributed training. diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index e6bd263c80f..240be5dabe8 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -26,6 +26,7 @@ import time import six from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.distribute import estimator_training as distribute_coordinator_training from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import exporter as exporter_lib from tensorflow.python.estimator import run_config as run_config_lib @@ -274,8 +275,10 @@ def train_and_evaluate(estimator, train_spec, eval_spec): evaluation `input_fn`, steps, etc. This utility function provides consistent behavior for both local - (non-distributed) and distributed configurations. Currently, the only - supported distributed training configuration is between-graph replication. + (non-distributed) and distributed configurations. The default distribution + configuration is parameter server-based between-graph replication. For other + types of distribution configurations such as all-reduce training, please use + [DistributionStrategies](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute). # pylint: disable=line-too-long Overfitting: In order to avoid overfitting, it is recommended to set up the training `input_fn` to shuffle the training data properly. @@ -426,6 +429,11 @@ def train_and_evaluate(estimator, train_spec, eval_spec): }' ``` + When `distribute` or `experimental_distribute.train_distribute` and + `experimental_distribute.remote_cluster` is set, this method will start a + client running on the current host which connects to the `remote_cluster` for + training and evaluation. + Args: estimator: An `Estimator` instance to train and evaluate. train_spec: A `TrainSpec` instance to specify the training specification. @@ -444,8 +452,16 @@ def train_and_evaluate(estimator, train_spec, eval_spec): executor = _TrainingExecutor( estimator=estimator, train_spec=train_spec, eval_spec=eval_spec) - config = estimator.config + + # If `distribute_coordinator_mode` is set and running in distributed + # environment, we run `train_and_evaluate` via distribute coordinator. + if distribute_coordinator_training.should_run_distribute_coordinator(config): + logging.info('Running `train_and_evaluate` with Distribute Coordinator.') + distribute_coordinator_training.train_and_evaluate( + estimator, train_spec, eval_spec, _TrainingExecutor) + return + if (config.task_type == run_config_lib.TaskType.EVALUATOR and config.task_id > 0): raise ValueError(