Add an option to RunConfig and train_and_evaluate to run distribute coordinator.

This is necessary to run multi-worker MirroredStrategy and CollectiveAllReduceStrategy with estimator.

PiperOrigin-RevId: 210192378
This commit is contained in:
Yuefeng Zhou 2018-08-24 19:14:44 -07:00 committed by TensorFlower Gardener
parent 9599b47303
commit ca94990804
11 changed files with 1100 additions and 26 deletions

View File

@ -35,5 +35,6 @@ py_library(
"//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/distribute:distribute_config",
],
)

View File

@ -27,6 +27,7 @@ from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceSt
from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy
from tensorflow.contrib.distribute.python.step_fn import *
from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy
from tensorflow.python.distribute.distribute_config import DistributeConfig
from tensorflow.python.training.distribute import *
from tensorflow.python.training.distribution_strategy_context import *
@ -37,6 +38,7 @@ _allowed_symbols = [
'AllReduceCrossTowerOps',
'CollectiveAllReduceStrategy',
'CrossTowerOps',
'DistributeConfig',
'DistributionStrategy',
'MirroredStrategy',
'Monitor',

View File

@ -452,6 +452,32 @@ cuda_py_test(
],
)
cuda_py_test(
name = "estimator_training_test",
size = "large",
srcs = ["estimator_training_test.py"],
additional_deps = [
":combinations",
":mirrored_strategy",
":multi_worker_test_base",
":parameter_server_strategy",
"//third_party/py/numpy",
"//tensorflow/contrib/optimizer_v2:training",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute",
"//tensorflow/python/eager:test",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/feature_column",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:summary",
],
tags = [
"multi_and_single_gpu",
"no_pip",
],
)
py_library(
name = "single_loss_example",
srcs = ["single_loss_example.py"],

View File

@ -0,0 +1,659 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests that show Distribute Coordinator works with Estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import json
import os
import sys
import tempfile
import threading
from absl.testing import parameterized
import numpy as np
import six
_portpicker_import_error = None
try:
import portpicker # pylint: disable=g-import-not-at-top
except ImportError as _error: # pylint: disable=invalid-name
_portpicker_import_error = _error
portpicker = None
# pylint: disable=g-import-not-at-top
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import parameter_server_strategy
from tensorflow.contrib.optimizer_v2 import adagrad
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import estimator_training as dc_training
from tensorflow.python.distribute.distribute_config import DistributeConfig
from tensorflow.python.eager import context
from tensorflow.python.estimator import exporter as exporter_lib
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.estimator import training as estimator_training
from tensorflow.python.estimator.canned import dnn_linear_combined
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.export import export as export_lib
from tensorflow.python.feature_column import feature_column
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import server_lib
BATCH_SIZE = 10
LABEL_DIMENSION = 2
DATA = np.linspace(
0., 2., BATCH_SIZE * LABEL_DIMENSION, dtype=np.float32).reshape(
BATCH_SIZE, LABEL_DIMENSION)
EVAL_NAME = "foo"
EXPORTER_NAME = "saved_model_exporter"
MAX_STEPS = 10
CHIEF = dc._TaskType.CHIEF
EVALUATOR = dc._TaskType.EVALUATOR
WORKER = dc._TaskType.WORKER
PS = dc._TaskType.PS
original_run_distribute_coordinator = dc.run_distribute_coordinator
# TODO(yuefengz): merge this method back to test_util.
def _create_local_cluster(num_workers,
num_ps,
has_eval=False,
protocol="grpc",
worker_config=None,
ps_config=None):
if _portpicker_import_error:
raise _portpicker_import_error # pylint: disable=raising-bad-type
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
cluster_dict = {
"worker": ["localhost:%s" % port for port in worker_ports],
"ps": ["localhost:%s" % port for port in ps_ports]
}
if has_eval:
cluster_dict["evaluator"] = ["localhost:%s" % portpicker.pick_unused_port()]
cs = server_lib.ClusterSpec(cluster_dict)
workers = [
server_lib.Server(
cs,
job_name="worker",
protocol=protocol,
task_index=ix,
config=worker_config,
start=True) for ix in range(num_workers)
]
ps_servers = [
server_lib.Server(
cs,
job_name="ps",
protocol=protocol,
task_index=ix,
config=ps_config,
start=True) for ix in range(num_ps)
]
if has_eval:
evals = [
server_lib.Server(
cs,
job_name="evaluator",
protocol=protocol,
task_index=0,
config=worker_config,
start=True)
]
else:
evals = []
return workers, ps_servers, evals
def _create_in_process_cluster(num_workers, num_ps, has_eval=False):
"""Create an in-process cluster that consists of only standard server."""
# Leave some memory for cuda runtime.
if has_eval:
gpu_mem_frac = 0.7 / (num_workers + 1)
else:
gpu_mem_frac = 0.7 / num_workers
worker_config = config_pb2.ConfigProto()
worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
# Enable collective ops which has no impact on non-collective ops.
# TODO(yuefengz, tucker): removing this after we move the initialization of
# collective mgr to the session level.
worker_config.experimental.collective_group_leader = (
"/job:worker/replica:0/task:0")
ps_config = config_pb2.ConfigProto()
ps_config.device_count["GPU"] = 0
return _create_local_cluster(
num_workers,
num_ps=num_ps,
has_eval=has_eval,
worker_config=worker_config,
ps_config=ps_config,
protocol="grpc")
def _create_cluster_spec(has_chief=False,
num_workers=1,
num_ps=0,
has_eval=False):
if _portpicker_import_error:
raise _portpicker_import_error # pylint: disable=raising-bad-type
cluster_spec = {}
if has_chief:
cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()]
if num_workers:
cluster_spec[WORKER] = [
"localhost:%s" % portpicker.pick_unused_port()
for _ in range(num_workers)
]
if num_ps:
cluster_spec[PS] = [
"localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps)
]
if has_eval:
cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
return cluster_spec
def _bytes_to_str(maybe_bytes):
if isinstance(maybe_bytes, six.string_types):
return maybe_bytes
else:
return str(maybe_bytes, "utf-8")
def _strip_protocol(target):
# cluster_spec expects "host:port" strings.
if "//" in target:
return target.split("//")[1]
else:
return target
class DistributeCoordinatorIntegrationTest(test.TestCase,
parameterized.TestCase):
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
cls._workers, cls._ps, cls._evals = _create_in_process_cluster(
num_workers=3, num_ps=2, has_eval=True)
cls._cluster_spec = {
"worker": [
_strip_protocol(_bytes_to_str(w.target)) for w in cls._workers
],
"ps": [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps],
"evaluator": [
_strip_protocol(_bytes_to_str(e.target)) for e in cls._evals
]
}
def setUp(self):
self._model_dir = tempfile.mkdtemp()
self._event = threading.Event()
super(DistributeCoordinatorIntegrationTest, self).setUp()
def dataset_input_fn(self, x, y, batch_size, shuffle):
def input_fn():
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
if shuffle:
dataset = dataset.shuffle(batch_size)
dataset = dataset.repeat(100).batch(batch_size)
return dataset
return input_fn
def _get_exporter(self, name, fc):
feature_spec = feature_column.make_parse_example_spec(fc)
serving_input_receiver_fn = (
export_lib.build_parsing_serving_input_receiver_fn(feature_spec))
return exporter_lib.LatestExporter(
name, serving_input_receiver_fn=serving_input_receiver_fn)
def _extract_loss_and_global_step(self, event_folder):
"""Returns the loss and global step in last event."""
event_paths = glob.glob(os.path.join(event_folder, "events*"))
loss = None
global_step_count = None
for e in summary_iterator.summary_iterator(event_paths[-1]):
current_loss = None
for v in e.summary.value:
if v.tag == "loss":
current_loss = v.simple_value
# If loss is not found, global step is meaningless.
if current_loss is None:
continue
current_global_step = e.step
if global_step_count is None or current_global_step > global_step_count:
global_step_count = current_global_step
loss = current_loss
return (loss, global_step_count)
def _get_estimator(self,
train_distribute,
eval_distribute,
remote_cluster=None):
input_dimension = LABEL_DIMENSION
linear_feature_columns = [
feature_column.numeric_column("x", shape=(input_dimension,))
]
dnn_feature_columns = [
feature_column.numeric_column("x", shape=(input_dimension,))
]
return dnn_linear_combined.DNNLinearCombinedRegressor(
linear_feature_columns=linear_feature_columns,
dnn_hidden_units=(2, 2),
dnn_feature_columns=dnn_feature_columns,
label_dimension=LABEL_DIMENSION,
model_dir=self._model_dir,
dnn_optimizer=adagrad.AdagradOptimizer(0.001),
linear_optimizer=adagrad.AdagradOptimizer(0.001),
config=run_config_lib.RunConfig(
experimental_distribute=DistributeConfig(
train_distribute=train_distribute,
eval_distribute=eval_distribute,
remote_cluster=remote_cluster)))
def _complete_flow(self,
train_distribute,
eval_distribute,
remote_cluster=None):
estimator = self._get_estimator(train_distribute, eval_distribute,
remote_cluster)
input_dimension = LABEL_DIMENSION
train_input_fn = self.dataset_input_fn(
x={"x": DATA},
y=DATA,
batch_size=BATCH_SIZE // len(train_distribute.worker_devices),
shuffle=True)
if eval_distribute:
eval_batch_size = BATCH_SIZE // len(eval_distribute.worker_devices)
else:
eval_batch_size = BATCH_SIZE
eval_input_fn = self.dataset_input_fn(
x={"x": DATA}, y=DATA, batch_size=eval_batch_size, shuffle=False)
linear_feature_columns = [
feature_column.numeric_column("x", shape=(input_dimension,))
]
dnn_feature_columns = [
feature_column.numeric_column("x", shape=(input_dimension,))
]
feature_columns = linear_feature_columns + dnn_feature_columns
estimator_training.train_and_evaluate(
estimator,
estimator_training.TrainSpec(train_input_fn, max_steps=MAX_STEPS),
estimator_training.EvalSpec(
name=EVAL_NAME,
input_fn=eval_input_fn,
steps=None,
exporters=self._get_exporter(EXPORTER_NAME, feature_columns),
start_delay_secs=0,
throttle_secs=1))
return estimator
def _inspect_train_and_eval_events(self, estimator):
# Make sure nothing is stuck in limbo.
writer_cache.FileWriterCache.clear()
# Examine the training events. Use a range to check global step to avoid
# flakyness due to global step race condition.
training_loss, _ = self._extract_loss_and_global_step(self._model_dir)
self.assertIsNotNone(training_loss)
# Examine the eval events. The global step should be accurate.
eval_dir = os.path.join(self._model_dir, "eval_" + EVAL_NAME)
eval_loss, eval_global_step = self._extract_loss_and_global_step(
event_folder=eval_dir)
self.assertIsNotNone(eval_loss)
self.assertGreaterEqual(eval_global_step, MAX_STEPS)
# Examine the export folder.
export_dir = os.path.join(
os.path.join(self._model_dir, "export"), EXPORTER_NAME)
self.assertTrue(gfile.Exists(export_dir))
# Examine the ckpt for predict.
def predict_input_fn():
return dataset_ops.Dataset.from_tensor_slices({
"x": DATA
}).batch(BATCH_SIZE)
predicted_proba = np.array([
x[prediction_keys.PredictionKeys.PREDICTIONS]
for x in estimator.predict(predict_input_fn)
])
self.assertAllEqual((BATCH_SIZE, LABEL_DIMENSION), predicted_proba.shape)
@combinations.generate(
combinations.combine(
mode=["graph"],
train_distribute_cls=[
mirrored_strategy.MirroredStrategy,
parameter_server_strategy.ParameterServerStrategy
],
eval_distribute_cls=[
None, mirrored_strategy.MirroredStrategy,
parameter_server_strategy.ParameterServerStrategy
],
required_gpus=1))
def test_complete_flow_standalone_client(self, train_distribute_cls,
eval_distribute_cls):
try:
train_distribute = train_distribute_cls(num_gpus=context.num_gpus())
except TypeError:
train_distribute = train_distribute_cls(num_gpus_per_worker=2)
if eval_distribute_cls:
eval_distribute = eval_distribute_cls()
else:
eval_distribute = None
estimator = self._complete_flow(
train_distribute, eval_distribute, remote_cluster=self._cluster_spec)
self._inspect_train_and_eval_events(estimator)
def _mock_run_distribute_coordinator(
self,
worker_fn,
strategy,
eval_fn,
eval_strategy,
mode=dc.CoordinatorMode.STANDALONE_CLIENT,
cluster_spec=None,
session_config=None):
# Calls the origial `run_distribute_coordinator` method but gets task config
# from environment variables and then signals the caller.
task_type = None
task_id = None
if not cluster_spec:
cluster_spec = None
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
if not cluster_spec:
cluster_spec = tf_config.get("cluster", {})
task_env = tf_config.get("task", {})
if task_env:
task_type = task_env.get("type", task_type)
task_id = int(task_env.get("index", task_id))
self._event.set()
original_run_distribute_coordinator(
worker_fn,
strategy,
eval_fn,
eval_strategy,
mode=mode,
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
session_config=session_config)
def _task_thread(self, train_distribute, eval_distribute):
with test.mock.patch.object(dc, "run_distribute_coordinator",
self._mock_run_distribute_coordinator):
self._complete_flow(train_distribute, eval_distribute)
def _run_task_in_thread(self, cluster_spec, task_type, task_id,
train_distribute, eval_distribute):
if task_type:
tf_config = {
"cluster": cluster_spec,
"task": {
"type": task_type,
"index": task_id
}
}
else:
tf_config = {
"cluster": cluster_spec,
"task": {
"type": task_type,
"index": task_id
}
}
self._event.clear()
t = threading.Thread(
target=self._task_thread, args=(train_distribute, eval_distribute))
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(tf_config)}):
t.start()
self._event.wait()
return t
def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute,
eval_distribute):
threads = {}
for task_type in cluster_spec.keys():
threads[task_type] = []
for task_id in range(len(cluster_spec[task_type])):
t = self._run_task_in_thread(cluster_spec, task_type, task_id,
train_distribute, eval_distribute)
threads[task_type].append(t)
return threads
@combinations.generate(
combinations.combine(
mode=["graph"],
train_distribute_cls=[
parameter_server_strategy.ParameterServerStrategy,
],
eval_distribute_cls=[
None, mirrored_strategy.MirroredStrategy,
parameter_server_strategy.ParameterServerStrategy
],
required_gpus=1))
def test_complete_flow_indepedent_worker_between_graph(
self, train_distribute_cls, eval_distribute_cls):
train_distribute = train_distribute_cls(
num_gpus_per_worker=context.num_gpus())
if eval_distribute_cls:
eval_distribute = eval_distribute_cls()
else:
eval_distribute = None
cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
threads = self._run_multiple_tasks_in_threads(
cluster_spec, train_distribute, eval_distribute)
for task_type, ts in threads.items():
if task_type == PS:
continue
for t in ts:
t.join()
estimator = self._get_estimator(train_distribute, eval_distribute)
self._inspect_train_and_eval_events(estimator)
@combinations.generate(
combinations.combine(
mode=["graph"],
train_distribute_cls=[mirrored_strategy.MirroredStrategy],
eval_distribute_cls=[None, mirrored_strategy.MirroredStrategy],
required_gpus=1))
def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls,
eval_distribute_cls):
train_distribute = train_distribute_cls(num_gpus=context.num_gpus())
if eval_distribute_cls:
eval_distribute = eval_distribute_cls()
else:
eval_distribute = None
cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
threads = self._run_multiple_tasks_in_threads(
cluster_spec, train_distribute, eval_distribute)
threads[WORKER][0].join()
threads[EVALUATOR][0].join()
estimator = self._get_estimator(train_distribute, eval_distribute)
self._inspect_train_and_eval_events(estimator)
TF_CONFIG_WITH_CHIEF = {
"cluster": {
"chief": ["fake_chief"],
},
"task": {
"type": "chief",
"index": 0
}
}
TF_CONFIG_WITH_MASTER = {
"cluster": {
"master": ["fake_master"],
},
"task": {
"type": "master",
"index": 0
}
}
TF_CONFIG_WITHOUT_TASK = {"cluster": {"chief": ["fake_worker"]}}
class RunConfigTest(test.TestCase):
def test_previously_unexpected_cluster_spec(self):
with test.mock.patch.dict(
"os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITHOUT_TASK)}):
run_config_lib.RunConfig(
experimental_distribute=DistributeConfig(
train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
def test_should_run_distribute_coordinator(self):
"""Tests that should_run_distribute_coordinator return a correct value."""
# We don't use distribute coordinator for local training.
self.assertFalse(
dc_training.should_run_distribute_coordinator(
run_config_lib.RunConfig()))
# When `train_distribute` is not specified, don't use distribute
# coordinator.
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
self.assertFalse(
dc_training.should_run_distribute_coordinator(
run_config_lib.RunConfig()))
# When `train_distribute` is specified and TF_CONFIG is detected, use
# distribute coordinator.
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
config_with_train_distribute = run_config_lib.RunConfig(
experimental_distribute=DistributeConfig(
train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
config_with_eval_distribute = run_config_lib.RunConfig(
experimental_distribute=DistributeConfig(
eval_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
self.assertTrue(
dc_training.should_run_distribute_coordinator(
config_with_train_distribute))
self.assertFalse(
dc_training.should_run_distribute_coordinator(
config_with_eval_distribute))
# With a master in the cluster, don't run distribute coordinator.
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}):
config = run_config_lib.RunConfig(
experimental_distribute=DistributeConfig(
train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
self.assertFalse(dc_training.should_run_distribute_coordinator(config))
def test_init_run_config_duplicate_distribute(self):
with self.assertRaises(ValueError):
run_config_lib.RunConfig(
train_distribute=mirrored_strategy.MirroredStrategy(),
experimental_distribute=DistributeConfig(
train_distribute=mirrored_strategy.MirroredStrategy()))
with self.assertRaises(ValueError):
run_config_lib.RunConfig(
eval_distribute=mirrored_strategy.MirroredStrategy(),
experimental_distribute=DistributeConfig(
eval_distribute=mirrored_strategy.MirroredStrategy()))
def test_init_run_config_none_distribute_coordinator_mode(self):
# We don't use distribute coordinator for local training.
config = run_config_lib.RunConfig(
train_distribute=mirrored_strategy.MirroredStrategy())
dc_training.init_run_config(config, {})
self.assertIsNone(config._distribute_coordinator_mode)
# With a master in the cluster, don't run distribute coordinator.
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}):
config = run_config_lib.RunConfig(
train_distribute=mirrored_strategy.MirroredStrategy())
self.assertIsNone(config._distribute_coordinator_mode)
# When `train_distribute` is not specified, don't use distribute
# coordinator.
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
config = run_config_lib.RunConfig()
self.assertFalse(hasattr(config, "_distribute_coordinator_mode"))
def test_init_run_config_independent_worker(self):
# When `train_distribute` is specified and TF_CONFIG is detected, use
# distribute coordinator with INDEPENDENT_WORKER mode.
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
config = run_config_lib.RunConfig(
train_distribute=mirrored_strategy.MirroredStrategy())
self.assertEqual(config._distribute_coordinator_mode,
dc.CoordinatorMode.INDEPENDENT_WORKER)
def test_init_run_config_standalone_client(self):
# When `train_distribute` is specified, TF_CONFIG is detected and
# `experimental.remote_cluster` is set use distribute coordinator with
# STANDALONE_CLIENT mode.
config = run_config_lib.RunConfig(
train_distribute=mirrored_strategy.MirroredStrategy(),
experimental_distribute=DistributeConfig(
remote_cluster={"chief": ["fake_worker"]}))
self.assertEqual(config._distribute_coordinator_mode,
dc.CoordinatorMode.STANDALONE_CLIENT)
if __name__ == "__main__":
with test.mock.patch.object(sys, "exit", os._exit):
test.main()

View File

@ -134,6 +134,7 @@ py_library(
"//tensorflow/core:protos_all_py",
"//tensorflow/python/compat",
"//tensorflow/python/data",
"//tensorflow/python/distribute:estimator_training",
"//tensorflow/python/feature_column:feature_column_py",
"//tensorflow/python/keras",
"//tensorflow/python/ops/distributions",

View File

@ -8,6 +8,25 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
name = "distribute",
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":distribute_config",
":distribute_coordinator",
":distribute_coordinator_context",
],
)
py_library(
name = "distribute_config",
srcs = [
"distribute_config.py",
],
deps = [],
)
py_library(
name = "distribute_coordinator",
srcs = [
@ -81,3 +100,17 @@ py_test(
"@absl_py//absl/testing:parameterized",
],
)
# Used only by estimator.
py_library(
name = "estimator_training",
srcs = [
"estimator_training.py",
],
srcs_version = "PY2AND3",
deps = [
":distribute_coordinator",
":distribute_coordinator_context",
"//tensorflow/python:training",
],
)

View File

@ -0,0 +1,45 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A configure tuple for high-level APIs for running distribution strategies."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
class DistributeConfig(
collections.namedtuple(
'DistributeConfig',
['train_distribute', 'eval_distribute', 'remote_cluster'])):
"""A config tuple for distribution strategies.
Attributes:
train_distribute: a `DistributionStrategy` object for training.
eval_distribute: an optional `DistributionStrategy` object for
evaluation.
remote_cluster: a dict, `ClusterDef` or `ClusterSpec` object specifying
the cluster configurations. If this is given, the `train_and_evaluate`
method will be running as a standalone client which connects to the
cluster for training.
"""
def __new__(cls,
train_distribute=None,
eval_distribute=None,
remote_cluster=None):
return super(DistributeConfig, cls).__new__(cls, train_distribute,
eval_distribute, remote_cluster)

View File

@ -311,7 +311,11 @@ def _run_single_worker(worker_fn,
worker_barrier=None):
"""Runs a single worker by calling `worker_fn` under context."""
strategy = copy.deepcopy(strategy)
strategy.configure(session_config, cluster_spec, task_type, task_id)
# If there is an EVALUATOR task, we run single-machine eval on that task.
if task_type == _TaskType.EVALUATOR:
strategy.configure(session_config)
else:
strategy.configure(session_config, cluster_spec, task_type, task_id)
context = _WorkerContext(
strategy,
cluster_spec,
@ -340,14 +344,14 @@ def _run_std_server(cluster_spec=None,
return server
def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config,
rpc_layer):
def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
cluster_spec, session_config, rpc_layer):
"""Runs a standalone client for between-graph replication."""
eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
args=(eval_fn, eval_strategy, None, _TaskType.EVALUATOR, 0,
session_config),
kwargs={
"rpc_layer": rpc_layer,
@ -378,14 +382,14 @@ def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config,
eval_thread.join()
def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
rpc_layer):
def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
cluster_spec, session_config, rpc_layer):
"""Runs a standalone client for in-graph replication."""
eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
session_config),
kwargs={
"rpc_layer": rpc_layer,
@ -408,6 +412,8 @@ def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
# is the special task when we support cluster_spec propagation.
def run_distribute_coordinator(worker_fn,
strategy,
eval_fn=None,
eval_strategy=None,
mode=CoordinatorMode.STANDALONE_CLIENT,
cluster_spec=None,
task_type=None,
@ -488,10 +494,12 @@ def run_distribute_coordinator(worker_fn,
If `cluster_spec` is not given in any format, it becomes local training and
this coordinator will connect to a local session.
For evaluation, if "evaluator" exist in the cluster_spec, a separate thread
will be created with its `task_type` set to "evaluator". If "evaluator" is not
set in the cluster_spec, it entirely depends on the `worker_fn` for how to do
evaluation.
For evaluation, if "evaluator" exists in the cluster_spec, a separate thread
will be created to call `eval_fn` with its `task_type` set to "evaluator". If
`eval_fn` is not defined, fall back to `worker_fn`. This implies that
evaluation will be done on a single machine if there is an "evaluator" task.
If "evaluator" doesn't exit in the cluster_spec, it entirely depends on the
`worker_fn` for how to do evaluation.
Args:
worker_fn: the function to be called. The function should accept a
@ -501,6 +509,8 @@ def run_distribute_coordinator(worker_fn,
run between-graph replicated training or not, whether to run init ops,
etc. This object will also be configured given `session_config`,
`cluster_spc`, `task_type` and `task_id`.
eval_fn: optional function for "evaluator" task.
eval_strategy: optional DistributionStrategy object for "evaluator" task.
mode: in which mode this distribute coordinator runs.
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
in a cluster. If not set or empty, fall back to local training.
@ -535,16 +545,22 @@ def run_distribute_coordinator(worker_fn,
# `mode` is ignored in the local case.
_run_single_worker(worker_fn, strategy, None, None, None, session_config,
rpc_layer)
if eval_fn:
_run_single_worker(eval_fn, eval_strategy or strategy, None, None, None,
session_config, rpc_layer)
elif mode == CoordinatorMode.STANDALONE_CLIENT:
eval_fn = eval_fn or worker_fn
eval_strategy = eval_strategy or strategy
# The client must know the cluster but servers in the cluster don't have to
# know the client.
if task_type in [_TaskType.CLIENT, None]:
if strategy.between_graph:
_run_between_graph_client(worker_fn, strategy, cluster_spec,
session_config, rpc_layer)
_run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
cluster_spec, session_config, rpc_layer)
else:
_run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
rpc_layer)
_run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
cluster_spec, session_config, rpc_layer)
else:
# If not a client job, run the standard server.
server = _run_std_server(
@ -554,6 +570,9 @@ def run_distribute_coordinator(worker_fn,
if mode != CoordinatorMode.INDEPENDENT_WORKER:
raise ValueError("Unexpected coordinator mode: %r" % mode)
eval_fn = eval_fn or worker_fn
eval_strategy = eval_strategy or strategy
# Every one starts a standard server.
server = _run_std_server(
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
@ -572,8 +591,8 @@ def run_distribute_coordinator(worker_fn,
else:
server.join()
elif task_type == _TaskType.EVALUATOR:
_run_single_worker(worker_fn, strategy, cluster_spec, task_type, task_id,
session_config, rpc_layer)
_run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type,
task_id, session_config, rpc_layer)
else:
if task_type != _TaskType.PS:
raise ValueError("Unexpected task_type: %r" % task_type)

View File

@ -0,0 +1,264 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training utilities for Estimator to use Distribute Coordinator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import six
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
# pylint: disable=protected-access
CHIEF = dc._TaskType.CHIEF
EVALUATOR = dc._TaskType.EVALUATOR
PS = dc._TaskType.PS
WORKER = dc._TaskType.WORKER
# pylint: enable=protected-access
def _count_ps(cluster_spec):
"""Counts the number of parameter servers in cluster_spec."""
if not cluster_spec:
raise RuntimeError(
'Internal error: `_count_ps` does not expect empty cluster_spec.')
return len(cluster_spec.as_dict().get(PS, []))
def _count_worker(cluster_spec, chief_task_type):
"""Counts the number of workers (including chief) in cluster_spec."""
if not cluster_spec:
raise RuntimeError(
'Internal error: `_count_worker` does not expect empty cluster_spec.')
return (len(cluster_spec.as_dict().get(WORKER, [])) + len(
cluster_spec.as_dict().get(chief_task_type, [])))
def _get_global_id(cluster_spec, task_type, task_id, chief_task_type):
"""Returns the global id of the given task type in a cluster."""
if not task_type:
return 0
# Sort task names in cluster by "chief"/"master", "evaluator", "worker"
# and "ps". More details can be found at the documentation of
# @{tf.estimator.RunConfig.global_id_in_cluster}.
task_type_ordered_list = []
if chief_task_type in cluster_spec.jobs:
task_type_ordered_list = [chief_task_type]
task_type_ordered_list.extend([
t for t in sorted(cluster_spec.jobs) if t != chief_task_type and t != PS
])
if PS in cluster_spec.jobs:
task_type_ordered_list.append(PS)
# Find the right gloabl_id for current task.
next_global_id = 0
for t in task_type_ordered_list:
if t == task_type:
return next_global_id + task_id
# `cluster_spec.job_tasks` returns all task addresses of type `t`.
next_global_id += len(cluster_spec.job_tasks(t))
# It is unexpected that it passes through all task_types in
# `task_type_ordered_list`.
raise RuntimeError('Internal Error: `task_type` ({}) is not in '
'cluster_spec ({}).'.format(task_type, cluster_spec))
def _init_run_config_from_worker_context(config, worker_context):
"""Initializes run config from distribute coordinator's worker context."""
# pylint: disable=protected-access
config._service = None
config._cluster_spec = worker_context.cluster_spec
config._task_type = worker_context.task_type
config._task_id = worker_context.task_id
config._evaluation_master = worker_context.master_target
config._master = worker_context.master_target
config._is_chief = worker_context.is_chief
if config._cluster_spec:
# Distributed mode.
if config._task_type != EVALUATOR:
config._num_ps_replicas = _count_ps(config._cluster_spec)
config._num_worker_replicas = _count_worker(
config._cluster_spec, chief_task_type=CHIEF)
config._global_id_in_cluster = _get_global_id(
config._cluster_spec,
config._task_type,
config._task_id,
chief_task_type=CHIEF)
else:
# Evaluator task should not be aware of the other tasks.
config._cluster_spec = server_lib.ClusterSpec({})
config._num_ps_replicas = 0
config._num_worker_replicas = 0
config._global_id_in_cluster = None # undefined
else:
# Local mode.
config._global_id_in_cluster = 0
config._num_ps_replicas = 0
config._num_worker_replicas = 1
def init_run_config(config, tf_config):
"""Initializes RunConfig for distribution strategies."""
# pylint: disable=protected-access
if (config._experimental_distribute and
config._experimental_distribute.train_distribute):
if config._train_distribute:
raise ValueError('Either `train_distribute` or'
'`experimental_distribute.train_distribute` can be set.')
config._train_distribute = config._experimental_distribute.train_distribute
if (config._experimental_distribute and
config._experimental_distribute.eval_distribute):
if config._eval_distribute:
raise ValueError('Either `eval_distribute` or'
'`experimental_distribute.eval_distribute` can be set.')
config._eval_distribute = config._experimental_distribute.eval_distribute
cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {}))
config._init_distributed_setting_from_environment_var({})
# Use distribute coordinator with STANDALONE_CLIENT mode if
# `experimental_distribute.remote_cluster` is set.
if (config._train_distribute and config._experimental_distribute and
config._experimental_distribute.remote_cluster):
if tf_config:
raise ValueError('Cannot set both TF_CONFIG environment variable and '
'`experimental_distribute.remote_cluster`')
config._distribute_coordinator_mode = dc.CoordinatorMode.STANDALONE_CLIENT
config._cluster_spec = config._experimental_distribute.remote_cluster
logging.info('RunConfig initialized for Distribute Coordinator with '
'STANDALONE_CLIENT mode')
return
# Don't use distribute coordinator if it is local training or cluster has a
# MASTER job or `train_distribute` is not specifed.
if (not tf_config or 'master' in cluster_spec.jobs or
not config._train_distribute):
config._distribute_coordinator_mode = None
config._init_distributed_setting_from_environment_var(tf_config)
config._maybe_overwrite_session_config_for_distributed_training()
logging.info('Not using Distribute Coordinator.')
return
# Use distribute coordinator with INDEPENDENT_WORKER mode otherwise.
assert tf_config
# Set the cluster_spec only since the distributed setting will come from
# distribute coordinator.
config._cluster_spec = cluster_spec
config._distribute_coordinator_mode = dc.CoordinatorMode.INDEPENDENT_WORKER
logging.info('RunConfig initialized for Distribute Coordinator with '
'INDEPENDENT_WORKER mode')
def should_run_distribute_coordinator(config):
"""Checks the config to see whether to run distribute coordinator."""
# pylint: disable=protected-access
if (not hasattr(config, '_distribute_coordinator_mode') or
config._distribute_coordinator_mode is None):
return False
if (not isinstance(config._distribute_coordinator_mode, six.string_types) or
config._distribute_coordinator_mode not in [
dc.CoordinatorMode.STANDALONE_CLIENT,
dc.CoordinatorMode.INDEPENDENT_WORKER
]):
logging.warning('Unexpected distribute_coordinator_mode: %r',
config._distribute_coordinator_mode)
return False
if not config.cluster_spec:
logging.warning('Running `train_and_evaluate` locally, ignoring '
'`experimental_distribute_coordinator_mode`.')
return False
return True
def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
"""Run distribute coordinator for Estimator's `train_and_evaluate`.
Args:
estimator: An `Estimator` instance to train and evaluate.
train_spec: A `TrainSpec` instance to specify the training specification.
eval_spec: A `EvalSpec` instance to specify the evaluation and export
specification.
executor_cls: the evaluation executor class of Estimator.
Raises:
ValueError: if `distribute_coordinator_mode` is None in RunConfig.
"""
run_config = estimator.config
if not run_config._distribute_coordinator_mode: # pylint: disable=protected-access
raise ValueError(
'Distribute coordinator mode is not specified in `RunConfig`.')
def _worker_fn(strategy):
"""Function for worker task."""
local_estimator = copy.deepcopy(estimator)
# pylint: disable=protected-access
local_estimator._config._train_distribute = strategy
_init_run_config_from_worker_context(
local_estimator._config, dc_context.get_current_worker_context())
local_estimator._train_distribution = strategy
# pylint: enable=protected-access
local_estimator.train(
input_fn=train_spec.input_fn,
max_steps=train_spec.max_steps,
hooks=list(train_spec.hooks))
def _eval_fn(strategy):
"""Function for evaluator task."""
local_estimator = copy.deepcopy(estimator)
# pylint: disable=protected-access
local_estimator._config._eval_distribute = strategy
_init_run_config_from_worker_context(
local_estimator._config, dc_context.get_current_worker_context())
local_estimator._eval_distribution = strategy
executor = executor_cls(local_estimator, train_spec, eval_spec)
executor._start_continuous_evaluation()
# pylint: enable=protected-access
# pylint: disable=protected-access
if (run_config._distribute_coordinator_mode ==
dc.CoordinatorMode.STANDALONE_CLIENT):
cluster_spec = run_config.cluster_spec
assert cluster_spec
else:
# The cluster_spec comes from TF_CONFIG environment variable if it is
# INDEPENDENT_WORKER mode.
cluster_spec = None
dc.run_distribute_coordinator(
_worker_fn,
run_config.train_distribute,
_eval_fn,
run_config.eval_distribute,
mode=run_config._distribute_coordinator_mode,
cluster_spec=cluster_spec,
session_config=run_config.session_config)

View File

@ -26,6 +26,7 @@ import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.distribute import estimator_training as distribute_coordinator_training
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat_internal
@ -460,7 +461,8 @@ class RunConfig(object):
train_distribute: An optional instance of
`tf.contrib.distribute.DistributionStrategy`. If specified,
then Estimator will distribute the user's model during training,
according to the policy specified by that strategy.
according to the policy specified by that strategy. Setting
`experimental_distribute.train_distribute` is preferred.
device_fn: A callable invoked for every `Operation` that takes the
`Operation` and returns the device string. If `None`, defaults to
the device function returned by `tf.train.replica_device_setter`
@ -470,10 +472,13 @@ class RunConfig(object):
eval_distribute: An optional instance of
`tf.contrib.distribute.DistributionStrategy`. If specified,
then Estimator will distribute the user's model during evaluation,
according to the policy specified by that strategy.
according to the policy specified by that strategy. Setting
`experimental_distribute.eval_distribute` is preferred.
experimental_distribute: an optional
`tf.contrib.distribute.DistributeConfig` object specifying
DistributionStrategy-related configuration.
DistributionStrategy-related configuration. The `train_distribute` and
`eval_distribute` can be passed as parameters to `RunConfig` or set in
`experimental_distribute` but not both.
Raises:
ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs`
@ -516,9 +521,12 @@ class RunConfig(object):
eval_distribute=eval_distribute,
experimental_distribute=experimental_distribute)
self._init_distributed_setting_from_environment_var(tf_config)
self._maybe_overwrite_session_config_for_distributed_training()
if train_distribute or eval_distribute or experimental_distribute:
logging.info('Initializing RunConfig with distribution strategies.')
distribute_coordinator_training.init_run_config(self, tf_config)
else:
self._init_distributed_setting_from_environment_var(tf_config)
self._maybe_overwrite_session_config_for_distributed_training()
def _maybe_overwrite_session_config_for_distributed_training(self):
"""Overwrites the session_config for distributed training.

View File

@ -26,6 +26,7 @@ import time
import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import estimator_training as distribute_coordinator_training
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import exporter as exporter_lib
from tensorflow.python.estimator import run_config as run_config_lib
@ -274,8 +275,10 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
evaluation `input_fn`, steps, etc.
This utility function provides consistent behavior for both local
(non-distributed) and distributed configurations. Currently, the only
supported distributed training configuration is between-graph replication.
(non-distributed) and distributed configurations. The default distribution
configuration is parameter server-based between-graph replication. For other
types of distribution configurations such as all-reduce training, please use
[DistributionStrategies](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute). # pylint: disable=line-too-long
Overfitting: In order to avoid overfitting, it is recommended to set up the
training `input_fn` to shuffle the training data properly.
@ -426,6 +429,11 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
}'
```
When `distribute` or `experimental_distribute.train_distribute` and
`experimental_distribute.remote_cluster` is set, this method will start a
client running on the current host which connects to the `remote_cluster` for
training and evaluation.
Args:
estimator: An `Estimator` instance to train and evaluate.
train_spec: A `TrainSpec` instance to specify the training specification.
@ -444,8 +452,16 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
executor = _TrainingExecutor(
estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
config = estimator.config
# If `distribute_coordinator_mode` is set and running in distributed
# environment, we run `train_and_evaluate` via distribute coordinator.
if distribute_coordinator_training.should_run_distribute_coordinator(config):
logging.info('Running `train_and_evaluate` with Distribute Coordinator.')
distribute_coordinator_training.train_and_evaluate(
estimator, train_spec, eval_spec, _TrainingExecutor)
return
if (config.task_type == run_config_lib.TaskType.EVALUATOR and
config.task_id > 0):
raise ValueError(