OSS Multiworker callback tf2 tests.
PiperOrigin-RevId: 279830188 Change-Id: I19d7f2f1a610817522563248d0948611f3f49328
This commit is contained in:
parent
4659f8a620
commit
7db012db68
tensorflow/python/keras/distribute
@ -353,8 +353,8 @@ cuda_py_test(
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "multi_worker_callback_test",
|
||||
srcs = ["multi_worker_callback_test.py"],
|
||||
name = "multi_worker_callback_tf1_test",
|
||||
srcs = ["multi_worker_callback_tf1_test.py"],
|
||||
additional_deps = [
|
||||
":distribute",
|
||||
"//tensorflow/python:platform",
|
||||
@ -376,6 +376,18 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "multi_worker_callback_tf2_test",
|
||||
srcs = ["multi_worker_callback_tf2_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python/distribute:collective_all_reduce_strategy",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:multi_process_runner",
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||
"//tensorflow/python/keras/distribute:multi_worker_testing_utils",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "multi_worker_fault_tolerance_test",
|
||||
srcs = ["multi_worker_fault_tolerance_test.py"],
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests Keras multi worker callbacks."""
|
||||
"""Tests for Keras callbacks in multi-worker training with TF1."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -24,7 +24,6 @@ import threading
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
# pylint: disable=g-direct-tensorflow-import
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
@ -115,6 +114,11 @@ def generate_callback_test_function(custom_callable):
|
||||
|
||||
class KerasMultiWorkerCallbackTest(test_base.IndependentWorkerTestBase,
|
||||
parameterized.TestCase):
|
||||
"""KerasMultiWorkerCallbackTest for TF1.
|
||||
|
||||
TODO(rchao): Migrate all tests in this class to
|
||||
`multi_worker_callback_tf2_test`.
|
||||
"""
|
||||
|
||||
# The callables of the actual testing content to be run go below.
|
||||
@staticmethod
|
@ -0,0 +1,185 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Keras callbacks in multi-worker training with TF2."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy as collective_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_process_runner_util
|
||||
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
||||
from tensorflow.python.keras import callbacks
|
||||
from tensorflow.python.keras.distribute import multi_worker_testing_utils
|
||||
from tensorflow.python.keras.distribute import multi_worker_training_state as training_state
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _model_setup(test_obj, file_format):
|
||||
"""Set up a MNIST Keras model for testing purposes.
|
||||
|
||||
This function builds a MNIST Keras model and returns relevant information
|
||||
for testing.
|
||||
|
||||
Args:
|
||||
test_obj: The `TestCase` testing object.
|
||||
file_format: File format for checkpoints. 'tf' or 'h5'.
|
||||
|
||||
Returns:
|
||||
A tuple of (model, saving_filepath, train_ds, steps) where train_ds is
|
||||
the training dataset.
|
||||
"""
|
||||
batch_size = 64
|
||||
steps = 2
|
||||
with collective_strategy.CollectiveAllReduceStrategy().scope():
|
||||
# TODO(b/142509827): In rare cases this errors out at C++ level with the
|
||||
# "Connect failed" error message.
|
||||
train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
|
||||
batch_size, steps)
|
||||
model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
|
||||
# Pass saving_filepath from the parent thread to ensure every worker has the
|
||||
# same filepath to save.
|
||||
saving_filepath = os.path.join(test_obj.get_temp_dir(),
|
||||
'checkpoint.' + file_format)
|
||||
return model, saving_filepath, train_ds, steps
|
||||
|
||||
|
||||
class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=['eager'],
|
||||
file_format=['h5', 'tf'],
|
||||
save_weights_only=[True, False]))
|
||||
def test_model_checkpoint_saves_on_chief_but_not_otherwise(
|
||||
self, file_format, mode, save_weights_only):
|
||||
|
||||
def proc_model_checkpoint_saves_on_chief_but_not_otherwise(
|
||||
test_obj, file_format):
|
||||
model, saving_filepath, train_ds, steps = _model_setup(
|
||||
test_obj, file_format)
|
||||
num_epoch = 2
|
||||
extension = os.path.splitext(saving_filepath)[1]
|
||||
|
||||
# Incorporate type/index information and thread id in saving_filepath to
|
||||
# ensure every worker has a unique path. Note that in normal use case the
|
||||
# saving_filepath will be the same for all workers, but we use different
|
||||
# ones here just to test out chief saves checkpoint but non-chief doesn't.
|
||||
saving_filepath = os.path.join(
|
||||
test_obj.get_temp_dir(), 'checkpoint_%s_%d%s' %
|
||||
(test_base.get_task_type(), test_base.get_task_index(), extension))
|
||||
|
||||
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
||||
test_obj.assertFalse(training_state.checkpoint_exists(saving_filepath))
|
||||
|
||||
model.fit(
|
||||
x=train_ds,
|
||||
epochs=num_epoch,
|
||||
steps_per_epoch=steps,
|
||||
callbacks=[
|
||||
callbacks.ModelCheckpoint(
|
||||
filepath=saving_filepath, save_weights_only=save_weights_only)
|
||||
])
|
||||
|
||||
# If it's chief, the model should be saved; if not, the model shouldn't.
|
||||
test_obj.assertEqual(
|
||||
training_state.checkpoint_exists(saving_filepath),
|
||||
test_base.is_chief())
|
||||
|
||||
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
|
||||
with multi_process_runner_util.try_run_and_except_connection_error(self):
|
||||
multi_process_runner.run(
|
||||
proc_model_checkpoint_saves_on_chief_but_not_otherwise,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||
args=(self, file_format))
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def test_tensorboard_saves_on_chief_but_not_otherwise(self, mode):
|
||||
|
||||
def proc_tensorboard_saves_on_chief_but_not_otherwise(test_obj):
|
||||
model, _, train_ds, steps = _model_setup(test_obj, file_format='')
|
||||
num_epoch = 2
|
||||
|
||||
# Incorporate type/index information and thread id in saving_filepath to
|
||||
# ensure every worker has a unique path. Note that in normal use case the
|
||||
# saving_filepath will be the same for all workers, but we use different
|
||||
# ones here just to test out chief saves summaries but non-chief doesn't.
|
||||
saving_filepath = os.path.join(
|
||||
test_obj.get_temp_dir(), 'logfile_%s_%d' %
|
||||
(test_base.get_task_type(), test_base.get_task_index()))
|
||||
|
||||
# The saving_filepath shouldn't exist at the beginning (as it's unique).
|
||||
test_obj.assertFalse(file_io.file_exists(saving_filepath))
|
||||
|
||||
model.fit(
|
||||
x=train_ds,
|
||||
epochs=num_epoch,
|
||||
steps_per_epoch=steps,
|
||||
callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)])
|
||||
|
||||
# If it's chief, the summaries should be saved in the filepath; if not,
|
||||
# the directory should be empty (although created). Using
|
||||
# `file_io.list_directory()` since the directory may be created at this
|
||||
# point.
|
||||
test_obj.assertEqual(
|
||||
bool(file_io.list_directory(saving_filepath)), test_base.is_chief())
|
||||
|
||||
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
|
||||
with multi_process_runner_util.try_run_and_except_connection_error(self):
|
||||
multi_process_runner.run(
|
||||
proc_tensorboard_saves_on_chief_but_not_otherwise,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||
args=(self,))
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def test_tensorboard_can_still_save_to_temp_even_if_it_exists(self, mode):
|
||||
|
||||
def proc_tensorboard_can_still_save_to_temp_even_if_it_exists(test_obj):
|
||||
model, _, train_ds, steps = _model_setup(test_obj, file_format='')
|
||||
num_epoch = 2
|
||||
|
||||
saving_filepath = os.path.join(test_obj.get_temp_dir(),
|
||||
'logfile_%s' % (test_base.get_task_type()))
|
||||
|
||||
saving_filepath_for_temp = os.path.join(saving_filepath, 'workertemp_1')
|
||||
os.mkdir(saving_filepath)
|
||||
os.mkdir(saving_filepath_for_temp)
|
||||
|
||||
# Verifies that even if `saving_filepath_for_temp` exists, tensorboard
|
||||
# can still save to temporary directory.
|
||||
test_obj.assertTrue(file_io.file_exists(saving_filepath_for_temp))
|
||||
|
||||
model.fit(
|
||||
x=train_ds,
|
||||
epochs=num_epoch,
|
||||
steps_per_epoch=steps,
|
||||
callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)])
|
||||
|
||||
# TODO(b/141948186): Remove this `with` block once b/141948186 is resolved.
|
||||
with multi_process_runner_util.try_run_and_except_connection_error(self):
|
||||
multi_process_runner.run(
|
||||
proc_tensorboard_can_still_save_to_temp_even_if_it_exists,
|
||||
cluster_spec=test_base.create_cluster_spec(num_workers=2),
|
||||
args=(self,))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
multi_process_runner.test_main()
|
Loading…
Reference in New Issue
Block a user