Add an option to RunConfig and train_and_evaluate to run distribute coordinator.

This is necessary to run multi-worker MirroredStrategy and CollectiveAllReduceStrategy with estimator.

PiperOrigin-RevId: 210192378
This commit is contained in:
Yuefeng Zhou 2018-08-24 19:14:44 -07:00 committed by TensorFlower Gardener
parent 9599b47303
commit ca94990804
11 changed files with 1100 additions and 26 deletions

View File

@ -35,5 +35,6 @@ py_library(
"//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/python:training", "//tensorflow/python:training",
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python/distribute:distribute_config",
], ],
) )

View File

@ -27,6 +27,7 @@ from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceSt
from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy
from tensorflow.contrib.distribute.python.step_fn import * from tensorflow.contrib.distribute.python.step_fn import *
from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy
from tensorflow.python.distribute.distribute_config import DistributeConfig
from tensorflow.python.training.distribute import * from tensorflow.python.training.distribute import *
from tensorflow.python.training.distribution_strategy_context import * from tensorflow.python.training.distribution_strategy_context import *
@ -37,6 +38,7 @@ _allowed_symbols = [
'AllReduceCrossTowerOps', 'AllReduceCrossTowerOps',
'CollectiveAllReduceStrategy', 'CollectiveAllReduceStrategy',
'CrossTowerOps', 'CrossTowerOps',
'DistributeConfig',
'DistributionStrategy', 'DistributionStrategy',
'MirroredStrategy', 'MirroredStrategy',
'Monitor', 'Monitor',

View File

@ -452,6 +452,32 @@ cuda_py_test(
], ],
) )
cuda_py_test(
name = "estimator_training_test",
size = "large",
srcs = ["estimator_training_test.py"],
additional_deps = [
":combinations",
":mirrored_strategy",
":multi_worker_test_base",
":parameter_server_strategy",
"//third_party/py/numpy",
"//tensorflow/contrib/optimizer_v2:training",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute",
"//tensorflow/python/eager:test",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/feature_column",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:summary",
],
tags = [
"multi_and_single_gpu",
"no_pip",
],
)
py_library( py_library(
name = "single_loss_example", name = "single_loss_example",
srcs = ["single_loss_example.py"], srcs = ["single_loss_example.py"],

View File

@ -0,0 +1,659 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests that show Distribute Coordinator works with Estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import json
import os
import sys
import tempfile
import threading
from absl.testing import parameterized
import numpy as np
import six
_portpicker_import_error = None
try:
import portpicker # pylint: disable=g-import-not-at-top
except ImportError as _error: # pylint: disable=invalid-name
_portpicker_import_error = _error
portpicker = None
# pylint: disable=g-import-not-at-top
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import parameter_server_strategy
from tensorflow.contrib.optimizer_v2 import adagrad
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import estimator_training as dc_training
from tensorflow.python.distribute.distribute_config import DistributeConfig
from tensorflow.python.eager import context
from tensorflow.python.estimator import exporter as exporter_lib
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.estimator import training as estimator_training
from tensorflow.python.estimator.canned import dnn_linear_combined
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.export import export as export_lib
from tensorflow.python.feature_column import feature_column
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import server_lib
BATCH_SIZE = 10
LABEL_DIMENSION = 2
DATA = np.linspace(
0., 2., BATCH_SIZE * LABEL_DIMENSION, dtype=np.float32).reshape(
BATCH_SIZE, LABEL_DIMENSION)
EVAL_NAME = "foo"
EXPORTER_NAME = "saved_model_exporter"
MAX_STEPS = 10
CHIEF = dc._TaskType.CHIEF
EVALUATOR = dc._TaskType.EVALUATOR
WORKER = dc._TaskType.WORKER
PS = dc._TaskType.PS
original_run_distribute_coordinator = dc.run_distribute_coordinator
# TODO(yuefengz): merge this method back to test_util.
def _create_local_cluster(num_workers,
num_ps,
has_eval=False,
protocol="grpc",
worker_config=None,
ps_config=None):
if _portpicker_import_error:
raise _portpicker_import_error # pylint: disable=raising-bad-type
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
cluster_dict = {
"worker": ["localhost:%s" % port for port in worker_ports],
"ps": ["localhost:%s" % port for port in ps_ports]
}
if has_eval:
cluster_dict["evaluator"] = ["localhost:%s" % portpicker.pick_unused_port()]
cs = server_lib.ClusterSpec(cluster_dict)
workers = [
server_lib.Server(
cs,
job_name="worker",
protocol=protocol,
task_index=ix,
config=worker_config,
start=True) for ix in range(num_workers)
]
ps_servers = [
server_lib.Server(
cs,
job_name="ps",
protocol=protocol,
task_index=ix,
config=ps_config,
start=True) for ix in range(num_ps)
]
if has_eval:
evals = [
server_lib.Server(
cs,
job_name="evaluator",
protocol=protocol,
task_index=0,
config=worker_config,
start=True)
]
else:
evals = []
return workers, ps_servers, evals
def _create_in_process_cluster(num_workers, num_ps, has_eval=False):
"""Create an in-process cluster that consists of only standard server."""
# Leave some memory for cuda runtime.
if has_eval:
gpu_mem_frac = 0.7 / (num_workers + 1)
else:
gpu_mem_frac = 0.7 / num_workers
worker_config = config_pb2.ConfigProto()
worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
# Enable collective ops which has no impact on non-collective ops.
# TODO(yuefengz, tucker): removing this after we move the initialization of
# collective mgr to the session level.
worker_config.experimental.collective_group_leader = (
"/job:worker/replica:0/task:0")
ps_config = config_pb2.ConfigProto()
ps_config.device_count["GPU"] = 0
return _create_local_cluster(
num_workers,
num_ps=num_ps,
has_eval=has_eval,
worker_config=worker_config,
ps_config=ps_config,
protocol="grpc")
def _create_cluster_spec(has_chief=False,
num_workers=1,
num_ps=0,
has_eval=False):
if _portpicker_import_error:
raise _portpicker_import_error # pylint: disable=raising-bad-type
cluster_spec = {}
if has_chief:
cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()]
if num_workers:
cluster_spec[WORKER] = [
"localhost:%s" % portpicker.pick_unused_port()
for _ in range(num_workers)
]
if num_ps:
cluster_spec[PS] = [
"localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps)
]
if has_eval:
cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
return cluster_spec
def _bytes_to_str(maybe_bytes):
if isinstance(maybe_bytes, six.string_types):
return maybe_bytes
else:
return str(maybe_bytes, "utf-8")
def _strip_protocol(target):
# cluster_spec expects "host:port" strings.
if "//" in target:
return target.split("//")[1]
else:
return target
class DistributeCoordinatorIntegrationTest(test.TestCase,
parameterized.TestCase):
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
cls._workers, cls._ps, cls._evals = _create_in_process_cluster(
num_workers=3, num_ps=2, has_eval=True)
cls._cluster_spec = {
"worker": [
_strip_protocol(_bytes_to_str(w.target)) for w in cls._workers
],
"ps": [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps],
"evaluator": [
_strip_protocol(_bytes_to_str(e.target)) for e in cls._evals
]
}
def setUp(self):
self._model_dir = tempfile.mkdtemp()
self._event = threading.Event()
super(DistributeCoordinatorIntegrationTest, self).setUp()
def dataset_input_fn(self, x, y, batch_size, shuffle):
def input_fn():
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
if shuffle:
dataset = dataset.shuffle(batch_size)
dataset = dataset.repeat(100).batch(batch_size)
return dataset
return input_fn
def _get_exporter(self, name, fc):
feature_spec = feature_column.make_parse_example_spec(fc)
serving_input_receiver_fn = (
export_lib.build_parsing_serving_input_receiver_fn(feature_spec))
return exporter_lib.LatestExporter(
name, serving_input_receiver_fn=serving_input_receiver_fn)
def _extract_loss_and_global_step(self, event_folder):
"""Returns the loss and global step in last event."""
event_paths = glob.glob(os.path.join(event_folder, "events*"))
loss = None
global_step_count = None
for e in summary_iterator.summary_iterator(event_paths[-1]):
current_loss = None
for v in e.summary.value:
if v.tag == "loss":
current_loss = v.simple_value
# If loss is not found, global step is meaningless.
if current_loss is None:
continue
current_global_step = e.step
if global_step_count is None or current_global_step > global_step_count:
global_step_count = current_global_step
loss = current_loss
return (loss, global_step_count)
def _get_estimator(self,
train_distribute,
eval_distribute,
remote_cluster=None):
input_dimension = LABEL_DIMENSION
linear_feature_columns = [
feature_column.numeric_column("x", shape=(input_dimension,))
]
dnn_feature_columns = [
feature_column.numeric_column("x", shape=(input_dimension,))
]
return dnn_linear_combined.DNNLinearCombinedRegressor(
linear_feature_columns=linear_feature_columns,
dnn_hidden_units=(2, 2),
dnn_feature_columns=dnn_feature_columns,
label_dimension=LABEL_DIMENSION,
model_dir=self._model_dir,
dnn_optimizer=adagrad.AdagradOptimizer(0.001),
linear_optimizer=adagrad.AdagradOptimizer(0.001),
config=run_config_lib.RunConfig(
experimental_distribute=DistributeConfig(
train_distribute=train_distribute,
eval_distribute=eval_distribute,
remote_cluster=remote_cluster)))
def _complete_flow(self,
train_distribute,
eval_distribute,
remote_cluster=None):
estimator = self._get_estimator(train_distribute, eval_distribute,
remote_cluster)
input_dimension = LABEL_DIMENSION
train_input_fn = self.dataset_input_fn(
x={"x": DATA},
y=DATA,
batch_size=BATCH_SIZE // len(train_distribute.worker_devices),
shuffle=True)
if eval_distribute:
eval_batch_size = BATCH_SIZE // len(eval_distribute.worker_devices)
else:
eval_batch_size = BATCH_SIZE
eval_input_fn = self.dataset_input_fn(
x={"x": DATA}, y=DATA, batch_size=eval_batch_size, shuffle=False)
linear_feature_columns = [
feature_column.numeric_column("x", shape=(input_dimension,))
]
dnn_feature_columns = [
feature_column.numeric_column("x", shape=(input_dimension,))
]
feature_columns = linear_feature_columns + dnn_feature_columns
estimator_training.train_and_evaluate(
estimator,
estimator_training.TrainSpec(train_input_fn, max_steps=MAX_STEPS),
estimator_training.EvalSpec(
name=EVAL_NAME,
input_fn=eval_input_fn,
steps=None,
exporters=self._get_exporter(EXPORTER_NAME, feature_columns),
start_delay_secs=0,
throttle_secs=1))
return estimator
def _inspect_train_and_eval_events(self, estimator):
# Make sure nothing is stuck in limbo.
writer_cache.FileWriterCache.clear()
# Examine the training events. Use a range to check global step to avoid
# flakyness due to global step race condition.
training_loss, _ = self._extract_loss_and_global_step(self._model_dir)
self.assertIsNotNone(training_loss)
# Examine the eval events. The global step should be accurate.
eval_dir = os.path.join(self._model_dir, "eval_" + EVAL_NAME)
eval_loss, eval_global_step = self._extract_loss_and_global_step(
event_folder=eval_dir)
self.assertIsNotNone(eval_loss)
self.assertGreaterEqual(eval_global_step, MAX_STEPS)
# Examine the export folder.
export_dir = os.path.join(
os.path.join(self._model_dir, "export"), EXPORTER_NAME)
self.assertTrue(gfile.Exists(export_dir))
# Examine the ckpt for predict.
def predict_input_fn():
return dataset_ops.Dataset.from_tensor_slices({
"x": DATA
}).batch(BATCH_SIZE)
predicted_proba = np.array([
x[prediction_keys.PredictionKeys.PREDICTIONS]
for x in estimator.predict(predict_input_fn)
])
self.assertAllEqual((BATCH_SIZE, LABEL_DIMENSION), predicted_proba.shape)
@combinations.generate(
combinations.combine(
mode=["graph"],
train_distribute_cls=[
mirrored_strategy.MirroredStrategy,
parameter_server_strategy.ParameterServerStrategy
],
eval_distribute_cls=[
None, mirrored_strategy.MirroredStrategy,
parameter_server_strategy.ParameterServerStrategy
],
required_gpus=1))
def test_complete_flow_standalone_client(self, train_distribute_cls,
eval_distribute_cls):
try:
train_distribute = train_distribute_cls(num_gpus=context.num_gpus())
except TypeError:
train_distribute = train_distribute_cls(num_gpus_per_worker=2)
if eval_distribute_cls:
eval_distribute = eval_distribute_cls()
else:
eval_distribute = None
estimator = self._complete_flow(
train_distribute, eval_distribute, remote_cluster=self._cluster_spec)
self._inspect_train_and_eval_events(estimator)
def _mock_run_distribute_coordinator(
self,
worker_fn,
strategy,
eval_fn,
eval_strategy,
mode=dc.CoordinatorMode.STANDALONE_CLIENT,
cluster_spec=None,
session_config=None):
# Calls the origial `run_distribute_coordinator` method but gets task config
# from environment variables and then signals the caller.
task_type = None
task_id = None
if not cluster_spec:
cluster_spec = None
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
if not cluster_spec:
cluster_spec = tf_config.get("cluster", {})
task_env = tf_config.get("task", {})
if task_env:
task_type = task_env.get("type", task_type)
task_id = int(task_env.get("index", task_id))
self._event.set()
original_run_distribute_coordinator(
worker_fn,
strategy,
eval_fn,
eval_strategy,
mode=mode,
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
session_config=session_config)
def _task_thread(self, train_distribute, eval_distribute):
with test.mock.patch.object(dc, "run_distribute_coordinator",
self._mock_run_distribute_coordinator):
self._complete_flow(train_distribute, eval_distribute)
def _run_task_in_thread(self, cluster_spec, task_type, task_id,
train_distribute, eval_distribute):
if task_type:
tf_config = {
"cluster": cluster_spec,
"task": {
"type": task_type,
"index": task_id
}
}
else:
tf_config = {
"cluster": cluster_spec,
"task": {
"type": task_type,
"index": task_id
}
}
self._event.clear()
t = threading.Thread(
target=self._task_thread, args=(train_distribute, eval_distribute))
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(tf_config)}):
t.start()
self._event.wait()
return t
def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute,
eval_distribute):
threads = {}
for task_type in cluster_spec.keys():
threads[task_type] = []
for task_id in range(len(cluster_spec[task_type])):
t = self._run_task_in_thread(cluster_spec, task_type, task_id,
train_distribute, eval_distribute)
threads[task_type].append(t)
return threads
@combinations.generate(
combinations.combine(
mode=["graph"],
train_distribute_cls=[
parameter_server_strategy.ParameterServerStrategy,
],
eval_distribute_cls=[
None, mirrored_strategy.MirroredStrategy,
parameter_server_strategy.ParameterServerStrategy
],
required_gpus=1))
def test_complete_flow_indepedent_worker_between_graph(
self, train_distribute_cls, eval_distribute_cls):
train_distribute = train_distribute_cls(
num_gpus_per_worker=context.num_gpus())
if eval_distribute_cls:
eval_distribute = eval_distribute_cls()
else:
eval_distribute = None
cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
threads = self._run_multiple_tasks_in_threads(
cluster_spec, train_distribute, eval_distribute)
for task_type, ts in threads.items():
if task_type == PS:
continue
for t in ts:
t.join()
estimator = self._get_estimator(train_distribute, eval_distribute)
self._inspect_train_and_eval_events(estimator)
@combinations.generate(
combinations.combine(
mode=["graph"],
train_distribute_cls=[mirrored_strategy.MirroredStrategy],
eval_distribute_cls=[None, mirrored_strategy.MirroredStrategy],
required_gpus=1))
def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls,
eval_distribute_cls):
train_distribute = train_distribute_cls(num_gpus=context.num_gpus())
if eval_distribute_cls:
eval_distribute = eval_distribute_cls()
else:
eval_distribute = None
cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
threads = self._run_multiple_tasks_in_threads(
cluster_spec, train_distribute, eval_distribute)
threads[WORKER][0].join()
threads[EVALUATOR][0].join()
estimator = self._get_estimator(train_distribute, eval_distribute)
self._inspect_train_and_eval_events(estimator)
TF_CONFIG_WITH_CHIEF = {
"cluster": {
"chief": ["fake_chief"],
},
"task": {
"type": "chief",
"index": 0
}
}
TF_CONFIG_WITH_MASTER = {
"cluster": {
"master": ["fake_master"],
},
"task": {
"type": "master",
"index": 0
}
}
TF_CONFIG_WITHOUT_TASK = {"cluster": {"chief": ["fake_worker"]}}
class RunConfigTest(test.TestCase):
def test_previously_unexpected_cluster_spec(self):
with test.mock.patch.dict(
"os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITHOUT_TASK)}):
run_config_lib.RunConfig(
experimental_distribute=DistributeConfig(
train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
def test_should_run_distribute_coordinator(self):
"""Tests that should_run_distribute_coordinator return a correct value."""
# We don't use distribute coordinator for local training.
self.assertFalse(
dc_training.should_run_distribute_coordinator(
run_config_lib.RunConfig()))
# When `train_distribute` is not specified, don't use distribute
# coordinator.
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
self.assertFalse(
dc_training.should_run_distribute_coordinator(
run_config_lib.RunConfig()))
# When `train_distribute` is specified and TF_CONFIG is detected, use
# distribute coordinator.
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
config_with_train_distribute = run_config_lib.RunConfig(
experimental_distribute=DistributeConfig(
train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
config_with_eval_distribute = run_config_lib.RunConfig(
experimental_distribute=DistributeConfig(
eval_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
self.assertTrue(
dc_training.should_run_distribute_coordinator(
config_with_train_distribute))
self.assertFalse(
dc_training.should_run_distribute_coordinator(
config_with_eval_distribute))
# With a master in the cluster, don't run distribute coordinator.
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}):
config = run_config_lib.RunConfig(
experimental_distribute=DistributeConfig(
train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
self.assertFalse(dc_training.should_run_distribute_coordinator(config))
def test_init_run_config_duplicate_distribute(self):
with self.assertRaises(ValueError):
run_config_lib.RunConfig(
train_distribute=mirrored_strategy.MirroredStrategy(),
experimental_distribute=DistributeConfig(
train_distribute=mirrored_strategy.MirroredStrategy()))
with self.assertRaises(ValueError):
run_config_lib.RunConfig(
eval_distribute=mirrored_strategy.MirroredStrategy(),
experimental_distribute=DistributeConfig(
eval_distribute=mirrored_strategy.MirroredStrategy()))
def test_init_run_config_none_distribute_coordinator_mode(self):
# We don't use distribute coordinator for local training.
config = run_config_lib.RunConfig(
train_distribute=mirrored_strategy.MirroredStrategy())
dc_training.init_run_config(config, {})
self.assertIsNone(config._distribute_coordinator_mode)
# With a master in the cluster, don't run distribute coordinator.
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}):
config = run_config_lib.RunConfig(
train_distribute=mirrored_strategy.MirroredStrategy())
self.assertIsNone(config._distribute_coordinator_mode)
# When `train_distribute` is not specified, don't use distribute
# coordinator.
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
config = run_config_lib.RunConfig()
self.assertFalse(hasattr(config, "_distribute_coordinator_mode"))
def test_init_run_config_independent_worker(self):
# When `train_distribute` is specified and TF_CONFIG is detected, use
# distribute coordinator with INDEPENDENT_WORKER mode.
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
config = run_config_lib.RunConfig(
train_distribute=mirrored_strategy.MirroredStrategy())
self.assertEqual(config._distribute_coordinator_mode,
dc.CoordinatorMode.INDEPENDENT_WORKER)
def test_init_run_config_standalone_client(self):
# When `train_distribute` is specified, TF_CONFIG is detected and
# `experimental.remote_cluster` is set use distribute coordinator with
# STANDALONE_CLIENT mode.
config = run_config_lib.RunConfig(
train_distribute=mirrored_strategy.MirroredStrategy(),
experimental_distribute=DistributeConfig(
remote_cluster={"chief": ["fake_worker"]}))
self.assertEqual(config._distribute_coordinator_mode,
dc.CoordinatorMode.STANDALONE_CLIENT)
if __name__ == "__main__":
with test.mock.patch.object(sys, "exit", os._exit):
test.main()

View File

@ -134,6 +134,7 @@ py_library(
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python/compat", "//tensorflow/python/compat",
"//tensorflow/python/data", "//tensorflow/python/data",
"//tensorflow/python/distribute:estimator_training",
"//tensorflow/python/feature_column:feature_column_py", "//tensorflow/python/feature_column:feature_column_py",
"//tensorflow/python/keras", "//tensorflow/python/keras",
"//tensorflow/python/ops/distributions", "//tensorflow/python/ops/distributions",

View File

@ -8,6 +8,25 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
name = "distribute",
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":distribute_config",
":distribute_coordinator",
":distribute_coordinator_context",
],
)
py_library(
name = "distribute_config",
srcs = [
"distribute_config.py",
],
deps = [],
)
py_library( py_library(
name = "distribute_coordinator", name = "distribute_coordinator",
srcs = [ srcs = [
@ -81,3 +100,17 @@ py_test(
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",
], ],
) )
# Used only by estimator.
py_library(
name = "estimator_training",
srcs = [
"estimator_training.py",
],
srcs_version = "PY2AND3",
deps = [
":distribute_coordinator",
":distribute_coordinator_context",
"//tensorflow/python:training",
],
)

View File

@ -0,0 +1,45 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A configure tuple for high-level APIs for running distribution strategies."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
class DistributeConfig(
collections.namedtuple(
'DistributeConfig',
['train_distribute', 'eval_distribute', 'remote_cluster'])):
"""A config tuple for distribution strategies.
Attributes:
train_distribute: a `DistributionStrategy` object for training.
eval_distribute: an optional `DistributionStrategy` object for
evaluation.
remote_cluster: a dict, `ClusterDef` or `ClusterSpec` object specifying
the cluster configurations. If this is given, the `train_and_evaluate`
method will be running as a standalone client which connects to the
cluster for training.
"""
def __new__(cls,
train_distribute=None,
eval_distribute=None,
remote_cluster=None):
return super(DistributeConfig, cls).__new__(cls, train_distribute,
eval_distribute, remote_cluster)

View File

@ -311,7 +311,11 @@ def _run_single_worker(worker_fn,
worker_barrier=None): worker_barrier=None):
"""Runs a single worker by calling `worker_fn` under context.""" """Runs a single worker by calling `worker_fn` under context."""
strategy = copy.deepcopy(strategy) strategy = copy.deepcopy(strategy)
strategy.configure(session_config, cluster_spec, task_type, task_id) # If there is an EVALUATOR task, we run single-machine eval on that task.
if task_type == _TaskType.EVALUATOR:
strategy.configure(session_config)
else:
strategy.configure(session_config, cluster_spec, task_type, task_id)
context = _WorkerContext( context = _WorkerContext(
strategy, strategy,
cluster_spec, cluster_spec,
@ -340,14 +344,14 @@ def _run_std_server(cluster_spec=None,
return server return server
def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config, def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
rpc_layer): cluster_spec, session_config, rpc_layer):
"""Runs a standalone client for between-graph replication.""" """Runs a standalone client for between-graph replication."""
eval_thread = None eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs: if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread( eval_thread = threading.Thread(
target=_run_single_worker, target=_run_single_worker,
args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0, args=(eval_fn, eval_strategy, None, _TaskType.EVALUATOR, 0,
session_config), session_config),
kwargs={ kwargs={
"rpc_layer": rpc_layer, "rpc_layer": rpc_layer,
@ -378,14 +382,14 @@ def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config,
eval_thread.join() eval_thread.join()
def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config, def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
rpc_layer): cluster_spec, session_config, rpc_layer):
"""Runs a standalone client for in-graph replication.""" """Runs a standalone client for in-graph replication."""
eval_thread = None eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs: if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread( eval_thread = threading.Thread(
target=_run_single_worker, target=_run_single_worker,
args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0, args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
session_config), session_config),
kwargs={ kwargs={
"rpc_layer": rpc_layer, "rpc_layer": rpc_layer,
@ -408,6 +412,8 @@ def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
# is the special task when we support cluster_spec propagation. # is the special task when we support cluster_spec propagation.
def run_distribute_coordinator(worker_fn, def run_distribute_coordinator(worker_fn,
strategy, strategy,
eval_fn=None,
eval_strategy=None,
mode=CoordinatorMode.STANDALONE_CLIENT, mode=CoordinatorMode.STANDALONE_CLIENT,
cluster_spec=None, cluster_spec=None,
task_type=None, task_type=None,
@ -488,10 +494,12 @@ def run_distribute_coordinator(worker_fn,
If `cluster_spec` is not given in any format, it becomes local training and If `cluster_spec` is not given in any format, it becomes local training and
this coordinator will connect to a local session. this coordinator will connect to a local session.
For evaluation, if "evaluator" exist in the cluster_spec, a separate thread For evaluation, if "evaluator" exists in the cluster_spec, a separate thread
will be created with its `task_type` set to "evaluator". If "evaluator" is not will be created to call `eval_fn` with its `task_type` set to "evaluator". If
set in the cluster_spec, it entirely depends on the `worker_fn` for how to do `eval_fn` is not defined, fall back to `worker_fn`. This implies that
evaluation. evaluation will be done on a single machine if there is an "evaluator" task.
If "evaluator" doesn't exit in the cluster_spec, it entirely depends on the
`worker_fn` for how to do evaluation.
Args: Args:
worker_fn: the function to be called. The function should accept a worker_fn: the function to be called. The function should accept a
@ -501,6 +509,8 @@ def run_distribute_coordinator(worker_fn,
run between-graph replicated training or not, whether to run init ops, run between-graph replicated training or not, whether to run init ops,
etc. This object will also be configured given `session_config`, etc. This object will also be configured given `session_config`,
`cluster_spc`, `task_type` and `task_id`. `cluster_spc`, `task_type` and `task_id`.
eval_fn: optional function for "evaluator" task.
eval_strategy: optional DistributionStrategy object for "evaluator" task.
mode: in which mode this distribute coordinator runs. mode: in which mode this distribute coordinator runs.
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
in a cluster. If not set or empty, fall back to local training. in a cluster. If not set or empty, fall back to local training.
@ -535,16 +545,22 @@ def run_distribute_coordinator(worker_fn,
# `mode` is ignored in the local case. # `mode` is ignored in the local case.
_run_single_worker(worker_fn, strategy, None, None, None, session_config, _run_single_worker(worker_fn, strategy, None, None, None, session_config,
rpc_layer) rpc_layer)
if eval_fn:
_run_single_worker(eval_fn, eval_strategy or strategy, None, None, None,
session_config, rpc_layer)
elif mode == CoordinatorMode.STANDALONE_CLIENT: elif mode == CoordinatorMode.STANDALONE_CLIENT:
eval_fn = eval_fn or worker_fn
eval_strategy = eval_strategy or strategy
# The client must know the cluster but servers in the cluster don't have to # The client must know the cluster but servers in the cluster don't have to
# know the client. # know the client.
if task_type in [_TaskType.CLIENT, None]: if task_type in [_TaskType.CLIENT, None]:
if strategy.between_graph: if strategy.between_graph:
_run_between_graph_client(worker_fn, strategy, cluster_spec, _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
session_config, rpc_layer) cluster_spec, session_config, rpc_layer)
else: else:
_run_in_graph_client(worker_fn, strategy, cluster_spec, session_config, _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
rpc_layer) cluster_spec, session_config, rpc_layer)
else: else:
# If not a client job, run the standard server. # If not a client job, run the standard server.
server = _run_std_server( server = _run_std_server(
@ -554,6 +570,9 @@ def run_distribute_coordinator(worker_fn,
if mode != CoordinatorMode.INDEPENDENT_WORKER: if mode != CoordinatorMode.INDEPENDENT_WORKER:
raise ValueError("Unexpected coordinator mode: %r" % mode) raise ValueError("Unexpected coordinator mode: %r" % mode)
eval_fn = eval_fn or worker_fn
eval_strategy = eval_strategy or strategy
# Every one starts a standard server. # Every one starts a standard server.
server = _run_std_server( server = _run_std_server(
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id) cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
@ -572,8 +591,8 @@ def run_distribute_coordinator(worker_fn,
else: else:
server.join() server.join()
elif task_type == _TaskType.EVALUATOR: elif task_type == _TaskType.EVALUATOR:
_run_single_worker(worker_fn, strategy, cluster_spec, task_type, task_id, _run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type,
session_config, rpc_layer) task_id, session_config, rpc_layer)
else: else:
if task_type != _TaskType.PS: if task_type != _TaskType.PS:
raise ValueError("Unexpected task_type: %r" % task_type) raise ValueError("Unexpected task_type: %r" % task_type)

View File

@ -0,0 +1,264 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training utilities for Estimator to use Distribute Coordinator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import six
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
# pylint: disable=protected-access
CHIEF = dc._TaskType.CHIEF
EVALUATOR = dc._TaskType.EVALUATOR
PS = dc._TaskType.PS
WORKER = dc._TaskType.WORKER
# pylint: enable=protected-access
def _count_ps(cluster_spec):
"""Counts the number of parameter servers in cluster_spec."""
if not cluster_spec:
raise RuntimeError(
'Internal error: `_count_ps` does not expect empty cluster_spec.')
return len(cluster_spec.as_dict().get(PS, []))
def _count_worker(cluster_spec, chief_task_type):
"""Counts the number of workers (including chief) in cluster_spec."""
if not cluster_spec:
raise RuntimeError(
'Internal error: `_count_worker` does not expect empty cluster_spec.')
return (len(cluster_spec.as_dict().get(WORKER, [])) + len(
cluster_spec.as_dict().get(chief_task_type, [])))
def _get_global_id(cluster_spec, task_type, task_id, chief_task_type):
"""Returns the global id of the given task type in a cluster."""
if not task_type:
return 0
# Sort task names in cluster by "chief"/"master", "evaluator", "worker"
# and "ps". More details can be found at the documentation of
# @{tf.estimator.RunConfig.global_id_in_cluster}.
task_type_ordered_list = []
if chief_task_type in cluster_spec.jobs:
task_type_ordered_list = [chief_task_type]
task_type_ordered_list.extend([
t for t in sorted(cluster_spec.jobs) if t != chief_task_type and t != PS
])
if PS in cluster_spec.jobs:
task_type_ordered_list.append(PS)
# Find the right gloabl_id for current task.
next_global_id = 0
for t in task_type_ordered_list:
if t == task_type:
return next_global_id + task_id
# `cluster_spec.job_tasks` returns all task addresses of type `t`.
next_global_id += len(cluster_spec.job_tasks(t))
# It is unexpected that it passes through all task_types in
# `task_type_ordered_list`.
raise RuntimeError('Internal Error: `task_type` ({}) is not in '
'cluster_spec ({}).'.format(task_type, cluster_spec))
def _init_run_config_from_worker_context(config, worker_context):
"""Initializes run config from distribute coordinator's worker context."""
# pylint: disable=protected-access
config._service = None
config._cluster_spec = worker_context.cluster_spec
config._task_type = worker_context.task_type
config._task_id = worker_context.task_id
config._evaluation_master = worker_context.master_target
config._master = worker_context.master_target
config._is_chief = worker_context.is_chief
if config._cluster_spec:
# Distributed mode.
if config._task_type != EVALUATOR:
config._num_ps_replicas = _count_ps(config._cluster_spec)
config._num_worker_replicas = _count_worker(
config._cluster_spec, chief_task_type=CHIEF)
config._global_id_in_cluster = _get_global_id(
config._cluster_spec,
config._task_type,
config._task_id,
chief_task_type=CHIEF)
else:
# Evaluator task should not be aware of the other tasks.
config._cluster_spec = server_lib.ClusterSpec({})
config._num_ps_replicas = 0
config._num_worker_replicas = 0
config._global_id_in_cluster = None # undefined
else:
# Local mode.
config._global_id_in_cluster = 0
config._num_ps_replicas = 0
config._num_worker_replicas = 1
def init_run_config(config, tf_config):
"""Initializes RunConfig for distribution strategies."""
# pylint: disable=protected-access
if (config._experimental_distribute and
config._experimental_distribute.train_distribute):
if config._train_distribute:
raise ValueError('Either `train_distribute` or'
'`experimental_distribute.train_distribute` can be set.')
config._train_distribute = config._experimental_distribute.train_distribute
if (config._experimental_distribute and
config._experimental_distribute.eval_distribute):
if config._eval_distribute:
raise ValueError('Either `eval_distribute` or'
'`experimental_distribute.eval_distribute` can be set.')
config._eval_distribute = config._experimental_distribute.eval_distribute
cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {}))
config._init_distributed_setting_from_environment_var({})
# Use distribute coordinator with STANDALONE_CLIENT mode if
# `experimental_distribute.remote_cluster` is set.
if (config._train_distribute and config._experimental_distribute and
config._experimental_distribute.remote_cluster):
if tf_config:
raise ValueError('Cannot set both TF_CONFIG environment variable and '
'`experimental_distribute.remote_cluster`')
config._distribute_coordinator_mode = dc.CoordinatorMode.STANDALONE_CLIENT
config._cluster_spec = config._experimental_distribute.remote_cluster
logging.info('RunConfig initialized for Distribute Coordinator with '
'STANDALONE_CLIENT mode')
return
# Don't use distribute coordinator if it is local training or cluster has a
# MASTER job or `train_distribute` is not specifed.
if (not tf_config or 'master' in cluster_spec.jobs or
not config._train_distribute):
config._distribute_coordinator_mode = None
config._init_distributed_setting_from_environment_var(tf_config)
config._maybe_overwrite_session_config_for_distributed_training()
logging.info('Not using Distribute Coordinator.')
return
# Use distribute coordinator with INDEPENDENT_WORKER mode otherwise.
assert tf_config
# Set the cluster_spec only since the distributed setting will come from
# distribute coordinator.
config._cluster_spec = cluster_spec
config._distribute_coordinator_mode = dc.CoordinatorMode.INDEPENDENT_WORKER
logging.info('RunConfig initialized for Distribute Coordinator with '
'INDEPENDENT_WORKER mode')
def should_run_distribute_coordinator(config):
"""Checks the config to see whether to run distribute coordinator."""
# pylint: disable=protected-access
if (not hasattr(config, '_distribute_coordinator_mode') or
config._distribute_coordinator_mode is None):
return False
if (not isinstance(config._distribute_coordinator_mode, six.string_types) or
config._distribute_coordinator_mode not in [
dc.CoordinatorMode.STANDALONE_CLIENT,
dc.CoordinatorMode.INDEPENDENT_WORKER
]):
logging.warning('Unexpected distribute_coordinator_mode: %r',
config._distribute_coordinator_mode)
return False
if not config.cluster_spec:
logging.warning('Running `train_and_evaluate` locally, ignoring '
'`experimental_distribute_coordinator_mode`.')
return False
return True
def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
"""Run distribute coordinator for Estimator's `train_and_evaluate`.
Args:
estimator: An `Estimator` instance to train and evaluate.
train_spec: A `TrainSpec` instance to specify the training specification.
eval_spec: A `EvalSpec` instance to specify the evaluation and export
specification.
executor_cls: the evaluation executor class of Estimator.
Raises:
ValueError: if `distribute_coordinator_mode` is None in RunConfig.
"""
run_config = estimator.config
if not run_config._distribute_coordinator_mode: # pylint: disable=protected-access
raise ValueError(
'Distribute coordinator mode is not specified in `RunConfig`.')
def _worker_fn(strategy):
"""Function for worker task."""
local_estimator = copy.deepcopy(estimator)
# pylint: disable=protected-access
local_estimator._config._train_distribute = strategy
_init_run_config_from_worker_context(
local_estimator._config, dc_context.get_current_worker_context())
local_estimator._train_distribution = strategy
# pylint: enable=protected-access
local_estimator.train(
input_fn=train_spec.input_fn,
max_steps=train_spec.max_steps,
hooks=list(train_spec.hooks))
def _eval_fn(strategy):
"""Function for evaluator task."""
local_estimator = copy.deepcopy(estimator)
# pylint: disable=protected-access
local_estimator._config._eval_distribute = strategy
_init_run_config_from_worker_context(
local_estimator._config, dc_context.get_current_worker_context())
local_estimator._eval_distribution = strategy
executor = executor_cls(local_estimator, train_spec, eval_spec)
executor._start_continuous_evaluation()
# pylint: enable=protected-access
# pylint: disable=protected-access
if (run_config._distribute_coordinator_mode ==
dc.CoordinatorMode.STANDALONE_CLIENT):
cluster_spec = run_config.cluster_spec
assert cluster_spec
else:
# The cluster_spec comes from TF_CONFIG environment variable if it is
# INDEPENDENT_WORKER mode.
cluster_spec = None
dc.run_distribute_coordinator(
_worker_fn,
run_config.train_distribute,
_eval_fn,
run_config.eval_distribute,
mode=run_config._distribute_coordinator_mode,
cluster_spec=cluster_spec,
session_config=run_config.session_config)

View File

@ -26,6 +26,7 @@ import six
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.distribute import estimator_training as distribute_coordinator_training
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
from tensorflow.python.util import compat_internal from tensorflow.python.util import compat_internal
@ -460,7 +461,8 @@ class RunConfig(object):
train_distribute: An optional instance of train_distribute: An optional instance of
`tf.contrib.distribute.DistributionStrategy`. If specified, `tf.contrib.distribute.DistributionStrategy`. If specified,
then Estimator will distribute the user's model during training, then Estimator will distribute the user's model during training,
according to the policy specified by that strategy. according to the policy specified by that strategy. Setting
`experimental_distribute.train_distribute` is preferred.
device_fn: A callable invoked for every `Operation` that takes the device_fn: A callable invoked for every `Operation` that takes the
`Operation` and returns the device string. If `None`, defaults to `Operation` and returns the device string. If `None`, defaults to
the device function returned by `tf.train.replica_device_setter` the device function returned by `tf.train.replica_device_setter`
@ -470,10 +472,13 @@ class RunConfig(object):
eval_distribute: An optional instance of eval_distribute: An optional instance of
`tf.contrib.distribute.DistributionStrategy`. If specified, `tf.contrib.distribute.DistributionStrategy`. If specified,
then Estimator will distribute the user's model during evaluation, then Estimator will distribute the user's model during evaluation,
according to the policy specified by that strategy. according to the policy specified by that strategy. Setting
`experimental_distribute.eval_distribute` is preferred.
experimental_distribute: an optional experimental_distribute: an optional
`tf.contrib.distribute.DistributeConfig` object specifying `tf.contrib.distribute.DistributeConfig` object specifying
DistributionStrategy-related configuration. DistributionStrategy-related configuration. The `train_distribute` and
`eval_distribute` can be passed as parameters to `RunConfig` or set in
`experimental_distribute` but not both.
Raises: Raises:
ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs` ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs`
@ -516,9 +521,12 @@ class RunConfig(object):
eval_distribute=eval_distribute, eval_distribute=eval_distribute,
experimental_distribute=experimental_distribute) experimental_distribute=experimental_distribute)
self._init_distributed_setting_from_environment_var(tf_config) if train_distribute or eval_distribute or experimental_distribute:
logging.info('Initializing RunConfig with distribution strategies.')
self._maybe_overwrite_session_config_for_distributed_training() distribute_coordinator_training.init_run_config(self, tf_config)
else:
self._init_distributed_setting_from_environment_var(tf_config)
self._maybe_overwrite_session_config_for_distributed_training()
def _maybe_overwrite_session_config_for_distributed_training(self): def _maybe_overwrite_session_config_for_distributed_training(self):
"""Overwrites the session_config for distributed training. """Overwrites the session_config for distributed training.

View File

@ -26,6 +26,7 @@ import time
import six import six
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import estimator_training as distribute_coordinator_training
from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import exporter as exporter_lib from tensorflow.python.estimator import exporter as exporter_lib
from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.estimator import run_config as run_config_lib
@ -274,8 +275,10 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
evaluation `input_fn`, steps, etc. evaluation `input_fn`, steps, etc.
This utility function provides consistent behavior for both local This utility function provides consistent behavior for both local
(non-distributed) and distributed configurations. Currently, the only (non-distributed) and distributed configurations. The default distribution
supported distributed training configuration is between-graph replication. configuration is parameter server-based between-graph replication. For other
types of distribution configurations such as all-reduce training, please use
[DistributionStrategies](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute). # pylint: disable=line-too-long
Overfitting: In order to avoid overfitting, it is recommended to set up the Overfitting: In order to avoid overfitting, it is recommended to set up the
training `input_fn` to shuffle the training data properly. training `input_fn` to shuffle the training data properly.
@ -426,6 +429,11 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
}' }'
``` ```
When `distribute` or `experimental_distribute.train_distribute` and
`experimental_distribute.remote_cluster` is set, this method will start a
client running on the current host which connects to the `remote_cluster` for
training and evaluation.
Args: Args:
estimator: An `Estimator` instance to train and evaluate. estimator: An `Estimator` instance to train and evaluate.
train_spec: A `TrainSpec` instance to specify the training specification. train_spec: A `TrainSpec` instance to specify the training specification.
@ -444,8 +452,16 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
executor = _TrainingExecutor( executor = _TrainingExecutor(
estimator=estimator, train_spec=train_spec, eval_spec=eval_spec) estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
config = estimator.config config = estimator.config
# If `distribute_coordinator_mode` is set and running in distributed
# environment, we run `train_and_evaluate` via distribute coordinator.
if distribute_coordinator_training.should_run_distribute_coordinator(config):
logging.info('Running `train_and_evaluate` with Distribute Coordinator.')
distribute_coordinator_training.train_and_evaluate(
estimator, train_spec, eval_spec, _TrainingExecutor)
return
if (config.task_type == run_config_lib.TaskType.EVALUATOR and if (config.task_type == run_config_lib.TaskType.EVALUATOR and
config.task_id > 0): config.task_id > 0):
raise ValueError( raise ValueError(