OSS Multiworker callback tf2 tests.
PiperOrigin-RevId: 279830188 Change-Id: I19d7f2f1a610817522563248d0948611f3f49328
This commit is contained in:
parent
4659f8a620
commit
7db012db68
@ -353,8 +353,8 @@ cuda_py_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "multi_worker_callback_test",
|
name = "multi_worker_callback_tf1_test",
|
||||||
srcs = ["multi_worker_callback_test.py"],
|
srcs = ["multi_worker_callback_tf1_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
":distribute",
|
":distribute",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
@ -376,6 +376,18 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "multi_worker_callback_tf2_test",
|
||||||
|
srcs = ["multi_worker_callback_tf2_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
"//tensorflow/python/distribute:collective_all_reduce_strategy",
|
||||||
|
"//tensorflow/python/distribute:combinations",
|
||||||
|
"//tensorflow/python/distribute:multi_process_runner",
|
||||||
|
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||||
|
"//tensorflow/python/keras/distribute:multi_worker_testing_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "multi_worker_fault_tolerance_test",
|
name = "multi_worker_fault_tolerance_test",
|
||||||
srcs = ["multi_worker_fault_tolerance_test.py"],
|
srcs = ["multi_worker_fault_tolerance_test.py"],
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests Keras multi worker callbacks."""
|
"""Tests for Keras callbacks in multi-worker training with TF1."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -24,7 +24,6 @@ import threading
|
|||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
# pylint: disable=g-direct-tensorflow-import
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
|
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
|
||||||
from tensorflow.python.distribute import combinations
|
from tensorflow.python.distribute import combinations
|
||||||
@ -115,6 +114,11 @@ def generate_callback_test_function(custom_callable):
|
|||||||
|
|
||||||
class KerasMultiWorkerCallbackTest(test_base.IndependentWorkerTestBase,
|
class KerasMultiWorkerCallbackTest(test_base.IndependentWorkerTestBase,
|
||||||
parameterized.TestCase):
|
parameterized.TestCase):
|
||||||
|
"""KerasMultiWorkerCallbackTest for TF1.
|
||||||
|
|
||||||
|
TODO(rchao): Migrate all tests in this class to
|
||||||
|
`multi_worker_callback_tf2_test`.
|
||||||
|
"""
|
||||||
|
|
||||||
# The callables of the actual testing content to be run go below.
|
# The callables of the actual testing content to be run go below.
|
||||||
@staticmethod
|
@staticmethod
|
@ -0,0 +1,185 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for Keras callbacks in multi-worker training with TF2."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
import os
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
|
||||||
|
from tensorflow.python.distribute import combinations
|
||||||
|
from tensorflow.python.distribute import multi_process_runner
|
||||||
|
from tensorflow.python.distribute import multi_process_runner_util
|
||||||
|
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
||||||
|
from tensorflow.python.keras import callbacks
|
||||||
|
from tensorflow.python.keras.distribute import multi_worker_testing_utils
|
||||||
|
from tensorflow.python.keras.distribute import multi_worker_training_state as training_state
|
||||||
|
from tensorflow.python.lib.io import file_io
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
def _model_setup(test_obj, file_format):
|
||||||
|
"""Set up a MNIST Keras model for testing purposes.
|
||||||
|
|
||||||
|
This function builds a MNIST Keras model and returns relevant information
|
||||||
|
for testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
test_obj: The `TestCase` testing object.
|
||||||
|
file_format: File format for checkpoints. 'tf' or 'h5'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (model, saving_filepath, train_ds, steps) where train_ds is
|
||||||
|
the training dataset.
|
||||||
|
"""
|
||||||
|
batch_size = 64
|
||||||
|
steps = 2
|
||||||
|
with collective_strategy.CollectiveAllReduceStrategy().scope():
|
||||||
|
# TODO(b/142509827): In rare cases this errors out at C++ level with the
|
||||||
|
# "Connect failed" error message.
|
||||||
|
train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
|
||||||
|
batch_size, steps)
|
||||||
|
model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
|
||||||
|
# Pass saving_filepath from the parent thread to ensure every worker has the
|
||||||
|
# same filepath to save.
|
||||||
|
saving_filepath = os.path.join(test_obj.get_temp_dir(),
|
||||||
|
'checkpoint.' + file_format)
|
||||||
|
return model, saving_filepath, train_ds, steps
|
||||||
|
|
||||||
|
|
||||||
|
class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
mode=['eager'],
|
||||||
|
file_format=['h5', 'tf'],
|
||||||
|
save_weights_only=[True, False]))
|
||||||
|
def test_model_checkpoint_saves_on_chief_but_not_otherwise(
|
||||||
|
self, file_format, mode, save_weights_only):
|
||||||
|
|
||||||
|
def proc_model_checkpoint_saves_on_chief_but_not_otherwise(
|
||||||
|
test_obj, file_format):
|
||||||
|
model, saving_filepath, train_ds, steps = _model_setup(
|
||||||
|
test_obj, file_format)
|
||||||
|
num_epoch = 2
|
||||||
|
extension = os.path.splitext(saving_filepath)[1]
|
||||||
|
|
||||||
|
# Incorporate type/index information and thread id in saving_filepath to
|
||||||
|
# ensure every worker has a unique path. Note that in normal use case the
|
||||||
|
# saving_filepath will be the same for all workers, but we use different
|
||||||
|
# ones here just to test out chief saves checkpoint but non-chief doesn't.
|
||||||
|
saving_filepath = os.path.join(
|
||||||
|
test_obj.get_temp_dir(), 'checkpoint_%s_%d%s' %
|
||||||
|
(test_base.get_task_type(), test_base.get_task_index(), extension))
|
||||||
|
|
||||||
|
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
||||||
|
test_obj.assertFalse(training_state.checkpoint_exists(saving_filepath))
|
||||||
|
|
||||||
|
model.fit(
|
||||||
|
x=train_ds,
|
||||||
|
epochs=num_epoch,
|
||||||
|
steps_per_epoch=steps,
|
||||||
|
callbacks=[
|
||||||
|
callbacks.ModelCheckpoint(
|
||||||
|
filepath=saving_filepath, save_weights_only=save_weights_only)
|
||||||
|
])
|
||||||
|
|
||||||
|
# If it's chief, the model should be saved; if not, the model shouldn't.
|
||||||
|
test_obj.assertEqual(
|
||||||
|
training_state.checkpoint_exists(saving_filepath),
|
||||||
|
test_base.is_chief())
|
||||||
|
|
||||||
|
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
|
||||||
|
with multi_process_runner_util.try_run_and_except_connection_error(self):
|
||||||
|
multi_process_runner.run(
|
||||||
|
proc_model_checkpoint_saves_on_chief_but_not_otherwise,
|
||||||
|
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||||
|
args=(self, file_format))
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['eager']))
|
||||||
|
def test_tensorboard_saves_on_chief_but_not_otherwise(self, mode):
|
||||||
|
|
||||||
|
def proc_tensorboard_saves_on_chief_but_not_otherwise(test_obj):
|
||||||
|
model, _, train_ds, steps = _model_setup(test_obj, file_format='')
|
||||||
|
num_epoch = 2
|
||||||
|
|
||||||
|
# Incorporate type/index information and thread id in saving_filepath to
|
||||||
|
# ensure every worker has a unique path. Note that in normal use case the
|
||||||
|
# saving_filepath will be the same for all workers, but we use different
|
||||||
|
# ones here just to test out chief saves summaries but non-chief doesn't.
|
||||||
|
saving_filepath = os.path.join(
|
||||||
|
test_obj.get_temp_dir(), 'logfile_%s_%d' %
|
||||||
|
(test_base.get_task_type(), test_base.get_task_index()))
|
||||||
|
|
||||||
|
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
||||||
|
test_obj.assertFalse(file_io.file_exists(saving_filepath))
|
||||||
|
|
||||||
|
model.fit(
|
||||||
|
x=train_ds,
|
||||||
|
epochs=num_epoch,
|
||||||
|
steps_per_epoch=steps,
|
||||||
|
callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)])
|
||||||
|
|
||||||
|
# If it's chief, the summaries should be saved in the filepath; if not,
|
||||||
|
# the directory should be empty (although created). Using
|
||||||
|
# `file_io.list_directory()` since the directory may be created at this
|
||||||
|
# point.
|
||||||
|
test_obj.assertEqual(
|
||||||
|
bool(file_io.list_directory(saving_filepath)), test_base.is_chief())
|
||||||
|
|
||||||
|
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
|
||||||
|
with multi_process_runner_util.try_run_and_except_connection_error(self):
|
||||||
|
multi_process_runner.run(
|
||||||
|
proc_tensorboard_saves_on_chief_but_not_otherwise,
|
||||||
|
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||||
|
args=(self,))
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['eager']))
|
||||||
|
def test_tensorboard_can_still_save_to_temp_even_if_it_exists(self, mode):
|
||||||
|
|
||||||
|
def proc_tensorboard_can_still_save_to_temp_even_if_it_exists(test_obj):
|
||||||
|
model, _, train_ds, steps = _model_setup(test_obj, file_format='')
|
||||||
|
num_epoch = 2
|
||||||
|
|
||||||
|
saving_filepath = os.path.join(test_obj.get_temp_dir(),
|
||||||
|
'logfile_%s' % (test_base.get_task_type()))
|
||||||
|
|
||||||
|
saving_filepath_for_temp = os.path.join(saving_filepath, 'workertemp_1')
|
||||||
|
os.mkdir(saving_filepath)
|
||||||
|
os.mkdir(saving_filepath_for_temp)
|
||||||
|
|
||||||
|
# Verifies that even if `saving_filepath_for_temp` exists, tensorboard
|
||||||
|
# can still save to temporary directory.
|
||||||
|
test_obj.assertTrue(file_io.file_exists(saving_filepath_for_temp))
|
||||||
|
|
||||||
|
model.fit(
|
||||||
|
x=train_ds,
|
||||||
|
epochs=num_epoch,
|
||||||
|
steps_per_epoch=steps,
|
||||||
|
callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)])
|
||||||
|
|
||||||
|
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
|
||||||
|
with multi_process_runner_util.try_run_and_except_connection_error(self):
|
||||||
|
multi_process_runner.run(
|
||||||
|
proc_tensorboard_can_still_save_to_temp_even_if_it_exists,
|
||||||
|
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||||
|
args=(self,))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
multi_process_runner.test_main()
|
Loading…
x
Reference in New Issue
Block a user