Set 2 virtual cpus and 2 virtual gpus by default for test cases.
PiperOrigin-RevId: 340551057 Change-Id: I2a72cac5cd56c1efc8ba338d2bcf66bad4560fab
This commit is contained in:
parent
e82b266a54
commit
a64738beb2
@ -635,4 +635,5 @@ class ExperimentalCompatibilityTest(test.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_util.main()
|
||||
# TODO(b/172304955): enable logical devices.
|
||||
test_util.main(config_logical_devices=False)
|
||||
|
@ -613,4 +613,5 @@ class PSStrategySaveAndLoadTest(test.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_util.main()
|
||||
# TODO(b/172304955): enable logical devices.
|
||||
test_util.main(config_logical_devices=False)
|
||||
|
@ -287,4 +287,5 @@ class ExponentialMovingAverageTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_util.main()
|
||||
# TODO(b/172304955): enable logical devices.
|
||||
test_util.main(config_logical_devices=False)
|
||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from absl import app
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
@ -87,8 +89,17 @@ def set_logical_devices_to_at_least(device, num):
|
||||
config.set_logical_device_configuration(physical_devices[-1], logical_devices)
|
||||
|
||||
|
||||
def main(enable_v2_behavior=True):
|
||||
def _set_logical_devices():
|
||||
if config.list_physical_devices("GPU"):
|
||||
set_logical_devices_to_at_least("GPU", 2)
|
||||
if config.list_physical_devices("CPU"):
|
||||
set_logical_devices_to_at_least("CPU", 2)
|
||||
|
||||
|
||||
def main(enable_v2_behavior=True, config_logical_devices=True):
|
||||
"""All-in-one main function for tf.distribute tests."""
|
||||
if config_logical_devices:
|
||||
app.call_after_init(_set_logical_devices)
|
||||
if enable_v2_behavior:
|
||||
v2_compat.enable_v2_behavior()
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user