Cleanup and consolidate the usage of ds.combinations.main().

PiperOrigin-RevId: 331069116
Change-Id: I21e45fc7b4f176e5555ea6735ccaaa6ae3f20d2c
This commit is contained in:
Scott Zhu 2020-09-10 19:50:20 -07:00 committed by TensorFlower Gardener
parent 3d70bc50e7
commit 80bdf7c72f
4 changed files with 8 additions and 4 deletions

View File

@ -27,6 +27,7 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import strategy_combinations
@ -371,4 +372,4 @@ class DistributedCollectiveAllReduceStrategyEagerTest(test.TestCase,
if __name__ == '__main__':
v2_compat.enable_v2_behavior()
ds_combinations.main()
multi_process_runner.test_main()

View File

@ -26,6 +26,7 @@ from tensorflow.python import keras
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop
@ -278,4 +279,4 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
if __name__ == '__main__':
ds_combinations.main()
multi_process_runner.test_main()

View File

@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
@ -101,4 +102,4 @@ class KerasMetricsTest(test.TestCase, parameterized.TestCase):
if __name__ == "__main__":
ds_combinations.main()
multi_process_runner.test_main()

View File

@ -26,6 +26,7 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop
@ -475,4 +476,4 @@ def _get_model():
if __name__ == "__main__":
ds_combinations.main()
multi_process_runner.test_main()