Add an option to RunConfig and train_and_evaluate to run distribute coordinator.
This is necessary to run multi-worker MirroredStrategy and CollectiveAllReduceStrategy with estimator. PiperOrigin-RevId: 210192378
This commit is contained in:
parent
9599b47303
commit
ca94990804
@ -35,5 +35,6 @@ py_library(
|
|||||||
"//tensorflow/contrib/distribute/python:tpu_strategy",
|
"//tensorflow/contrib/distribute/python:tpu_strategy",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
|
"//tensorflow/python/distribute:distribute_config",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceSt
|
|||||||
from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy
|
from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy
|
||||||
from tensorflow.contrib.distribute.python.step_fn import *
|
from tensorflow.contrib.distribute.python.step_fn import *
|
||||||
from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy
|
from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy
|
||||||
|
from tensorflow.python.distribute.distribute_config import DistributeConfig
|
||||||
from tensorflow.python.training.distribute import *
|
from tensorflow.python.training.distribute import *
|
||||||
from tensorflow.python.training.distribution_strategy_context import *
|
from tensorflow.python.training.distribution_strategy_context import *
|
||||||
|
|
||||||
@ -37,6 +38,7 @@ _allowed_symbols = [
|
|||||||
'AllReduceCrossTowerOps',
|
'AllReduceCrossTowerOps',
|
||||||
'CollectiveAllReduceStrategy',
|
'CollectiveAllReduceStrategy',
|
||||||
'CrossTowerOps',
|
'CrossTowerOps',
|
||||||
|
'DistributeConfig',
|
||||||
'DistributionStrategy',
|
'DistributionStrategy',
|
||||||
'MirroredStrategy',
|
'MirroredStrategy',
|
||||||
'Monitor',
|
'Monitor',
|
||||||
|
@ -452,6 +452,32 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "estimator_training_test",
|
||||||
|
size = "large",
|
||||||
|
srcs = ["estimator_training_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":combinations",
|
||||||
|
":mirrored_strategy",
|
||||||
|
":multi_worker_test_base",
|
||||||
|
":parameter_server_strategy",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"//tensorflow/contrib/optimizer_v2:training",
|
||||||
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"//tensorflow/python/distribute",
|
||||||
|
"//tensorflow/python/eager:test",
|
||||||
|
"//tensorflow/python/estimator:estimator_py",
|
||||||
|
"//tensorflow/python/feature_column",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"//tensorflow/python:summary",
|
||||||
|
],
|
||||||
|
tags = [
|
||||||
|
"multi_and_single_gpu",
|
||||||
|
"no_pip",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "single_loss_example",
|
name = "single_loss_example",
|
||||||
srcs = ["single_loss_example.py"],
|
srcs = ["single_loss_example.py"],
|
||||||
|
659
tensorflow/contrib/distribute/python/estimator_training_test.py
Normal file
659
tensorflow/contrib/distribute/python/estimator_training_test.py
Normal file
@ -0,0 +1,659 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests that show Distribute Coordinator works with Estimator."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
import six
|
||||||
|
|
||||||
|
_portpicker_import_error = None
|
||||||
|
try:
|
||||||
|
import portpicker # pylint: disable=g-import-not-at-top
|
||||||
|
except ImportError as _error: # pylint: disable=invalid-name
|
||||||
|
_portpicker_import_error = _error
|
||||||
|
portpicker = None
|
||||||
|
|
||||||
|
# pylint: disable=g-import-not-at-top
|
||||||
|
from tensorflow.contrib.distribute.python import combinations
|
||||||
|
from tensorflow.contrib.distribute.python import mirrored_strategy
|
||||||
|
from tensorflow.contrib.distribute.python import parameter_server_strategy
|
||||||
|
from tensorflow.contrib.optimizer_v2 import adagrad
|
||||||
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||||
|
from tensorflow.python.distribute import estimator_training as dc_training
|
||||||
|
from tensorflow.python.distribute.distribute_config import DistributeConfig
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.estimator import exporter as exporter_lib
|
||||||
|
from tensorflow.python.estimator import run_config as run_config_lib
|
||||||
|
from tensorflow.python.estimator import training as estimator_training
|
||||||
|
from tensorflow.python.estimator.canned import dnn_linear_combined
|
||||||
|
from tensorflow.python.estimator.canned import prediction_keys
|
||||||
|
from tensorflow.python.estimator.export import export as export_lib
|
||||||
|
from tensorflow.python.feature_column import feature_column
|
||||||
|
from tensorflow.python.platform import gfile
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.summary import summary_iterator
|
||||||
|
from tensorflow.python.summary.writer import writer_cache
|
||||||
|
from tensorflow.python.training import server_lib
|
||||||
|
|
||||||
|
BATCH_SIZE = 10
|
||||||
|
LABEL_DIMENSION = 2
|
||||||
|
DATA = np.linspace(
|
||||||
|
0., 2., BATCH_SIZE * LABEL_DIMENSION, dtype=np.float32).reshape(
|
||||||
|
BATCH_SIZE, LABEL_DIMENSION)
|
||||||
|
EVAL_NAME = "foo"
|
||||||
|
EXPORTER_NAME = "saved_model_exporter"
|
||||||
|
MAX_STEPS = 10
|
||||||
|
|
||||||
|
CHIEF = dc._TaskType.CHIEF
|
||||||
|
EVALUATOR = dc._TaskType.EVALUATOR
|
||||||
|
WORKER = dc._TaskType.WORKER
|
||||||
|
PS = dc._TaskType.PS
|
||||||
|
|
||||||
|
original_run_distribute_coordinator = dc.run_distribute_coordinator
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(yuefengz): merge this method back to test_util.
|
||||||
|
def _create_local_cluster(num_workers,
|
||||||
|
num_ps,
|
||||||
|
has_eval=False,
|
||||||
|
protocol="grpc",
|
||||||
|
worker_config=None,
|
||||||
|
ps_config=None):
|
||||||
|
if _portpicker_import_error:
|
||||||
|
raise _portpicker_import_error # pylint: disable=raising-bad-type
|
||||||
|
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
|
||||||
|
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
|
||||||
|
|
||||||
|
cluster_dict = {
|
||||||
|
"worker": ["localhost:%s" % port for port in worker_ports],
|
||||||
|
"ps": ["localhost:%s" % port for port in ps_ports]
|
||||||
|
}
|
||||||
|
if has_eval:
|
||||||
|
cluster_dict["evaluator"] = ["localhost:%s" % portpicker.pick_unused_port()]
|
||||||
|
|
||||||
|
cs = server_lib.ClusterSpec(cluster_dict)
|
||||||
|
|
||||||
|
workers = [
|
||||||
|
server_lib.Server(
|
||||||
|
cs,
|
||||||
|
job_name="worker",
|
||||||
|
protocol=protocol,
|
||||||
|
task_index=ix,
|
||||||
|
config=worker_config,
|
||||||
|
start=True) for ix in range(num_workers)
|
||||||
|
]
|
||||||
|
ps_servers = [
|
||||||
|
server_lib.Server(
|
||||||
|
cs,
|
||||||
|
job_name="ps",
|
||||||
|
protocol=protocol,
|
||||||
|
task_index=ix,
|
||||||
|
config=ps_config,
|
||||||
|
start=True) for ix in range(num_ps)
|
||||||
|
]
|
||||||
|
if has_eval:
|
||||||
|
evals = [
|
||||||
|
server_lib.Server(
|
||||||
|
cs,
|
||||||
|
job_name="evaluator",
|
||||||
|
protocol=protocol,
|
||||||
|
task_index=0,
|
||||||
|
config=worker_config,
|
||||||
|
start=True)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
evals = []
|
||||||
|
|
||||||
|
return workers, ps_servers, evals
|
||||||
|
|
||||||
|
|
||||||
|
def _create_in_process_cluster(num_workers, num_ps, has_eval=False):
|
||||||
|
"""Create an in-process cluster that consists of only standard server."""
|
||||||
|
# Leave some memory for cuda runtime.
|
||||||
|
if has_eval:
|
||||||
|
gpu_mem_frac = 0.7 / (num_workers + 1)
|
||||||
|
else:
|
||||||
|
gpu_mem_frac = 0.7 / num_workers
|
||||||
|
|
||||||
|
worker_config = config_pb2.ConfigProto()
|
||||||
|
worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
|
||||||
|
|
||||||
|
# Enable collective ops which has no impact on non-collective ops.
|
||||||
|
# TODO(yuefengz, tucker): removing this after we move the initialization of
|
||||||
|
# collective mgr to the session level.
|
||||||
|
worker_config.experimental.collective_group_leader = (
|
||||||
|
"/job:worker/replica:0/task:0")
|
||||||
|
|
||||||
|
ps_config = config_pb2.ConfigProto()
|
||||||
|
ps_config.device_count["GPU"] = 0
|
||||||
|
|
||||||
|
return _create_local_cluster(
|
||||||
|
num_workers,
|
||||||
|
num_ps=num_ps,
|
||||||
|
has_eval=has_eval,
|
||||||
|
worker_config=worker_config,
|
||||||
|
ps_config=ps_config,
|
||||||
|
protocol="grpc")
|
||||||
|
|
||||||
|
|
||||||
|
def _create_cluster_spec(has_chief=False,
|
||||||
|
num_workers=1,
|
||||||
|
num_ps=0,
|
||||||
|
has_eval=False):
|
||||||
|
if _portpicker_import_error:
|
||||||
|
raise _portpicker_import_error # pylint: disable=raising-bad-type
|
||||||
|
|
||||||
|
cluster_spec = {}
|
||||||
|
if has_chief:
|
||||||
|
cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()]
|
||||||
|
if num_workers:
|
||||||
|
cluster_spec[WORKER] = [
|
||||||
|
"localhost:%s" % portpicker.pick_unused_port()
|
||||||
|
for _ in range(num_workers)
|
||||||
|
]
|
||||||
|
if num_ps:
|
||||||
|
cluster_spec[PS] = [
|
||||||
|
"localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps)
|
||||||
|
]
|
||||||
|
if has_eval:
|
||||||
|
cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
|
||||||
|
return cluster_spec
|
||||||
|
|
||||||
|
|
||||||
|
def _bytes_to_str(maybe_bytes):
|
||||||
|
if isinstance(maybe_bytes, six.string_types):
|
||||||
|
return maybe_bytes
|
||||||
|
else:
|
||||||
|
return str(maybe_bytes, "utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_protocol(target):
|
||||||
|
# cluster_spec expects "host:port" strings.
|
||||||
|
if "//" in target:
|
||||||
|
return target.split("//")[1]
|
||||||
|
else:
|
||||||
|
return target
|
||||||
|
|
||||||
|
|
||||||
|
class DistributeCoordinatorIntegrationTest(test.TestCase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
"""Create a local cluster with 2 workers."""
|
||||||
|
cls._workers, cls._ps, cls._evals = _create_in_process_cluster(
|
||||||
|
num_workers=3, num_ps=2, has_eval=True)
|
||||||
|
cls._cluster_spec = {
|
||||||
|
"worker": [
|
||||||
|
_strip_protocol(_bytes_to_str(w.target)) for w in cls._workers
|
||||||
|
],
|
||||||
|
"ps": [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps],
|
||||||
|
"evaluator": [
|
||||||
|
_strip_protocol(_bytes_to_str(e.target)) for e in cls._evals
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._model_dir = tempfile.mkdtemp()
|
||||||
|
self._event = threading.Event()
|
||||||
|
super(DistributeCoordinatorIntegrationTest, self).setUp()
|
||||||
|
|
||||||
|
def dataset_input_fn(self, x, y, batch_size, shuffle):
|
||||||
|
|
||||||
|
def input_fn():
|
||||||
|
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
|
||||||
|
if shuffle:
|
||||||
|
dataset = dataset.shuffle(batch_size)
|
||||||
|
dataset = dataset.repeat(100).batch(batch_size)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
return input_fn
|
||||||
|
|
||||||
|
def _get_exporter(self, name, fc):
|
||||||
|
feature_spec = feature_column.make_parse_example_spec(fc)
|
||||||
|
serving_input_receiver_fn = (
|
||||||
|
export_lib.build_parsing_serving_input_receiver_fn(feature_spec))
|
||||||
|
return exporter_lib.LatestExporter(
|
||||||
|
name, serving_input_receiver_fn=serving_input_receiver_fn)
|
||||||
|
|
||||||
|
def _extract_loss_and_global_step(self, event_folder):
|
||||||
|
"""Returns the loss and global step in last event."""
|
||||||
|
event_paths = glob.glob(os.path.join(event_folder, "events*"))
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
global_step_count = None
|
||||||
|
|
||||||
|
for e in summary_iterator.summary_iterator(event_paths[-1]):
|
||||||
|
current_loss = None
|
||||||
|
for v in e.summary.value:
|
||||||
|
if v.tag == "loss":
|
||||||
|
current_loss = v.simple_value
|
||||||
|
|
||||||
|
# If loss is not found, global step is meaningless.
|
||||||
|
if current_loss is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_global_step = e.step
|
||||||
|
if global_step_count is None or current_global_step > global_step_count:
|
||||||
|
global_step_count = current_global_step
|
||||||
|
loss = current_loss
|
||||||
|
|
||||||
|
return (loss, global_step_count)
|
||||||
|
|
||||||
|
def _get_estimator(self,
|
||||||
|
train_distribute,
|
||||||
|
eval_distribute,
|
||||||
|
remote_cluster=None):
|
||||||
|
input_dimension = LABEL_DIMENSION
|
||||||
|
linear_feature_columns = [
|
||||||
|
feature_column.numeric_column("x", shape=(input_dimension,))
|
||||||
|
]
|
||||||
|
dnn_feature_columns = [
|
||||||
|
feature_column.numeric_column("x", shape=(input_dimension,))
|
||||||
|
]
|
||||||
|
|
||||||
|
return dnn_linear_combined.DNNLinearCombinedRegressor(
|
||||||
|
linear_feature_columns=linear_feature_columns,
|
||||||
|
dnn_hidden_units=(2, 2),
|
||||||
|
dnn_feature_columns=dnn_feature_columns,
|
||||||
|
label_dimension=LABEL_DIMENSION,
|
||||||
|
model_dir=self._model_dir,
|
||||||
|
dnn_optimizer=adagrad.AdagradOptimizer(0.001),
|
||||||
|
linear_optimizer=adagrad.AdagradOptimizer(0.001),
|
||||||
|
config=run_config_lib.RunConfig(
|
||||||
|
experimental_distribute=DistributeConfig(
|
||||||
|
train_distribute=train_distribute,
|
||||||
|
eval_distribute=eval_distribute,
|
||||||
|
remote_cluster=remote_cluster)))
|
||||||
|
|
||||||
|
def _complete_flow(self,
|
||||||
|
train_distribute,
|
||||||
|
eval_distribute,
|
||||||
|
remote_cluster=None):
|
||||||
|
estimator = self._get_estimator(train_distribute, eval_distribute,
|
||||||
|
remote_cluster)
|
||||||
|
|
||||||
|
input_dimension = LABEL_DIMENSION
|
||||||
|
train_input_fn = self.dataset_input_fn(
|
||||||
|
x={"x": DATA},
|
||||||
|
y=DATA,
|
||||||
|
batch_size=BATCH_SIZE // len(train_distribute.worker_devices),
|
||||||
|
shuffle=True)
|
||||||
|
if eval_distribute:
|
||||||
|
eval_batch_size = BATCH_SIZE // len(eval_distribute.worker_devices)
|
||||||
|
else:
|
||||||
|
eval_batch_size = BATCH_SIZE
|
||||||
|
eval_input_fn = self.dataset_input_fn(
|
||||||
|
x={"x": DATA}, y=DATA, batch_size=eval_batch_size, shuffle=False)
|
||||||
|
|
||||||
|
linear_feature_columns = [
|
||||||
|
feature_column.numeric_column("x", shape=(input_dimension,))
|
||||||
|
]
|
||||||
|
dnn_feature_columns = [
|
||||||
|
feature_column.numeric_column("x", shape=(input_dimension,))
|
||||||
|
]
|
||||||
|
feature_columns = linear_feature_columns + dnn_feature_columns
|
||||||
|
|
||||||
|
estimator_training.train_and_evaluate(
|
||||||
|
estimator,
|
||||||
|
estimator_training.TrainSpec(train_input_fn, max_steps=MAX_STEPS),
|
||||||
|
estimator_training.EvalSpec(
|
||||||
|
name=EVAL_NAME,
|
||||||
|
input_fn=eval_input_fn,
|
||||||
|
steps=None,
|
||||||
|
exporters=self._get_exporter(EXPORTER_NAME, feature_columns),
|
||||||
|
start_delay_secs=0,
|
||||||
|
throttle_secs=1))
|
||||||
|
return estimator
|
||||||
|
|
||||||
|
def _inspect_train_and_eval_events(self, estimator):
|
||||||
|
# Make sure nothing is stuck in limbo.
|
||||||
|
writer_cache.FileWriterCache.clear()
|
||||||
|
|
||||||
|
# Examine the training events. Use a range to check global step to avoid
|
||||||
|
# flakyness due to global step race condition.
|
||||||
|
training_loss, _ = self._extract_loss_and_global_step(self._model_dir)
|
||||||
|
self.assertIsNotNone(training_loss)
|
||||||
|
|
||||||
|
# Examine the eval events. The global step should be accurate.
|
||||||
|
eval_dir = os.path.join(self._model_dir, "eval_" + EVAL_NAME)
|
||||||
|
eval_loss, eval_global_step = self._extract_loss_and_global_step(
|
||||||
|
event_folder=eval_dir)
|
||||||
|
self.assertIsNotNone(eval_loss)
|
||||||
|
self.assertGreaterEqual(eval_global_step, MAX_STEPS)
|
||||||
|
|
||||||
|
# Examine the export folder.
|
||||||
|
export_dir = os.path.join(
|
||||||
|
os.path.join(self._model_dir, "export"), EXPORTER_NAME)
|
||||||
|
self.assertTrue(gfile.Exists(export_dir))
|
||||||
|
|
||||||
|
# Examine the ckpt for predict.
|
||||||
|
def predict_input_fn():
|
||||||
|
return dataset_ops.Dataset.from_tensor_slices({
|
||||||
|
"x": DATA
|
||||||
|
}).batch(BATCH_SIZE)
|
||||||
|
|
||||||
|
predicted_proba = np.array([
|
||||||
|
x[prediction_keys.PredictionKeys.PREDICTIONS]
|
||||||
|
for x in estimator.predict(predict_input_fn)
|
||||||
|
])
|
||||||
|
self.assertAllEqual((BATCH_SIZE, LABEL_DIMENSION), predicted_proba.shape)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
mode=["graph"],
|
||||||
|
train_distribute_cls=[
|
||||||
|
mirrored_strategy.MirroredStrategy,
|
||||||
|
parameter_server_strategy.ParameterServerStrategy
|
||||||
|
],
|
||||||
|
eval_distribute_cls=[
|
||||||
|
None, mirrored_strategy.MirroredStrategy,
|
||||||
|
parameter_server_strategy.ParameterServerStrategy
|
||||||
|
],
|
||||||
|
required_gpus=1))
|
||||||
|
def test_complete_flow_standalone_client(self, train_distribute_cls,
|
||||||
|
eval_distribute_cls):
|
||||||
|
try:
|
||||||
|
train_distribute = train_distribute_cls(num_gpus=context.num_gpus())
|
||||||
|
except TypeError:
|
||||||
|
train_distribute = train_distribute_cls(num_gpus_per_worker=2)
|
||||||
|
|
||||||
|
if eval_distribute_cls:
|
||||||
|
eval_distribute = eval_distribute_cls()
|
||||||
|
else:
|
||||||
|
eval_distribute = None
|
||||||
|
|
||||||
|
estimator = self._complete_flow(
|
||||||
|
train_distribute, eval_distribute, remote_cluster=self._cluster_spec)
|
||||||
|
self._inspect_train_and_eval_events(estimator)
|
||||||
|
|
||||||
|
def _mock_run_distribute_coordinator(
|
||||||
|
self,
|
||||||
|
worker_fn,
|
||||||
|
strategy,
|
||||||
|
eval_fn,
|
||||||
|
eval_strategy,
|
||||||
|
mode=dc.CoordinatorMode.STANDALONE_CLIENT,
|
||||||
|
cluster_spec=None,
|
||||||
|
session_config=None):
|
||||||
|
# Calls the origial `run_distribute_coordinator` method but gets task config
|
||||||
|
# from environment variables and then signals the caller.
|
||||||
|
task_type = None
|
||||||
|
task_id = None
|
||||||
|
if not cluster_spec:
|
||||||
|
cluster_spec = None
|
||||||
|
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
|
||||||
|
if not cluster_spec:
|
||||||
|
cluster_spec = tf_config.get("cluster", {})
|
||||||
|
task_env = tf_config.get("task", {})
|
||||||
|
if task_env:
|
||||||
|
task_type = task_env.get("type", task_type)
|
||||||
|
task_id = int(task_env.get("index", task_id))
|
||||||
|
self._event.set()
|
||||||
|
original_run_distribute_coordinator(
|
||||||
|
worker_fn,
|
||||||
|
strategy,
|
||||||
|
eval_fn,
|
||||||
|
eval_strategy,
|
||||||
|
mode=mode,
|
||||||
|
cluster_spec=cluster_spec,
|
||||||
|
task_type=task_type,
|
||||||
|
task_id=task_id,
|
||||||
|
session_config=session_config)
|
||||||
|
|
||||||
|
def _task_thread(self, train_distribute, eval_distribute):
|
||||||
|
with test.mock.patch.object(dc, "run_distribute_coordinator",
|
||||||
|
self._mock_run_distribute_coordinator):
|
||||||
|
self._complete_flow(train_distribute, eval_distribute)
|
||||||
|
|
||||||
|
def _run_task_in_thread(self, cluster_spec, task_type, task_id,
|
||||||
|
train_distribute, eval_distribute):
|
||||||
|
if task_type:
|
||||||
|
tf_config = {
|
||||||
|
"cluster": cluster_spec,
|
||||||
|
"task": {
|
||||||
|
"type": task_type,
|
||||||
|
"index": task_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
tf_config = {
|
||||||
|
"cluster": cluster_spec,
|
||||||
|
"task": {
|
||||||
|
"type": task_type,
|
||||||
|
"index": task_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self._event.clear()
|
||||||
|
t = threading.Thread(
|
||||||
|
target=self._task_thread, args=(train_distribute, eval_distribute))
|
||||||
|
with test.mock.patch.dict("os.environ",
|
||||||
|
{"TF_CONFIG": json.dumps(tf_config)}):
|
||||||
|
t.start()
|
||||||
|
self._event.wait()
|
||||||
|
return t
|
||||||
|
|
||||||
|
def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute,
|
||||||
|
eval_distribute):
|
||||||
|
threads = {}
|
||||||
|
for task_type in cluster_spec.keys():
|
||||||
|
threads[task_type] = []
|
||||||
|
for task_id in range(len(cluster_spec[task_type])):
|
||||||
|
t = self._run_task_in_thread(cluster_spec, task_type, task_id,
|
||||||
|
train_distribute, eval_distribute)
|
||||||
|
threads[task_type].append(t)
|
||||||
|
return threads
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
mode=["graph"],
|
||||||
|
train_distribute_cls=[
|
||||||
|
parameter_server_strategy.ParameterServerStrategy,
|
||||||
|
],
|
||||||
|
eval_distribute_cls=[
|
||||||
|
None, mirrored_strategy.MirroredStrategy,
|
||||||
|
parameter_server_strategy.ParameterServerStrategy
|
||||||
|
],
|
||||||
|
required_gpus=1))
|
||||||
|
def test_complete_flow_indepedent_worker_between_graph(
|
||||||
|
self, train_distribute_cls, eval_distribute_cls):
|
||||||
|
train_distribute = train_distribute_cls(
|
||||||
|
num_gpus_per_worker=context.num_gpus())
|
||||||
|
|
||||||
|
if eval_distribute_cls:
|
||||||
|
eval_distribute = eval_distribute_cls()
|
||||||
|
else:
|
||||||
|
eval_distribute = None
|
||||||
|
|
||||||
|
cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
|
||||||
|
threads = self._run_multiple_tasks_in_threads(
|
||||||
|
cluster_spec, train_distribute, eval_distribute)
|
||||||
|
for task_type, ts in threads.items():
|
||||||
|
if task_type == PS:
|
||||||
|
continue
|
||||||
|
for t in ts:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
estimator = self._get_estimator(train_distribute, eval_distribute)
|
||||||
|
self._inspect_train_and_eval_events(estimator)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
mode=["graph"],
|
||||||
|
train_distribute_cls=[mirrored_strategy.MirroredStrategy],
|
||||||
|
eval_distribute_cls=[None, mirrored_strategy.MirroredStrategy],
|
||||||
|
required_gpus=1))
|
||||||
|
def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls,
|
||||||
|
eval_distribute_cls):
|
||||||
|
train_distribute = train_distribute_cls(num_gpus=context.num_gpus())
|
||||||
|
|
||||||
|
if eval_distribute_cls:
|
||||||
|
eval_distribute = eval_distribute_cls()
|
||||||
|
else:
|
||||||
|
eval_distribute = None
|
||||||
|
|
||||||
|
cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
|
||||||
|
threads = self._run_multiple_tasks_in_threads(
|
||||||
|
cluster_spec, train_distribute, eval_distribute)
|
||||||
|
threads[WORKER][0].join()
|
||||||
|
threads[EVALUATOR][0].join()
|
||||||
|
|
||||||
|
estimator = self._get_estimator(train_distribute, eval_distribute)
|
||||||
|
self._inspect_train_and_eval_events(estimator)
|
||||||
|
|
||||||
|
|
||||||
|
TF_CONFIG_WITH_CHIEF = {
|
||||||
|
"cluster": {
|
||||||
|
"chief": ["fake_chief"],
|
||||||
|
},
|
||||||
|
"task": {
|
||||||
|
"type": "chief",
|
||||||
|
"index": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CONFIG_WITH_MASTER = {
|
||||||
|
"cluster": {
|
||||||
|
"master": ["fake_master"],
|
||||||
|
},
|
||||||
|
"task": {
|
||||||
|
"type": "master",
|
||||||
|
"index": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CONFIG_WITHOUT_TASK = {"cluster": {"chief": ["fake_worker"]}}
|
||||||
|
|
||||||
|
|
||||||
|
class RunConfigTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_previously_unexpected_cluster_spec(self):
|
||||||
|
with test.mock.patch.dict(
|
||||||
|
"os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITHOUT_TASK)}):
|
||||||
|
run_config_lib.RunConfig(
|
||||||
|
experimental_distribute=DistributeConfig(
|
||||||
|
train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
|
||||||
|
|
||||||
|
def test_should_run_distribute_coordinator(self):
|
||||||
|
"""Tests that should_run_distribute_coordinator return a correct value."""
|
||||||
|
# We don't use distribute coordinator for local training.
|
||||||
|
self.assertFalse(
|
||||||
|
dc_training.should_run_distribute_coordinator(
|
||||||
|
run_config_lib.RunConfig()))
|
||||||
|
|
||||||
|
# When `train_distribute` is not specified, don't use distribute
|
||||||
|
# coordinator.
|
||||||
|
with test.mock.patch.dict("os.environ",
|
||||||
|
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
|
||||||
|
self.assertFalse(
|
||||||
|
dc_training.should_run_distribute_coordinator(
|
||||||
|
run_config_lib.RunConfig()))
|
||||||
|
|
||||||
|
# When `train_distribute` is specified and TF_CONFIG is detected, use
|
||||||
|
# distribute coordinator.
|
||||||
|
with test.mock.patch.dict("os.environ",
|
||||||
|
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
|
||||||
|
config_with_train_distribute = run_config_lib.RunConfig(
|
||||||
|
experimental_distribute=DistributeConfig(
|
||||||
|
train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
|
||||||
|
config_with_eval_distribute = run_config_lib.RunConfig(
|
||||||
|
experimental_distribute=DistributeConfig(
|
||||||
|
eval_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
|
||||||
|
self.assertTrue(
|
||||||
|
dc_training.should_run_distribute_coordinator(
|
||||||
|
config_with_train_distribute))
|
||||||
|
self.assertFalse(
|
||||||
|
dc_training.should_run_distribute_coordinator(
|
||||||
|
config_with_eval_distribute))
|
||||||
|
|
||||||
|
# With a master in the cluster, don't run distribute coordinator.
|
||||||
|
with test.mock.patch.dict("os.environ",
|
||||||
|
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}):
|
||||||
|
config = run_config_lib.RunConfig(
|
||||||
|
experimental_distribute=DistributeConfig(
|
||||||
|
train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
|
||||||
|
self.assertFalse(dc_training.should_run_distribute_coordinator(config))
|
||||||
|
|
||||||
|
def test_init_run_config_duplicate_distribute(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
run_config_lib.RunConfig(
|
||||||
|
train_distribute=mirrored_strategy.MirroredStrategy(),
|
||||||
|
experimental_distribute=DistributeConfig(
|
||||||
|
train_distribute=mirrored_strategy.MirroredStrategy()))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
run_config_lib.RunConfig(
|
||||||
|
eval_distribute=mirrored_strategy.MirroredStrategy(),
|
||||||
|
experimental_distribute=DistributeConfig(
|
||||||
|
eval_distribute=mirrored_strategy.MirroredStrategy()))
|
||||||
|
|
||||||
|
def test_init_run_config_none_distribute_coordinator_mode(self):
|
||||||
|
# We don't use distribute coordinator for local training.
|
||||||
|
config = run_config_lib.RunConfig(
|
||||||
|
train_distribute=mirrored_strategy.MirroredStrategy())
|
||||||
|
dc_training.init_run_config(config, {})
|
||||||
|
self.assertIsNone(config._distribute_coordinator_mode)
|
||||||
|
|
||||||
|
# With a master in the cluster, don't run distribute coordinator.
|
||||||
|
with test.mock.patch.dict("os.environ",
|
||||||
|
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}):
|
||||||
|
config = run_config_lib.RunConfig(
|
||||||
|
train_distribute=mirrored_strategy.MirroredStrategy())
|
||||||
|
self.assertIsNone(config._distribute_coordinator_mode)
|
||||||
|
|
||||||
|
# When `train_distribute` is not specified, don't use distribute
|
||||||
|
# coordinator.
|
||||||
|
with test.mock.patch.dict("os.environ",
|
||||||
|
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
|
||||||
|
config = run_config_lib.RunConfig()
|
||||||
|
self.assertFalse(hasattr(config, "_distribute_coordinator_mode"))
|
||||||
|
|
||||||
|
def test_init_run_config_independent_worker(self):
|
||||||
|
# When `train_distribute` is specified and TF_CONFIG is detected, use
|
||||||
|
# distribute coordinator with INDEPENDENT_WORKER mode.
|
||||||
|
with test.mock.patch.dict("os.environ",
|
||||||
|
{"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
|
||||||
|
config = run_config_lib.RunConfig(
|
||||||
|
train_distribute=mirrored_strategy.MirroredStrategy())
|
||||||
|
self.assertEqual(config._distribute_coordinator_mode,
|
||||||
|
dc.CoordinatorMode.INDEPENDENT_WORKER)
|
||||||
|
|
||||||
|
def test_init_run_config_standalone_client(self):
|
||||||
|
# When `train_distribute` is specified, TF_CONFIG is detected and
|
||||||
|
# `experimental.remote_cluster` is set use distribute coordinator with
|
||||||
|
# STANDALONE_CLIENT mode.
|
||||||
|
config = run_config_lib.RunConfig(
|
||||||
|
train_distribute=mirrored_strategy.MirroredStrategy(),
|
||||||
|
experimental_distribute=DistributeConfig(
|
||||||
|
remote_cluster={"chief": ["fake_worker"]}))
|
||||||
|
self.assertEqual(config._distribute_coordinator_mode,
|
||||||
|
dc.CoordinatorMode.STANDALONE_CLIENT)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
with test.mock.patch.object(sys, "exit", os._exit):
|
||||||
|
test.main()
|
@ -134,6 +134,7 @@ py_library(
|
|||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python/compat",
|
"//tensorflow/python/compat",
|
||||||
"//tensorflow/python/data",
|
"//tensorflow/python/data",
|
||||||
|
"//tensorflow/python/distribute:estimator_training",
|
||||||
"//tensorflow/python/feature_column:feature_column_py",
|
"//tensorflow/python/feature_column:feature_column_py",
|
||||||
"//tensorflow/python/keras",
|
"//tensorflow/python/keras",
|
||||||
"//tensorflow/python/ops/distributions",
|
"//tensorflow/python/ops/distributions",
|
||||||
|
@ -8,6 +8,25 @@ exports_files(["LICENSE"])
|
|||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "distribute",
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":distribute_config",
|
||||||
|
":distribute_coordinator",
|
||||||
|
":distribute_coordinator_context",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "distribute_config",
|
||||||
|
srcs = [
|
||||||
|
"distribute_config.py",
|
||||||
|
],
|
||||||
|
deps = [],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "distribute_coordinator",
|
name = "distribute_coordinator",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -81,3 +100,17 @@ py_test(
|
|||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Used only by estimator.
|
||||||
|
py_library(
|
||||||
|
name = "estimator_training",
|
||||||
|
srcs = [
|
||||||
|
"estimator_training.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":distribute_coordinator",
|
||||||
|
":distribute_coordinator_context",
|
||||||
|
"//tensorflow/python:training",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
45
tensorflow/python/distribute/distribute_config.py
Normal file
45
tensorflow/python/distribute/distribute_config.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""A configure tuple for high-level APIs for running distribution strategies."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
|
||||||
|
class DistributeConfig(
|
||||||
|
collections.namedtuple(
|
||||||
|
'DistributeConfig',
|
||||||
|
['train_distribute', 'eval_distribute', 'remote_cluster'])):
|
||||||
|
"""A config tuple for distribution strategies.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
train_distribute: a `DistributionStrategy` object for training.
|
||||||
|
eval_distribute: an optional `DistributionStrategy` object for
|
||||||
|
evaluation.
|
||||||
|
remote_cluster: a dict, `ClusterDef` or `ClusterSpec` object specifying
|
||||||
|
the cluster configurations. If this is given, the `train_and_evaluate`
|
||||||
|
method will be running as a standalone client which connects to the
|
||||||
|
cluster for training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls,
|
||||||
|
train_distribute=None,
|
||||||
|
eval_distribute=None,
|
||||||
|
remote_cluster=None):
|
||||||
|
return super(DistributeConfig, cls).__new__(cls, train_distribute,
|
||||||
|
eval_distribute, remote_cluster)
|
@ -311,7 +311,11 @@ def _run_single_worker(worker_fn,
|
|||||||
worker_barrier=None):
|
worker_barrier=None):
|
||||||
"""Runs a single worker by calling `worker_fn` under context."""
|
"""Runs a single worker by calling `worker_fn` under context."""
|
||||||
strategy = copy.deepcopy(strategy)
|
strategy = copy.deepcopy(strategy)
|
||||||
strategy.configure(session_config, cluster_spec, task_type, task_id)
|
# If there is an EVALUATOR task, we run single-machine eval on that task.
|
||||||
|
if task_type == _TaskType.EVALUATOR:
|
||||||
|
strategy.configure(session_config)
|
||||||
|
else:
|
||||||
|
strategy.configure(session_config, cluster_spec, task_type, task_id)
|
||||||
context = _WorkerContext(
|
context = _WorkerContext(
|
||||||
strategy,
|
strategy,
|
||||||
cluster_spec,
|
cluster_spec,
|
||||||
@ -340,14 +344,14 @@ def _run_std_server(cluster_spec=None,
|
|||||||
return server
|
return server
|
||||||
|
|
||||||
|
|
||||||
def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config,
|
def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
||||||
rpc_layer):
|
cluster_spec, session_config, rpc_layer):
|
||||||
"""Runs a standalone client for between-graph replication."""
|
"""Runs a standalone client for between-graph replication."""
|
||||||
eval_thread = None
|
eval_thread = None
|
||||||
if _TaskType.EVALUATOR in cluster_spec.jobs:
|
if _TaskType.EVALUATOR in cluster_spec.jobs:
|
||||||
eval_thread = threading.Thread(
|
eval_thread = threading.Thread(
|
||||||
target=_run_single_worker,
|
target=_run_single_worker,
|
||||||
args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
|
args=(eval_fn, eval_strategy, None, _TaskType.EVALUATOR, 0,
|
||||||
session_config),
|
session_config),
|
||||||
kwargs={
|
kwargs={
|
||||||
"rpc_layer": rpc_layer,
|
"rpc_layer": rpc_layer,
|
||||||
@ -378,14 +382,14 @@ def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config,
|
|||||||
eval_thread.join()
|
eval_thread.join()
|
||||||
|
|
||||||
|
|
||||||
def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
|
def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
||||||
rpc_layer):
|
cluster_spec, session_config, rpc_layer):
|
||||||
"""Runs a standalone client for in-graph replication."""
|
"""Runs a standalone client for in-graph replication."""
|
||||||
eval_thread = None
|
eval_thread = None
|
||||||
if _TaskType.EVALUATOR in cluster_spec.jobs:
|
if _TaskType.EVALUATOR in cluster_spec.jobs:
|
||||||
eval_thread = threading.Thread(
|
eval_thread = threading.Thread(
|
||||||
target=_run_single_worker,
|
target=_run_single_worker,
|
||||||
args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
|
args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
|
||||||
session_config),
|
session_config),
|
||||||
kwargs={
|
kwargs={
|
||||||
"rpc_layer": rpc_layer,
|
"rpc_layer": rpc_layer,
|
||||||
@ -408,6 +412,8 @@ def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
|
|||||||
# is the special task when we support cluster_spec propagation.
|
# is the special task when we support cluster_spec propagation.
|
||||||
def run_distribute_coordinator(worker_fn,
|
def run_distribute_coordinator(worker_fn,
|
||||||
strategy,
|
strategy,
|
||||||
|
eval_fn=None,
|
||||||
|
eval_strategy=None,
|
||||||
mode=CoordinatorMode.STANDALONE_CLIENT,
|
mode=CoordinatorMode.STANDALONE_CLIENT,
|
||||||
cluster_spec=None,
|
cluster_spec=None,
|
||||||
task_type=None,
|
task_type=None,
|
||||||
@ -488,10 +494,12 @@ def run_distribute_coordinator(worker_fn,
|
|||||||
If `cluster_spec` is not given in any format, it becomes local training and
|
If `cluster_spec` is not given in any format, it becomes local training and
|
||||||
this coordinator will connect to a local session.
|
this coordinator will connect to a local session.
|
||||||
|
|
||||||
For evaluation, if "evaluator" exist in the cluster_spec, a separate thread
|
For evaluation, if "evaluator" exists in the cluster_spec, a separate thread
|
||||||
will be created with its `task_type` set to "evaluator". If "evaluator" is not
|
will be created to call `eval_fn` with its `task_type` set to "evaluator". If
|
||||||
set in the cluster_spec, it entirely depends on the `worker_fn` for how to do
|
`eval_fn` is not defined, fall back to `worker_fn`. This implies that
|
||||||
evaluation.
|
evaluation will be done on a single machine if there is an "evaluator" task.
|
||||||
|
If "evaluator" doesn't exit in the cluster_spec, it entirely depends on the
|
||||||
|
`worker_fn` for how to do evaluation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
worker_fn: the function to be called. The function should accept a
|
worker_fn: the function to be called. The function should accept a
|
||||||
@ -501,6 +509,8 @@ def run_distribute_coordinator(worker_fn,
|
|||||||
run between-graph replicated training or not, whether to run init ops,
|
run between-graph replicated training or not, whether to run init ops,
|
||||||
etc. This object will also be configured given `session_config`,
|
etc. This object will also be configured given `session_config`,
|
||||||
`cluster_spc`, `task_type` and `task_id`.
|
`cluster_spc`, `task_type` and `task_id`.
|
||||||
|
eval_fn: optional function for "evaluator" task.
|
||||||
|
eval_strategy: optional DistributionStrategy object for "evaluator" task.
|
||||||
mode: in which mode this distribute coordinator runs.
|
mode: in which mode this distribute coordinator runs.
|
||||||
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
|
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
|
||||||
in a cluster. If not set or empty, fall back to local training.
|
in a cluster. If not set or empty, fall back to local training.
|
||||||
@ -535,16 +545,22 @@ def run_distribute_coordinator(worker_fn,
|
|||||||
# `mode` is ignored in the local case.
|
# `mode` is ignored in the local case.
|
||||||
_run_single_worker(worker_fn, strategy, None, None, None, session_config,
|
_run_single_worker(worker_fn, strategy, None, None, None, session_config,
|
||||||
rpc_layer)
|
rpc_layer)
|
||||||
|
if eval_fn:
|
||||||
|
_run_single_worker(eval_fn, eval_strategy or strategy, None, None, None,
|
||||||
|
session_config, rpc_layer)
|
||||||
elif mode == CoordinatorMode.STANDALONE_CLIENT:
|
elif mode == CoordinatorMode.STANDALONE_CLIENT:
|
||||||
|
eval_fn = eval_fn or worker_fn
|
||||||
|
eval_strategy = eval_strategy or strategy
|
||||||
|
|
||||||
# The client must know the cluster but servers in the cluster don't have to
|
# The client must know the cluster but servers in the cluster don't have to
|
||||||
# know the client.
|
# know the client.
|
||||||
if task_type in [_TaskType.CLIENT, None]:
|
if task_type in [_TaskType.CLIENT, None]:
|
||||||
if strategy.between_graph:
|
if strategy.between_graph:
|
||||||
_run_between_graph_client(worker_fn, strategy, cluster_spec,
|
_run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
||||||
session_config, rpc_layer)
|
cluster_spec, session_config, rpc_layer)
|
||||||
else:
|
else:
|
||||||
_run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
|
_run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
|
||||||
rpc_layer)
|
cluster_spec, session_config, rpc_layer)
|
||||||
else:
|
else:
|
||||||
# If not a client job, run the standard server.
|
# If not a client job, run the standard server.
|
||||||
server = _run_std_server(
|
server = _run_std_server(
|
||||||
@ -554,6 +570,9 @@ def run_distribute_coordinator(worker_fn,
|
|||||||
if mode != CoordinatorMode.INDEPENDENT_WORKER:
|
if mode != CoordinatorMode.INDEPENDENT_WORKER:
|
||||||
raise ValueError("Unexpected coordinator mode: %r" % mode)
|
raise ValueError("Unexpected coordinator mode: %r" % mode)
|
||||||
|
|
||||||
|
eval_fn = eval_fn or worker_fn
|
||||||
|
eval_strategy = eval_strategy or strategy
|
||||||
|
|
||||||
# Every one starts a standard server.
|
# Every one starts a standard server.
|
||||||
server = _run_std_server(
|
server = _run_std_server(
|
||||||
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
|
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
|
||||||
@ -572,8 +591,8 @@ def run_distribute_coordinator(worker_fn,
|
|||||||
else:
|
else:
|
||||||
server.join()
|
server.join()
|
||||||
elif task_type == _TaskType.EVALUATOR:
|
elif task_type == _TaskType.EVALUATOR:
|
||||||
_run_single_worker(worker_fn, strategy, cluster_spec, task_type, task_id,
|
_run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type,
|
||||||
session_config, rpc_layer)
|
task_id, session_config, rpc_layer)
|
||||||
else:
|
else:
|
||||||
if task_type != _TaskType.PS:
|
if task_type != _TaskType.PS:
|
||||||
raise ValueError("Unexpected task_type: %r" % task_type)
|
raise ValueError("Unexpected task_type: %r" % task_type)
|
||||||
|
264
tensorflow/python/distribute/estimator_training.py
Normal file
264
tensorflow/python/distribute/estimator_training.py
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Training utilities for Estimator to use Distribute Coordinator."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||||
|
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.training import server_lib
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
CHIEF = dc._TaskType.CHIEF
|
||||||
|
EVALUATOR = dc._TaskType.EVALUATOR
|
||||||
|
PS = dc._TaskType.PS
|
||||||
|
WORKER = dc._TaskType.WORKER
|
||||||
|
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
def _count_ps(cluster_spec):
|
||||||
|
"""Counts the number of parameter servers in cluster_spec."""
|
||||||
|
if not cluster_spec:
|
||||||
|
raise RuntimeError(
|
||||||
|
'Internal error: `_count_ps` does not expect empty cluster_spec.')
|
||||||
|
|
||||||
|
return len(cluster_spec.as_dict().get(PS, []))
|
||||||
|
|
||||||
|
|
||||||
|
def _count_worker(cluster_spec, chief_task_type):
|
||||||
|
"""Counts the number of workers (including chief) in cluster_spec."""
|
||||||
|
if not cluster_spec:
|
||||||
|
raise RuntimeError(
|
||||||
|
'Internal error: `_count_worker` does not expect empty cluster_spec.')
|
||||||
|
|
||||||
|
return (len(cluster_spec.as_dict().get(WORKER, [])) + len(
|
||||||
|
cluster_spec.as_dict().get(chief_task_type, [])))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_global_id(cluster_spec, task_type, task_id, chief_task_type):
|
||||||
|
"""Returns the global id of the given task type in a cluster."""
|
||||||
|
if not task_type:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Sort task names in cluster by "chief"/"master", "evaluator", "worker"
|
||||||
|
# and "ps". More details can be found at the documentation of
|
||||||
|
# @{tf.estimator.RunConfig.global_id_in_cluster}.
|
||||||
|
task_type_ordered_list = []
|
||||||
|
if chief_task_type in cluster_spec.jobs:
|
||||||
|
task_type_ordered_list = [chief_task_type]
|
||||||
|
task_type_ordered_list.extend([
|
||||||
|
t for t in sorted(cluster_spec.jobs) if t != chief_task_type and t != PS
|
||||||
|
])
|
||||||
|
if PS in cluster_spec.jobs:
|
||||||
|
task_type_ordered_list.append(PS)
|
||||||
|
|
||||||
|
# Find the right gloabl_id for current task.
|
||||||
|
next_global_id = 0
|
||||||
|
for t in task_type_ordered_list:
|
||||||
|
if t == task_type:
|
||||||
|
return next_global_id + task_id
|
||||||
|
# `cluster_spec.job_tasks` returns all task addresses of type `t`.
|
||||||
|
next_global_id += len(cluster_spec.job_tasks(t))
|
||||||
|
|
||||||
|
# It is unexpected that it passes through all task_types in
|
||||||
|
# `task_type_ordered_list`.
|
||||||
|
raise RuntimeError('Internal Error: `task_type` ({}) is not in '
|
||||||
|
'cluster_spec ({}).'.format(task_type, cluster_spec))
|
||||||
|
|
||||||
|
|
||||||
|
def _init_run_config_from_worker_context(config, worker_context):
|
||||||
|
"""Initializes run config from distribute coordinator's worker context."""
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
config._service = None
|
||||||
|
config._cluster_spec = worker_context.cluster_spec
|
||||||
|
config._task_type = worker_context.task_type
|
||||||
|
config._task_id = worker_context.task_id
|
||||||
|
config._evaluation_master = worker_context.master_target
|
||||||
|
config._master = worker_context.master_target
|
||||||
|
config._is_chief = worker_context.is_chief
|
||||||
|
|
||||||
|
if config._cluster_spec:
|
||||||
|
# Distributed mode.
|
||||||
|
if config._task_type != EVALUATOR:
|
||||||
|
|
||||||
|
config._num_ps_replicas = _count_ps(config._cluster_spec)
|
||||||
|
config._num_worker_replicas = _count_worker(
|
||||||
|
config._cluster_spec, chief_task_type=CHIEF)
|
||||||
|
config._global_id_in_cluster = _get_global_id(
|
||||||
|
config._cluster_spec,
|
||||||
|
config._task_type,
|
||||||
|
config._task_id,
|
||||||
|
chief_task_type=CHIEF)
|
||||||
|
else:
|
||||||
|
# Evaluator task should not be aware of the other tasks.
|
||||||
|
config._cluster_spec = server_lib.ClusterSpec({})
|
||||||
|
config._num_ps_replicas = 0
|
||||||
|
config._num_worker_replicas = 0
|
||||||
|
config._global_id_in_cluster = None # undefined
|
||||||
|
else:
|
||||||
|
# Local mode.
|
||||||
|
config._global_id_in_cluster = 0
|
||||||
|
config._num_ps_replicas = 0
|
||||||
|
config._num_worker_replicas = 1
|
||||||
|
|
||||||
|
|
||||||
|
def init_run_config(config, tf_config):
|
||||||
|
"""Initializes RunConfig for distribution strategies."""
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
if (config._experimental_distribute and
|
||||||
|
config._experimental_distribute.train_distribute):
|
||||||
|
if config._train_distribute:
|
||||||
|
raise ValueError('Either `train_distribute` or'
|
||||||
|
'`experimental_distribute.train_distribute` can be set.')
|
||||||
|
config._train_distribute = config._experimental_distribute.train_distribute
|
||||||
|
|
||||||
|
if (config._experimental_distribute and
|
||||||
|
config._experimental_distribute.eval_distribute):
|
||||||
|
if config._eval_distribute:
|
||||||
|
raise ValueError('Either `eval_distribute` or'
|
||||||
|
'`experimental_distribute.eval_distribute` can be set.')
|
||||||
|
config._eval_distribute = config._experimental_distribute.eval_distribute
|
||||||
|
|
||||||
|
cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {}))
|
||||||
|
config._init_distributed_setting_from_environment_var({})
|
||||||
|
|
||||||
|
# Use distribute coordinator with STANDALONE_CLIENT mode if
|
||||||
|
# `experimental_distribute.remote_cluster` is set.
|
||||||
|
if (config._train_distribute and config._experimental_distribute and
|
||||||
|
config._experimental_distribute.remote_cluster):
|
||||||
|
if tf_config:
|
||||||
|
raise ValueError('Cannot set both TF_CONFIG environment variable and '
|
||||||
|
'`experimental_distribute.remote_cluster`')
|
||||||
|
config._distribute_coordinator_mode = dc.CoordinatorMode.STANDALONE_CLIENT
|
||||||
|
config._cluster_spec = config._experimental_distribute.remote_cluster
|
||||||
|
logging.info('RunConfig initialized for Distribute Coordinator with '
|
||||||
|
'STANDALONE_CLIENT mode')
|
||||||
|
return
|
||||||
|
|
||||||
|
# Don't use distribute coordinator if it is local training or cluster has a
|
||||||
|
# MASTER job or `train_distribute` is not specifed.
|
||||||
|
if (not tf_config or 'master' in cluster_spec.jobs or
|
||||||
|
not config._train_distribute):
|
||||||
|
config._distribute_coordinator_mode = None
|
||||||
|
config._init_distributed_setting_from_environment_var(tf_config)
|
||||||
|
config._maybe_overwrite_session_config_for_distributed_training()
|
||||||
|
logging.info('Not using Distribute Coordinator.')
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use distribute coordinator with INDEPENDENT_WORKER mode otherwise.
|
||||||
|
assert tf_config
|
||||||
|
|
||||||
|
# Set the cluster_spec only since the distributed setting will come from
|
||||||
|
# distribute coordinator.
|
||||||
|
config._cluster_spec = cluster_spec
|
||||||
|
config._distribute_coordinator_mode = dc.CoordinatorMode.INDEPENDENT_WORKER
|
||||||
|
logging.info('RunConfig initialized for Distribute Coordinator with '
|
||||||
|
'INDEPENDENT_WORKER mode')
|
||||||
|
|
||||||
|
|
||||||
|
def should_run_distribute_coordinator(config):
|
||||||
|
"""Checks the config to see whether to run distribute coordinator."""
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
if (not hasattr(config, '_distribute_coordinator_mode') or
|
||||||
|
config._distribute_coordinator_mode is None):
|
||||||
|
return False
|
||||||
|
if (not isinstance(config._distribute_coordinator_mode, six.string_types) or
|
||||||
|
config._distribute_coordinator_mode not in [
|
||||||
|
dc.CoordinatorMode.STANDALONE_CLIENT,
|
||||||
|
dc.CoordinatorMode.INDEPENDENT_WORKER
|
||||||
|
]):
|
||||||
|
logging.warning('Unexpected distribute_coordinator_mode: %r',
|
||||||
|
config._distribute_coordinator_mode)
|
||||||
|
return False
|
||||||
|
if not config.cluster_spec:
|
||||||
|
logging.warning('Running `train_and_evaluate` locally, ignoring '
|
||||||
|
'`experimental_distribute_coordinator_mode`.')
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
|
||||||
|
"""Run distribute coordinator for Estimator's `train_and_evaluate`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
estimator: An `Estimator` instance to train and evaluate.
|
||||||
|
train_spec: A `TrainSpec` instance to specify the training specification.
|
||||||
|
eval_spec: A `EvalSpec` instance to specify the evaluation and export
|
||||||
|
specification.
|
||||||
|
executor_cls: the evaluation executor class of Estimator.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `distribute_coordinator_mode` is None in RunConfig.
|
||||||
|
"""
|
||||||
|
run_config = estimator.config
|
||||||
|
if not run_config._distribute_coordinator_mode: # pylint: disable=protected-access
|
||||||
|
raise ValueError(
|
||||||
|
'Distribute coordinator mode is not specified in `RunConfig`.')
|
||||||
|
|
||||||
|
def _worker_fn(strategy):
|
||||||
|
"""Function for worker task."""
|
||||||
|
local_estimator = copy.deepcopy(estimator)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
local_estimator._config._train_distribute = strategy
|
||||||
|
_init_run_config_from_worker_context(
|
||||||
|
local_estimator._config, dc_context.get_current_worker_context())
|
||||||
|
local_estimator._train_distribution = strategy
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
local_estimator.train(
|
||||||
|
input_fn=train_spec.input_fn,
|
||||||
|
max_steps=train_spec.max_steps,
|
||||||
|
hooks=list(train_spec.hooks))
|
||||||
|
|
||||||
|
def _eval_fn(strategy):
|
||||||
|
"""Function for evaluator task."""
|
||||||
|
local_estimator = copy.deepcopy(estimator)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
local_estimator._config._eval_distribute = strategy
|
||||||
|
_init_run_config_from_worker_context(
|
||||||
|
local_estimator._config, dc_context.get_current_worker_context())
|
||||||
|
local_estimator._eval_distribution = strategy
|
||||||
|
|
||||||
|
executor = executor_cls(local_estimator, train_spec, eval_spec)
|
||||||
|
executor._start_continuous_evaluation()
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
if (run_config._distribute_coordinator_mode ==
|
||||||
|
dc.CoordinatorMode.STANDALONE_CLIENT):
|
||||||
|
cluster_spec = run_config.cluster_spec
|
||||||
|
assert cluster_spec
|
||||||
|
else:
|
||||||
|
# The cluster_spec comes from TF_CONFIG environment variable if it is
|
||||||
|
# INDEPENDENT_WORKER mode.
|
||||||
|
cluster_spec = None
|
||||||
|
|
||||||
|
dc.run_distribute_coordinator(
|
||||||
|
_worker_fn,
|
||||||
|
run_config.train_distribute,
|
||||||
|
_eval_fn,
|
||||||
|
run_config.eval_distribute,
|
||||||
|
mode=run_config._distribute_coordinator_mode,
|
||||||
|
cluster_spec=cluster_spec,
|
||||||
|
session_config=run_config.session_config)
|
@ -26,6 +26,7 @@ import six
|
|||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
|
from tensorflow.python.distribute import estimator_training as distribute_coordinator_training
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
from tensorflow.python.util import compat_internal
|
from tensorflow.python.util import compat_internal
|
||||||
@ -460,7 +461,8 @@ class RunConfig(object):
|
|||||||
train_distribute: An optional instance of
|
train_distribute: An optional instance of
|
||||||
`tf.contrib.distribute.DistributionStrategy`. If specified,
|
`tf.contrib.distribute.DistributionStrategy`. If specified,
|
||||||
then Estimator will distribute the user's model during training,
|
then Estimator will distribute the user's model during training,
|
||||||
according to the policy specified by that strategy.
|
according to the policy specified by that strategy. Setting
|
||||||
|
`experimental_distribute.train_distribute` is preferred.
|
||||||
device_fn: A callable invoked for every `Operation` that takes the
|
device_fn: A callable invoked for every `Operation` that takes the
|
||||||
`Operation` and returns the device string. If `None`, defaults to
|
`Operation` and returns the device string. If `None`, defaults to
|
||||||
the device function returned by `tf.train.replica_device_setter`
|
the device function returned by `tf.train.replica_device_setter`
|
||||||
@ -470,10 +472,13 @@ class RunConfig(object):
|
|||||||
eval_distribute: An optional instance of
|
eval_distribute: An optional instance of
|
||||||
`tf.contrib.distribute.DistributionStrategy`. If specified,
|
`tf.contrib.distribute.DistributionStrategy`. If specified,
|
||||||
then Estimator will distribute the user's model during evaluation,
|
then Estimator will distribute the user's model during evaluation,
|
||||||
according to the policy specified by that strategy.
|
according to the policy specified by that strategy. Setting
|
||||||
|
`experimental_distribute.eval_distribute` is preferred.
|
||||||
experimental_distribute: an optional
|
experimental_distribute: an optional
|
||||||
`tf.contrib.distribute.DistributeConfig` object specifying
|
`tf.contrib.distribute.DistributeConfig` object specifying
|
||||||
DistributionStrategy-related configuration.
|
DistributionStrategy-related configuration. The `train_distribute` and
|
||||||
|
`eval_distribute` can be passed as parameters to `RunConfig` or set in
|
||||||
|
`experimental_distribute` but not both.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs`
|
ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs`
|
||||||
@ -516,9 +521,12 @@ class RunConfig(object):
|
|||||||
eval_distribute=eval_distribute,
|
eval_distribute=eval_distribute,
|
||||||
experimental_distribute=experimental_distribute)
|
experimental_distribute=experimental_distribute)
|
||||||
|
|
||||||
self._init_distributed_setting_from_environment_var(tf_config)
|
if train_distribute or eval_distribute or experimental_distribute:
|
||||||
|
logging.info('Initializing RunConfig with distribution strategies.')
|
||||||
self._maybe_overwrite_session_config_for_distributed_training()
|
distribute_coordinator_training.init_run_config(self, tf_config)
|
||||||
|
else:
|
||||||
|
self._init_distributed_setting_from_environment_var(tf_config)
|
||||||
|
self._maybe_overwrite_session_config_for_distributed_training()
|
||||||
|
|
||||||
def _maybe_overwrite_session_config_for_distributed_training(self):
|
def _maybe_overwrite_session_config_for_distributed_training(self):
|
||||||
"""Overwrites the session_config for distributed training.
|
"""Overwrites the session_config for distributed training.
|
||||||
|
@ -26,6 +26,7 @@ import time
|
|||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python.distribute import estimator_training as distribute_coordinator_training
|
||||||
from tensorflow.python.estimator import estimator as estimator_lib
|
from tensorflow.python.estimator import estimator as estimator_lib
|
||||||
from tensorflow.python.estimator import exporter as exporter_lib
|
from tensorflow.python.estimator import exporter as exporter_lib
|
||||||
from tensorflow.python.estimator import run_config as run_config_lib
|
from tensorflow.python.estimator import run_config as run_config_lib
|
||||||
@ -274,8 +275,10 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
|
|||||||
evaluation `input_fn`, steps, etc.
|
evaluation `input_fn`, steps, etc.
|
||||||
|
|
||||||
This utility function provides consistent behavior for both local
|
This utility function provides consistent behavior for both local
|
||||||
(non-distributed) and distributed configurations. Currently, the only
|
(non-distributed) and distributed configurations. The default distribution
|
||||||
supported distributed training configuration is between-graph replication.
|
configuration is parameter server-based between-graph replication. For other
|
||||||
|
types of distribution configurations such as all-reduce training, please use
|
||||||
|
[DistributionStrategies](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute). # pylint: disable=line-too-long
|
||||||
|
|
||||||
Overfitting: In order to avoid overfitting, it is recommended to set up the
|
Overfitting: In order to avoid overfitting, it is recommended to set up the
|
||||||
training `input_fn` to shuffle the training data properly.
|
training `input_fn` to shuffle the training data properly.
|
||||||
@ -426,6 +429,11 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
When `distribute` or `experimental_distribute.train_distribute` and
|
||||||
|
`experimental_distribute.remote_cluster` is set, this method will start a
|
||||||
|
client running on the current host which connects to the `remote_cluster` for
|
||||||
|
training and evaluation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
estimator: An `Estimator` instance to train and evaluate.
|
estimator: An `Estimator` instance to train and evaluate.
|
||||||
train_spec: A `TrainSpec` instance to specify the training specification.
|
train_spec: A `TrainSpec` instance to specify the training specification.
|
||||||
@ -444,8 +452,16 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
|
|||||||
|
|
||||||
executor = _TrainingExecutor(
|
executor = _TrainingExecutor(
|
||||||
estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
|
estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
|
||||||
|
|
||||||
config = estimator.config
|
config = estimator.config
|
||||||
|
|
||||||
|
# If `distribute_coordinator_mode` is set and running in distributed
|
||||||
|
# environment, we run `train_and_evaluate` via distribute coordinator.
|
||||||
|
if distribute_coordinator_training.should_run_distribute_coordinator(config):
|
||||||
|
logging.info('Running `train_and_evaluate` with Distribute Coordinator.')
|
||||||
|
distribute_coordinator_training.train_and_evaluate(
|
||||||
|
estimator, train_spec, eval_spec, _TrainingExecutor)
|
||||||
|
return
|
||||||
|
|
||||||
if (config.task_type == run_config_lib.TaskType.EVALUATOR and
|
if (config.task_type == run_config_lib.TaskType.EVALUATOR and
|
||||||
config.task_id > 0):
|
config.task_id > 0):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
Loading…
Reference in New Issue
Block a user