tf.distribute: Add MultiWorkerMirroredStrategy to keras_premade_models_test
PiperOrigin-RevId: 332501279 Change-Id: Iada8a744d555b73ce9f70a855e1ef4290ca12c9a
This commit is contained in:
parent
7edae146ce
commit
effefef880
@ -21,6 +21,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations as ds_combinations
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.framework import test_combinations as combinations
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
@ -39,10 +40,15 @@ def strategy_combinations_eager_data_fn():
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.one_device_strategy_gpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
# NOTE: TPUStrategy not tested because the models in this test are
|
||||
# sparse and do not work with TPUs.
|
||||
],
|
||||
mode=['eager'],
|
||||
data_fn=[get_numpy, get_dataset])
|
||||
data_fn=['numpy', 'dataset'])
|
||||
|
||||
|
||||
def get_numpy():
|
||||
@ -66,7 +72,7 @@ class KerasPremadeModelsTest(test.TestCase, parameterized.TestCase):
|
||||
model = linear.LinearModel()
|
||||
opt = gradient_descent.SGD(learning_rate=0.1)
|
||||
model.compile(opt, 'mse')
|
||||
if data_fn == get_numpy:
|
||||
if data_fn == 'numpy':
|
||||
inputs, output = get_numpy()
|
||||
hist = model.fit(inputs, output, epochs=5)
|
||||
else:
|
||||
@ -84,7 +90,7 @@ class KerasPremadeModelsTest(test.TestCase, parameterized.TestCase):
|
||||
wide_deep_model.compile(
|
||||
optimizer=[linear_opt, dnn_opt],
|
||||
loss='mse')
|
||||
if data_fn == get_numpy:
|
||||
if data_fn == 'numpy':
|
||||
inputs, output = get_numpy()
|
||||
hist = wide_deep_model.fit(inputs, output, epochs=5)
|
||||
else:
|
||||
@ -93,4 +99,4 @@ class KerasPremadeModelsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
multi_process_runner.test_main()
|
||||
|
Loading…
Reference in New Issue
Block a user