Multi-worker tutorial: Add the workflow of MWMS+CTL example that is going to be added in the tutorial in multi_worker_tutorial_test.
Fix the flakiness of the test and re-enable in TAP. PiperOrigin-RevId: 339966024 Change-Id: Icb866f8a7054fa88f2e474c02960982a57c542b3
This commit is contained in:
parent
3d1f1b062d
commit
64edb2fb2d
@ -789,29 +789,6 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "multi_worker_tutorial_test",
|
||||
srcs = ["multi_worker_tutorial_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
"noasan", # TODO(b/156029134)
|
||||
"nomsan", # TODO(b/156029134)
|
||||
"notap", # TODO(b/165865820): restore when not flaky
|
||||
"notsan", # TODO(b/156029134)
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute:collective_all_reduce_strategy",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:multi_process_runner",
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python/keras/optimizer_v2",
|
||||
],
|
||||
)
|
||||
|
||||
distribute_py_test(
|
||||
name = "saved_model_save_load_test",
|
||||
size = "medium",
|
||||
|
@ -1,228 +0,0 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test for multi-worker training tutorial."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
from absl import logging
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import combinations as ds_combinations
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import test_combinations as combinations
|
||||
from tensorflow.python.keras.datasets import mnist
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training.tracking import util as tracking_util
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase):
|
||||
"""Test multi-worker training flow demo'ed in go/multi-worker-with-keras."""
|
||||
|
||||
@contextlib.contextmanager
|
||||
def skip_fetch_failure_exception(self):
|
||||
try:
|
||||
yield
|
||||
except zipfile.BadZipfile as e:
|
||||
self.skipTest('Data loading error: Bad magic number for file header.')
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
if 'URL fetch failure' in str(e):
|
||||
self.skipTest('URL fetch error not considered failure of the test.')
|
||||
else:
|
||||
raise
|
||||
|
||||
@ds_combinations.generate(
|
||||
combinations.combine(
|
||||
mode=['eager'],
|
||||
shard_policy=[None] + list(distribute_options.AutoShardPolicy)))
|
||||
def testMultiWorkerTutorial(self, mode, shard_policy):
|
||||
"""Test multi-worker training flow demo'ed in go/multi-worker-with-keras.
|
||||
|
||||
This test should be kept in sync with the code samples in
|
||||
go/multi-worker-with-keras.
|
||||
|
||||
Args:
|
||||
mode: Runtime mode.
|
||||
shard_policy: None or any of tf.data.experimental.AutoShardPolicy for
|
||||
testing.
|
||||
"""
|
||||
if shard_policy is distribute_options.AutoShardPolicy.FILE:
|
||||
self.skipTest('TensorSliceDataset is not shardable with FILE policy.')
|
||||
|
||||
def mnist_dataset(batch_size):
|
||||
with self.skip_fetch_failure_exception():
|
||||
(x_train, y_train), _ = mnist.load_data()
|
||||
# The `x` arrays are in uint8 and have values in the range [0, 255].
|
||||
# We need to convert them to float32 with values in the range [0, 1]
|
||||
x_train = x_train / np.float32(255)
|
||||
y_train = y_train.astype(np.int64)
|
||||
train_dataset = dataset_ops.DatasetV2.from_tensor_slices(
|
||||
(x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
|
||||
return train_dataset
|
||||
|
||||
def build_and_compile_cnn_model():
|
||||
model = keras.Sequential([
|
||||
keras.layers.Input(shape=(28, 28)),
|
||||
keras.layers.Reshape(target_shape=(28, 28, 1)),
|
||||
keras.layers.Conv2D(32, 3, activation='relu'),
|
||||
keras.layers.Flatten(),
|
||||
keras.layers.Dense(128, activation='relu'),
|
||||
keras.layers.Dense(10)
|
||||
])
|
||||
model.compile(
|
||||
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
||||
optimizer=gradient_descent.SGD(learning_rate=0.001),
|
||||
metrics=['accuracy'])
|
||||
return model
|
||||
|
||||
per_worker_batch_size = 64
|
||||
|
||||
single_worker_dataset = mnist_dataset(per_worker_batch_size)
|
||||
single_worker_model = build_and_compile_cnn_model()
|
||||
single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)
|
||||
|
||||
num_workers = 4
|
||||
|
||||
def fn(model_path, checkpoint_dir):
|
||||
global_batch_size = per_worker_batch_size * num_workers
|
||||
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
|
||||
with strategy.scope():
|
||||
multi_worker_model = build_and_compile_cnn_model()
|
||||
|
||||
callbacks = [
|
||||
keras.callbacks.ModelCheckpoint(
|
||||
filepath=os.path.join(self.get_temp_dir(), 'checkpoint'))
|
||||
]
|
||||
|
||||
multi_worker_dataset = mnist_dataset(global_batch_size)
|
||||
if shard_policy:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_distribute.auto_shard_policy = shard_policy
|
||||
multi_worker_dataset = multi_worker_dataset.with_options(options)
|
||||
|
||||
multi_worker_model.fit(
|
||||
multi_worker_dataset,
|
||||
epochs=2,
|
||||
steps_per_epoch=20,
|
||||
callbacks=callbacks)
|
||||
|
||||
def _is_chief(task_type, task_id):
|
||||
return task_type is None or task_type == 'chief' or (
|
||||
task_type == 'worker' and task_id == 0)
|
||||
|
||||
def _get_temp_dir(dirpath, task_id):
|
||||
base_dirpath = 'workertemp_' + str(task_id)
|
||||
temp_dir = os.path.join(dirpath, base_dirpath)
|
||||
file_io.recursive_create_dir_v2(temp_dir)
|
||||
return temp_dir
|
||||
|
||||
def write_filepath(filepath, task_type, task_id):
|
||||
dirpath = os.path.dirname(filepath)
|
||||
base = os.path.basename(filepath)
|
||||
if not _is_chief(task_type, task_id):
|
||||
dirpath = _get_temp_dir(dirpath, task_id)
|
||||
return os.path.join(dirpath, base)
|
||||
|
||||
task_type, task_id = (strategy.cluster_resolver.task_type,
|
||||
strategy.cluster_resolver.task_id)
|
||||
write_model_path = write_filepath(model_path, task_type, task_id)
|
||||
|
||||
multi_worker_model.save(write_model_path)
|
||||
if not _is_chief(task_type, task_id):
|
||||
file_io.delete_recursively_v2(os.path.dirname(write_model_path))
|
||||
|
||||
# Make sure chief finishes saving before non-chief's assertions.
|
||||
multi_process_runner.get_barrier().wait()
|
||||
|
||||
if not file_io.file_exists_v2(model_path):
|
||||
raise RuntimeError()
|
||||
if file_io.file_exists_v2(write_model_path) != _is_chief(
|
||||
task_type, task_id):
|
||||
raise RuntimeError()
|
||||
|
||||
loaded_model = keras.saving.save.load_model(model_path)
|
||||
loaded_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20)
|
||||
|
||||
checkpoint = tracking_util.Checkpoint(model=multi_worker_model)
|
||||
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)
|
||||
checkpoint_manager = checkpoint_management.CheckpointManager(
|
||||
checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
|
||||
|
||||
checkpoint_manager.save()
|
||||
if not _is_chief(task_type, task_id):
|
||||
file_io.delete_recursively_v2(write_checkpoint_dir)
|
||||
|
||||
# Make sure chief finishes saving before non-chief's assertions.
|
||||
multi_process_runner.get_barrier().wait()
|
||||
|
||||
if not file_io.file_exists_v2(checkpoint_dir):
|
||||
raise RuntimeError()
|
||||
if file_io.file_exists_v2(write_checkpoint_dir) != _is_chief(
|
||||
task_type, task_id):
|
||||
raise RuntimeError()
|
||||
|
||||
latest_checkpoint = checkpoint_management.latest_checkpoint(
|
||||
checkpoint_dir)
|
||||
checkpoint.restore(latest_checkpoint)
|
||||
multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20)
|
||||
|
||||
logging.info('testMultiWorkerTutorial successfully ends')
|
||||
|
||||
model_path = os.path.join(self.get_temp_dir(), 'model.tf')
|
||||
checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt')
|
||||
try:
|
||||
mpr_result = multi_process_runner.run(
|
||||
fn,
|
||||
multi_worker_test_base.create_cluster_spec(num_workers=num_workers),
|
||||
args=(model_path, checkpoint_dir),
|
||||
return_output=True)
|
||||
except errors_impl.UnavailableError as e:
|
||||
self.skipTest('Skipping error: {}: {}'.format(type(e), str(e)))
|
||||
|
||||
self.assertTrue(
|
||||
any([
|
||||
'testMultiWorkerTutorial successfully ends' in msg
|
||||
for msg in mpr_result.stdout
|
||||
]))
|
||||
|
||||
def extract_accuracy(worker_id, input_string):
|
||||
match = re.match(
|
||||
r'\[worker\-{}\].*accuracy: (\d+\.\d+).*'.format(worker_id),
|
||||
input_string)
|
||||
return None if match is None else float(match.group(1))
|
||||
|
||||
for worker_id in range(num_workers):
|
||||
accu_result = nest.map_structure(
|
||||
lambda x: extract_accuracy(worker_id, x), # pylint: disable=cell-var-from-loop
|
||||
mpr_result.stdout)
|
||||
self.assertTrue(
|
||||
any(accu_result), 'Every worker is supposed to have accuracy result.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
multi_process_runner.test_main()
|
@ -103,3 +103,18 @@ tpu_py_test(
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "multi_worker_tutorial_test",
|
||||
srcs = ["multi_worker_tutorial_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 3,
|
||||
tags = [
|
||||
"noasan", # TODO(b/156029134)
|
||||
"nomsan", # TODO(b/156029134)
|
||||
"notsan", # TODO(b/156029134)
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,346 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test for multi-worker training tutorial."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
import unittest
|
||||
import uuid
|
||||
import zipfile
|
||||
from absl import logging
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
PER_WORKER_BATCH_SIZE = 64
|
||||
NUM_WORKERS = 2
|
||||
NUM_EPOCHS = 2
|
||||
NUM_STEPS_PER_EPOCH = 50
|
||||
|
||||
|
||||
def _is_chief(task_type, task_id):
|
||||
return task_type is None or task_type == 'chief' or (task_type == 'worker' and
|
||||
task_id == 0)
|
||||
|
||||
|
||||
def _get_temp_dir(dirpath, task_id):
|
||||
base_dirpath = 'workertemp_' + str(task_id)
|
||||
temp_dir = os.path.join(dirpath, base_dirpath)
|
||||
tf.io.gfile.makedirs(temp_dir)
|
||||
return temp_dir
|
||||
|
||||
|
||||
def write_filepath(filepath, task_type, task_id):
|
||||
dirpath = os.path.dirname(filepath)
|
||||
base = os.path.basename(filepath)
|
||||
if not _is_chief(task_type, task_id):
|
||||
dirpath = _get_temp_dir(dirpath, task_id)
|
||||
return os.path.join(dirpath, base)
|
||||
|
||||
|
||||
class MultiWorkerTutorialTest(parameterized.TestCase, tf.test.TestCase):
|
||||
"""Test of multi-worker training flow in tutorials on tensorflow.org.
|
||||
|
||||
Please see below test method docs for what actual tutorial is being covered.
|
||||
"""
|
||||
|
||||
# TODO(rchao): Add a test to demonstrate gather with MWMS.
|
||||
|
||||
@contextlib.contextmanager
|
||||
def skip_fetch_failure_exception(self):
|
||||
try:
|
||||
yield
|
||||
except zipfile.BadZipfile as e:
|
||||
# There can be a race when multiple processes are downloading the data.
|
||||
# Skip the test if that results in loading errors.
|
||||
self.skipTest('Data loading error: Bad magic number for file header.')
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
if 'URL fetch failure' in str(e):
|
||||
self.skipTest('URL fetch error not considered failure of the test.')
|
||||
else:
|
||||
raise
|
||||
|
||||
def mnist_dataset(self):
|
||||
path_to_use = 'mnist_{}.npz'.format(str(uuid.uuid4()))
|
||||
with self.skip_fetch_failure_exception():
|
||||
(x_train,
|
||||
y_train), _ = tf.keras.datasets.mnist.load_data(path=path_to_use)
|
||||
# The `x` arrays are in uint8 and have values in the range [0, 255].
|
||||
# We need to convert them to float32 with values in the range [0, 1]
|
||||
x_train = x_train / np.float32(255)
|
||||
y_train = y_train.astype(np.int64)
|
||||
train_dataset = tf.data.Dataset.from_tensor_slices(
|
||||
(x_train, y_train)).shuffle(60000)
|
||||
return train_dataset
|
||||
|
||||
def dataset_fn(self, global_batch_size, input_context):
|
||||
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
|
||||
dataset = self.mnist_dataset()
|
||||
dataset = dataset.shard(input_context.num_input_pipelines,
|
||||
input_context.input_pipeline_id)
|
||||
dataset = dataset.batch(batch_size)
|
||||
return dataset
|
||||
|
||||
def build_cnn_model(self):
|
||||
return tf.keras.Sequential([
|
||||
tf.keras.layers.Input(shape=(28, 28)),
|
||||
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
|
||||
tf.keras.layers.Conv2D(32, 3, activation='relu'),
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dense(128, activation='relu'),
|
||||
tf.keras.layers.Dense(10)
|
||||
])
|
||||
|
||||
def build_and_compile_cnn_model(self):
|
||||
model = self.build_cnn_model()
|
||||
model.compile(
|
||||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
||||
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
|
||||
metrics=['accuracy'])
|
||||
return model
|
||||
|
||||
@tf.__internal__.test.combinations.generate(
|
||||
tf.__internal__.test.combinations.combine(
|
||||
mode=['eager'], tf_api_version=2))
|
||||
def testSingleWorkerModelFit(self):
|
||||
single_worker_dataset = self.mnist_dataset().batch(
|
||||
PER_WORKER_BATCH_SIZE)
|
||||
single_worker_model = self.build_and_compile_cnn_model()
|
||||
single_worker_model.fit(single_worker_dataset, epochs=NUM_EPOCHS)
|
||||
|
||||
@tf.__internal__.test.combinations.generate(
|
||||
tf.__internal__.test.combinations.combine(
|
||||
mode=['eager'], tf_api_version=2))
|
||||
def testMwmsWithModelFit(self, mode):
|
||||
"""Test multi-worker training flow demo'ed in go/multi-worker-with-keras.
|
||||
|
||||
This test should be kept in sync with the code samples in
|
||||
go/multi-worker-with-keras.
|
||||
|
||||
Args:
|
||||
mode: Runtime mode.
|
||||
"""
|
||||
def fn(model_path, checkpoint_dir):
|
||||
global_batch_size = PER_WORKER_BATCH_SIZE * NUM_WORKERS
|
||||
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||||
with strategy.scope():
|
||||
multi_worker_model = self.build_and_compile_cnn_model()
|
||||
|
||||
callbacks = [
|
||||
tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=os.path.join(self.get_temp_dir(), 'checkpoint'))
|
||||
]
|
||||
|
||||
multi_worker_dataset = strategy.distribute_datasets_from_function(
|
||||
lambda input_context: self.dataset_fn(global_batch_size, input_context
|
||||
))
|
||||
|
||||
multi_worker_model.fit(
|
||||
multi_worker_dataset,
|
||||
epochs=NUM_EPOCHS,
|
||||
steps_per_epoch=50,
|
||||
callbacks=callbacks)
|
||||
|
||||
task_type, task_id = (strategy.cluster_resolver.task_type,
|
||||
strategy.cluster_resolver.task_id)
|
||||
write_model_path = write_filepath(model_path, task_type, task_id)
|
||||
|
||||
multi_worker_model.save(write_model_path)
|
||||
if not _is_chief(task_type, task_id):
|
||||
tf.io.gfile.rmtree(os.path.dirname(write_model_path))
|
||||
|
||||
# Make sure chief finishes saving before non-chief's assertions.
|
||||
tf.__internal__.distribute.multi_process_runner.get_barrier().wait()
|
||||
|
||||
if not tf.io.gfile.exists(model_path):
|
||||
raise RuntimeError()
|
||||
if tf.io.gfile.exists(write_model_path) != _is_chief(task_type, task_id):
|
||||
raise RuntimeError()
|
||||
|
||||
with strategy.scope():
|
||||
loaded_model = tf.keras.models.load_model(model_path)
|
||||
loaded_model.fit(multi_worker_dataset, epochs=1, steps_per_epoch=1)
|
||||
|
||||
checkpoint = tf.train.Checkpoint(model=multi_worker_model)
|
||||
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)
|
||||
checkpoint_manager = tf.train.CheckpointManager(
|
||||
checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
|
||||
|
||||
checkpoint_manager.save()
|
||||
if not _is_chief(task_type, task_id):
|
||||
tf.io.gfile.rmtree(write_checkpoint_dir)
|
||||
|
||||
# Make sure chief finishes saving before non-chief's assertions.
|
||||
tf.__internal__.distribute.multi_process_runner.get_barrier().wait()
|
||||
|
||||
if not tf.io.gfile.exists(checkpoint_dir):
|
||||
raise RuntimeError()
|
||||
if tf.io.gfile.exists(write_checkpoint_dir) != _is_chief(
|
||||
task_type, task_id):
|
||||
raise RuntimeError()
|
||||
|
||||
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
|
||||
checkpoint.restore(latest_checkpoint)
|
||||
multi_worker_model.fit(multi_worker_dataset, epochs=1, steps_per_epoch=1)
|
||||
|
||||
logging.info('testMwmsWithModelFit successfully ends')
|
||||
|
||||
model_path = os.path.join(self.get_temp_dir(), 'model.tf')
|
||||
checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt')
|
||||
try:
|
||||
mpr_result = tf.__internal__.distribute.multi_process_runner.run(
|
||||
fn,
|
||||
tf.__internal__.distribute.multi_process_runner.create_cluster_spec(
|
||||
num_workers=NUM_WORKERS),
|
||||
args=(model_path, checkpoint_dir),
|
||||
return_output=True)
|
||||
except tf.errors.UnavailableError:
|
||||
self.skipTest('Skipping rare disconnection among the workers.')
|
||||
|
||||
self.assertTrue(
|
||||
any([
|
||||
'testMwmsWithModelFit successfully ends' in msg
|
||||
for msg in mpr_result.stdout
|
||||
]))
|
||||
|
||||
def extract_accuracy(worker_id, input_string):
|
||||
match = re.match(
|
||||
r'\[worker\-{}\].*accuracy: (\d+\.\d+).*'.format(worker_id),
|
||||
input_string)
|
||||
return None if match is None else float(match.group(1))
|
||||
|
||||
for worker_id in range(NUM_WORKERS):
|
||||
accu_result = tf.nest.map_structure(
|
||||
lambda x: extract_accuracy(worker_id, x), # pylint: disable=cell-var-from-loop
|
||||
mpr_result.stdout)
|
||||
self.assertTrue(
|
||||
any(accu_result), 'Every worker is supposed to have accuracy result.')
|
||||
|
||||
@tf.__internal__.test.combinations.generate(
|
||||
tf.__internal__.test.combinations.combine(
|
||||
mode=['eager'], tf_api_version=2))
|
||||
def testMwmsWithCtl(self, mode):
|
||||
"""Test multi-worker CTL training flow demo'ed in a to-be-added tutorial."""
|
||||
|
||||
def proc_func(checkpoint_dir):
|
||||
global_batch_size = PER_WORKER_BATCH_SIZE * NUM_WORKERS
|
||||
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||||
try:
|
||||
|
||||
with strategy.scope():
|
||||
multi_worker_model = self.build_cnn_model()
|
||||
|
||||
multi_worker_dataset = strategy.distribute_datasets_from_function(
|
||||
lambda input_context: self.dataset_fn(global_batch_size, # pylint: disable=g-long-lambda
|
||||
input_context))
|
||||
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
|
||||
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
|
||||
name='train_accuracy')
|
||||
|
||||
@tf.function
|
||||
def train_step(iterator):
|
||||
"""Training step function."""
|
||||
|
||||
def step_fn(inputs):
|
||||
"""Per-Replica step function."""
|
||||
x, y = inputs
|
||||
with tf.GradientTape() as tape:
|
||||
predictions = multi_worker_model(x, training=True)
|
||||
per_batch_loss = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||
from_logits=True,
|
||||
reduction=tf.keras.losses.Reduction.NONE)(y, predictions)
|
||||
loss = tf.nn.compute_average_loss(
|
||||
per_batch_loss, global_batch_size=global_batch_size)
|
||||
|
||||
grads = tape.gradient(loss, multi_worker_model.trainable_variables)
|
||||
optimizer.apply_gradients(
|
||||
zip(grads, multi_worker_model.trainable_variables))
|
||||
train_accuracy.update_state(y, predictions)
|
||||
|
||||
return loss
|
||||
|
||||
per_replica_losses = strategy.run(step_fn, args=(next(iterator),))
|
||||
return strategy.reduce(
|
||||
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
|
||||
|
||||
epoch = tf.Variable(
|
||||
initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')
|
||||
step_in_epoch = tf.Variable(
|
||||
initial_value=tf.constant(0, dtype=tf.dtypes.int64),
|
||||
name='step_in_epoch')
|
||||
|
||||
task_type, task_id = (strategy.cluster_resolver.task_type,
|
||||
strategy.cluster_resolver.task_id)
|
||||
checkpoint = tf.train.Checkpoint(
|
||||
model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)
|
||||
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type,
|
||||
task_id)
|
||||
checkpoint_manager = tf.train.CheckpointManager(
|
||||
checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
|
||||
|
||||
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
|
||||
if latest_checkpoint:
|
||||
checkpoint.restore(latest_checkpoint)
|
||||
|
||||
while epoch.numpy() < NUM_EPOCHS:
|
||||
iterator = iter(multi_worker_dataset)
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
while step_in_epoch.numpy() < NUM_STEPS_PER_EPOCH:
|
||||
total_loss += train_step(iterator)
|
||||
num_batches += 1
|
||||
step_in_epoch.assign_add(1)
|
||||
|
||||
train_loss = total_loss / num_batches
|
||||
logging.info('Epoch: %d, accuracy: %f, train_loss: %f.',
|
||||
epoch.numpy(), train_accuracy.result(), train_loss)
|
||||
|
||||
train_accuracy.reset_states()
|
||||
|
||||
checkpoint_manager.save()
|
||||
if not _is_chief(task_type, task_id):
|
||||
tf.io.gfile.rmtree(write_checkpoint_dir)
|
||||
|
||||
epoch.assign_add(1)
|
||||
step_in_epoch.assign(0)
|
||||
|
||||
except tf.errors.UnavailableError as e:
|
||||
logging.info('UnavailableError occurred: %r', e)
|
||||
raise unittest.SkipTest('Skipping test due to UnavailableError')
|
||||
|
||||
logging.info('testMwmsWithCtl successfully ends')
|
||||
|
||||
checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt')
|
||||
|
||||
mpr_result = tf.__internal__.distribute.multi_process_runner.run(
|
||||
proc_func,
|
||||
tf.__internal__.distribute.multi_process_runner.create_cluster_spec(
|
||||
num_workers=NUM_WORKERS),
|
||||
return_output=True,
|
||||
args=(checkpoint_dir,))
|
||||
|
||||
self.assertTrue(
|
||||
any([
|
||||
'testMwmsWithCtl successfully ends' in msg
|
||||
for msg in mpr_result.stdout
|
||||
]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.__internal__.distribute.multi_process_runner.test_main()
|
Loading…
Reference in New Issue
Block a user