OSS Multiworker callback tf2 tests.

PiperOrigin-RevId: 279830188
Change-Id: I19d7f2f1a610817522563248d0948611f3f49328
This commit is contained in:
Rick Chao 2019-11-11 15:27:02 -08:00 committed by TensorFlower Gardener
parent 4659f8a620
commit 7db012db68
3 changed files with 205 additions and 4 deletions

View File

@ -353,8 +353,8 @@ cuda_py_test(
)
cuda_py_test(
name = "multi_worker_callback_test",
srcs = ["multi_worker_callback_test.py"],
name = "multi_worker_callback_tf1_test",
srcs = ["multi_worker_callback_tf1_test.py"],
additional_deps = [
":distribute",
"//tensorflow/python:platform",
@ -376,6 +376,18 @@ cuda_py_test(
],
)
cuda_py_test(
name = "multi_worker_callback_tf2_test",
srcs = ["multi_worker_callback_tf2_test.py"],
additional_deps = [
"//tensorflow/python/distribute:collective_all_reduce_strategy",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:multi_process_runner",
"//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/keras/distribute:multi_worker_testing_utils",
],
)
cuda_py_test(
name = "multi_worker_fault_tolerance_test",
srcs = ["multi_worker_fault_tolerance_test.py"],

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests Keras multi worker callbacks."""
"""Tests for Keras callbacks in multi-worker training with TF1."""
from __future__ import absolute_import
from __future__ import division
@ -24,7 +24,6 @@ import threading
from absl.testing import parameterized
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python import keras
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
from tensorflow.python.distribute import combinations
@ -115,6 +114,11 @@ def generate_callback_test_function(custom_callable):
class KerasMultiWorkerCallbackTest(test_base.IndependentWorkerTestBase,
parameterized.TestCase):
"""KerasMultiWorkerCallbackTest for TF1.
TODO(rchao): Migrate all tests in this class to
`multi_worker_callback_tf2_test`.
"""
# The callables of the actual testing content to be run go below.
@staticmethod

View File

@ -0,0 +1,185 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras callbacks in multi-worker training with TF2."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_process_runner_util
from tensorflow.python.distribute import multi_worker_test_base as test_base
from tensorflow.python.keras import callbacks
from tensorflow.python.keras.distribute import multi_worker_testing_utils
from tensorflow.python.keras.distribute import multi_worker_training_state as training_state
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import test
def _model_setup(test_obj, file_format):
"""Set up a MNIST Keras model for testing purposes.
This function builds a MNIST Keras model and returns relevant information
for testing.
Args:
test_obj: The `TestCase` testing object.
file_format: File format for checkpoints. 'tf' or 'h5'.
Returns:
A tuple of (model, saving_filepath, train_ds, steps) where train_ds is
the training dataset.
"""
batch_size = 64
steps = 2
with collective_strategy.CollectiveAllReduceStrategy().scope():
# TODO(b/142509827): In rare cases this errors out at C++ level with the
# "Connect failed" error message.
train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
batch_size, steps)
model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
# Pass saving_filepath from the parent thread to ensure every worker has the
# same filepath to save.
saving_filepath = os.path.join(test_obj.get_temp_dir(),
'checkpoint.' + file_format)
return model, saving_filepath, train_ds, steps
class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
@combinations.generate(
combinations.combine(
mode=['eager'],
file_format=['h5', 'tf'],
save_weights_only=[True, False]))
def test_model_checkpoint_saves_on_chief_but_not_otherwise(
self, file_format, mode, save_weights_only):
def proc_model_checkpoint_saves_on_chief_but_not_otherwise(
test_obj, file_format):
model, saving_filepath, train_ds, steps = _model_setup(
test_obj, file_format)
num_epoch = 2
extension = os.path.splitext(saving_filepath)[1]
# Incorporate type/index information and thread id in saving_filepath to
# ensure every worker has a unique path. Note that in normal use case the
# saving_filepath will be the same for all workers, but we use different
# ones here just to test out chief saves checkpoint but non-chief doesn't.
saving_filepath = os.path.join(
test_obj.get_temp_dir(), 'checkpoint_%s_%d%s' %
(test_base.get_task_type(), test_base.get_task_index(), extension))
# The saving_filepath shouldn't exist at the beginning (as it's unique).
test_obj.assertFalse(training_state.checkpoint_exists(saving_filepath))
model.fit(
x=train_ds,
epochs=num_epoch,
steps_per_epoch=steps,
callbacks=[
callbacks.ModelCheckpoint(
filepath=saving_filepath, save_weights_only=save_weights_only)
])
# If it's chief, the model should be saved; if not, the model shouldn't.
test_obj.assertEqual(
training_state.checkpoint_exists(saving_filepath),
test_base.is_chief())
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
with multi_process_runner_util.try_run_and_except_connection_error(self):
multi_process_runner.run(
proc_model_checkpoint_saves_on_chief_but_not_otherwise,
cluster_spec=test_base.create_cluster_spec(num_workers=2),
args=(self, file_format))
@combinations.generate(combinations.combine(mode=['eager']))
def test_tensorboard_saves_on_chief_but_not_otherwise(self, mode):
def proc_tensorboard_saves_on_chief_but_not_otherwise(test_obj):
model, _, train_ds, steps = _model_setup(test_obj, file_format='')
num_epoch = 2
# Incorporate type/index information and thread id in saving_filepath to
# ensure every worker has a unique path. Note that in normal use case the
# saving_filepath will be the same for all workers, but we use different
# ones here just to test out chief saves summaries but non-chief doesn't.
saving_filepath = os.path.join(
test_obj.get_temp_dir(), 'logfile_%s_%d' %
(test_base.get_task_type(), test_base.get_task_index()))
# The saving_filepath shouldn't exist at the beginning (as it's unique).
test_obj.assertFalse(file_io.file_exists(saving_filepath))
model.fit(
x=train_ds,
epochs=num_epoch,
steps_per_epoch=steps,
callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)])
# If it's chief, the summaries should be saved in the filepath; if not,
# the directory should be empty (although created). Using
# `file_io.list_directory()` since the directory may be created at this
# point.
test_obj.assertEqual(
bool(file_io.list_directory(saving_filepath)), test_base.is_chief())
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
with multi_process_runner_util.try_run_and_except_connection_error(self):
multi_process_runner.run(
proc_tensorboard_saves_on_chief_but_not_otherwise,
cluster_spec=test_base.create_cluster_spec(num_workers=2),
args=(self,))
@combinations.generate(combinations.combine(mode=['eager']))
def test_tensorboard_can_still_save_to_temp_even_if_it_exists(self, mode):
def proc_tensorboard_can_still_save_to_temp_even_if_it_exists(test_obj):
model, _, train_ds, steps = _model_setup(test_obj, file_format='')
num_epoch = 2
saving_filepath = os.path.join(test_obj.get_temp_dir(),
'logfile_%s' % (test_base.get_task_type()))
saving_filepath_for_temp = os.path.join(saving_filepath, 'workertemp_1')
os.mkdir(saving_filepath)
os.mkdir(saving_filepath_for_temp)
# Verifies that even if `saving_filepath_for_temp` exists, tensorboard
# can still save to temporary directory.
test_obj.assertTrue(file_io.file_exists(saving_filepath_for_temp))
model.fit(
x=train_ds,
epochs=num_epoch,
steps_per_epoch=steps,
callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)])
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
with multi_process_runner_util.try_run_and_except_connection_error(self):
multi_process_runner.run(
proc_tensorboard_can_still_save_to_temp_even_if_it_exists,
cluster_spec=test_base.create_cluster_spec(num_workers=2),
args=(self,))
if __name__ == '__main__':
multi_process_runner.test_main()