tf.distribute: Add MultiWorkerMirroredStrategy to keras_premade_models_test

PiperOrigin-RevId: 332501279
Change-Id: Iada8a744d555b73ce9f70a855e1ef4290ca12c9a
This commit is contained in:
Priya Gupta 2020-09-18 12:17:18 -07:00 committed by TensorFlower Gardener
parent 7edae146ce
commit effefef880

View File

@ -21,6 +21,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.engine import sequential
@ -39,10 +40,15 @@ def strategy_combinations_eager_data_fn():
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
# NOTE: TPUStrategy not tested because the models in this test are
# sparse and do not work with TPUs.
],
mode=['eager'],
data_fn=[get_numpy, get_dataset])
data_fn=['numpy', 'dataset'])
def get_numpy():
@ -66,7 +72,7 @@ class KerasPremadeModelsTest(test.TestCase, parameterized.TestCase):
model = linear.LinearModel()
opt = gradient_descent.SGD(learning_rate=0.1)
model.compile(opt, 'mse')
if data_fn == get_numpy:
if data_fn == 'numpy':
inputs, output = get_numpy()
hist = model.fit(inputs, output, epochs=5)
else:
@ -84,7 +90,7 @@ class KerasPremadeModelsTest(test.TestCase, parameterized.TestCase):
wide_deep_model.compile(
optimizer=[linear_opt, dnn_opt],
loss='mse')
if data_fn == get_numpy:
if data_fn == 'numpy':
inputs, output = get_numpy()
hist = wide_deep_model.fit(inputs, output, epochs=5)
else:
@ -93,4 +99,4 @@ class KerasPremadeModelsTest(test.TestCase, parameterized.TestCase):
if __name__ == '__main__':
test.main()
multi_process_runner.test_main()