Cleanup and consolidate the usage of ds.combinations.main().
PiperOrigin-RevId: 331069116 Change-Id: I21e45fc7b4f176e5555ea6735ccaaa6ae3f20d2c
This commit is contained in:
parent
3d70bc50e7
commit
80bdf7c72f
@ -27,6 +27,7 @@ from tensorflow.python.data.ops import dataset_ops
|
|||||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||||
from tensorflow.python.distribute import combinations as ds_combinations
|
from tensorflow.python.distribute import combinations as ds_combinations
|
||||||
from tensorflow.python.distribute import cross_device_utils
|
from tensorflow.python.distribute import cross_device_utils
|
||||||
|
from tensorflow.python.distribute import multi_process_runner
|
||||||
from tensorflow.python.distribute import multi_worker_test_base
|
from tensorflow.python.distribute import multi_worker_test_base
|
||||||
from tensorflow.python.distribute import multi_worker_util
|
from tensorflow.python.distribute import multi_worker_util
|
||||||
from tensorflow.python.distribute import strategy_combinations
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
@ -371,4 +372,4 @@ class DistributedCollectiveAllReduceStrategyEagerTest(test.TestCase,
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
v2_compat.enable_v2_behavior()
|
v2_compat.enable_v2_behavior()
|
||||||
ds_combinations.main()
|
multi_process_runner.test_main()
|
||||||
|
@ -26,6 +26,7 @@ from tensorflow.python import keras
|
|||||||
from tensorflow.python.compat import v2_compat
|
from tensorflow.python.compat import v2_compat
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import combinations as ds_combinations
|
from tensorflow.python.distribute import combinations as ds_combinations
|
||||||
|
from tensorflow.python.distribute import multi_process_runner
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.distribute import strategy_combinations
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
@ -278,4 +279,4 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase,
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ds_combinations.main()
|
multi_process_runner.test_main()
|
||||||
|
@ -23,6 +23,7 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import combinations as ds_combinations
|
from tensorflow.python.distribute import combinations as ds_combinations
|
||||||
|
from tensorflow.python.distribute import multi_process_runner
|
||||||
from tensorflow.python.distribute import strategy_combinations
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -101,4 +102,4 @@ class KerasMetricsTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ds_combinations.main()
|
multi_process_runner.test_main()
|
||||||
|
@ -26,6 +26,7 @@ import numpy as np
|
|||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import combinations as ds_combinations
|
from tensorflow.python.distribute import combinations as ds_combinations
|
||||||
|
from tensorflow.python.distribute import multi_process_runner
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.distribute import strategy_combinations
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
@ -475,4 +476,4 @@ def _get_model():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ds_combinations.main()
|
multi_process_runner.test_main()
|
||||||
|
Loading…
Reference in New Issue
Block a user