diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 7825f7e5dd2..70ea4ebe1e2 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5358,6 +5358,7 @@ py_test( ":client_testlib", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/distribute:test_util", "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 7551b751166..9962aec07f3 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -812,6 +812,7 @@ py_test( python_version = "PY3", deps = [ ":combinations", + ":test_util", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_combinations", "//tensorflow/python/distribute/cluster_resolver:tfconfig_cluster_resolver_py", @@ -837,6 +838,7 @@ py_library( ":multi_process_runner", ":multi_worker_test_base", ":one_device_strategy", + ":test_util", ":tpu_strategy", "//tensorflow/python:config", "//tensorflow/python:platform", @@ -858,6 +860,7 @@ distribute_py_test( ":combinations", ":reduce_util", ":strategy_combinations", + ":test_util", "//tensorflow/python:config", "//tensorflow/python:constant_op", "//tensorflow/python/eager:context", @@ -948,6 +951,7 @@ distribute_py_test( ":multi_worker_test_base", ":reduce_util", ":strategy_combinations", + ":test_util", ":values", "//tensorflow/python:control_flow_ops", "//tensorflow/python:errors", @@ -1235,6 +1239,7 @@ distribute_py_test( ":combinations", ":distribute_lib", ":strategy_combinations", + ":test_util", ":tpu_strategy", ":tpu_values", ":values", @@ -1287,6 +1292,7 @@ distribute_py_test( deps = [ ":combinations", ":strategy_combinations", + ":test_util", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", @@ -1698,6 +1704,7 @@ distribute_py_test( ":reduce_util", ":strategy_combinations", ":strategy_test_lib", + ":test_util", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -1743,10 +1750,14 @@ py_library( srcs_version = "PY3", deps = [ ":collective_all_reduce_strategy", + ":multi_process_runner", ":values", "//tensorflow/python:array_ops", + "//tensorflow/python:config", "//tensorflow/python:framework_ops", "//tensorflow/python:util", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/eager:context", ], ) diff --git a/tensorflow/python/distribute/combinations.py b/tensorflow/python/distribute/combinations.py index d1679c22c77..c9d3d7d9a9a 100644 --- a/tensorflow/python/distribute/combinations.py +++ b/tensorflow/python/distribute/combinations.py @@ -363,11 +363,6 @@ times = combinations_lib.times NamedObject = combinations_lib.NamedObject -def main(): - """Tests must call this main().""" - return multi_process_runner.test_main() - - # Identifies whether we're in the main process or worker processes. # `_multi_worker_test` decoration behaves differently in the main processs and # the worker processes. See the documentation of _multi_worker_test for detail. diff --git a/tensorflow/python/distribute/combinations_test.py b/tensorflow/python/distribute/combinations_test.py index fd1646fa9b7..02ddcbef632 100644 --- a/tensorflow/python/distribute/combinations_test.py +++ b/tensorflow/python/distribute/combinations_test.py @@ -25,6 +25,7 @@ import unittest from absl.testing import parameterized from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import test_util from tensorflow.python.distribute.cluster_resolver import tfconfig_cluster_resolver from tensorflow.python.framework import combinations as framework_combinations from tensorflow.python.platform import test @@ -156,4 +157,4 @@ class CombinationsOnClassMultiWorkerExpectedFailureTest(test.TestCase, if __name__ == "__main__": - combinations.main() + test_util.main() diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index b266dd25bc0..5abd6f483d3 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -35,6 +35,7 @@ from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import test_util from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import test @@ -1421,4 +1422,4 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, if __name__ == "__main__": - combinations.main() + test_util.main() diff --git a/tensorflow/python/distribute/integration_test/BUILD b/tensorflow/python/distribute/integration_test/BUILD index c1849afcb70..5ea98593ffe 100644 --- a/tensorflow/python/distribute/integration_test/BUILD +++ b/tensorflow/python/distribute/integration_test/BUILD @@ -18,6 +18,7 @@ distribute_py_test( "//tensorflow/python/distribute:parameter_server_strategy_v2", "//tensorflow/python/distribute:sharded_variable", "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/distribute:test_util", "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", @@ -40,6 +41,7 @@ cuda_py_test( "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:multi_process_runner", "//tensorflow/python/distribute:multi_worker_test_base", + "//tensorflow/python/distribute:test_util", "//tensorflow/python/eager:test", ], ) diff --git a/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py b/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py index 6d822ca1b97..02dee6f6adb 100644 --- a/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py +++ b/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py @@ -27,9 +27,9 @@ import os import tensorflow as tf from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib -from tensorflow.python.distribute import combinations from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_worker_test_base +from tensorflow.python.distribute import test_util from tensorflow.python.eager import test @@ -213,4 +213,4 @@ class PeerFailureRecoverTest(test.TestCase): if __name__ == "__main__": - combinations.main() + test_util.main() diff --git a/tensorflow/python/distribute/integration_test/saved_model_test.py b/tensorflow/python/distribute/integration_test/saved_model_test.py index 704279b4b04..b8ac71ef203 100644 --- a/tensorflow/python/distribute/integration_test/saved_model_test.py +++ b/tensorflow/python/distribute/integration_test/saved_model_test.py @@ -39,6 +39,7 @@ from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import parameter_server_strategy_v2 from tensorflow.python.distribute import sharded_variable from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import test_util from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import test @@ -612,4 +613,4 @@ class PSStrategySaveAndLoadTest(test.TestCase): if __name__ == "__main__": - combinations.main() + test_util.main() diff --git a/tensorflow/python/distribute/moving_averages_test.py b/tensorflow/python/distribute/moving_averages_test.py index 577a6c1168f..22522ea2389 100644 --- a/tensorflow/python/distribute/moving_averages_test.py +++ b/tensorflow/python/distribute/moving_averages_test.py @@ -23,6 +23,7 @@ from absl.testing import parameterized from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import combinations from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import test_util from tensorflow.python.distribute import tpu_strategy from tensorflow.python.eager import def_function from tensorflow.python.eager import test @@ -286,4 +287,4 @@ class ExponentialMovingAverageTest(test.TestCase, parameterized.TestCase): if __name__ == "__main__": - combinations.main() + test_util.main() diff --git a/tensorflow/python/distribute/strategy_combinations.py b/tensorflow/python/distribute/strategy_combinations.py index 57273e9bb15..0c35613c1a0 100644 --- a/tensorflow/python/distribute/strategy_combinations.py +++ b/tensorflow/python/distribute/strategy_combinations.py @@ -27,11 +27,11 @@ from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import mirrored_strategy as mirrored_lib from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import one_device_strategy as one_device_lib +from tensorflow.python.distribute import test_util from tensorflow.python.distribute import tpu_strategy as tpu_lib from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver from tensorflow.python.eager import context from tensorflow.python.eager import remote -from tensorflow.python.framework import config from tensorflow.python.platform import flags from tensorflow.python.tpu import device_assignment as device_assignment_lib from tensorflow.python.tpu import tpu_strategy_util @@ -246,27 +246,9 @@ multi_worker_mirrored_4x1_cpu = combinations.NamedDistribution( graph_and_eager_modes = ["graph", "eager"] -# This function should be called in a test's `setUp` method with the -# maximum value needed in any test. +# TODO(crccw): remove after tf-nightly picks up the new API. def set_virtual_cpus_to_at_least(num_virtual_cpus): - """Create virtual CPU devices if they haven't yet been created.""" - if num_virtual_cpus < 1: - raise ValueError("`num_virtual_cpus` must be at least 1 not %r" % - (num_virtual_cpus,)) - physical_devices = config.list_physical_devices("CPU") - if not physical_devices: - raise RuntimeError("No CPUs found") - configs = config.get_logical_device_configuration(physical_devices[0]) - if configs is None: - logical_devices = [ - context.LogicalDeviceConfiguration() for _ in range(num_virtual_cpus) - ] - config.set_logical_device_configuration(physical_devices[0], - logical_devices) - else: - if len(configs) < num_virtual_cpus: - raise RuntimeError("Already configured with %d < %d virtual CPUs" % - (len(configs), num_virtual_cpus)) + test_util.set_logical_devices_to_at_least("CPU", num_virtual_cpus) strategies_minus_tpu = [ diff --git a/tensorflow/python/distribute/strategy_combinations_test.py b/tensorflow/python/distribute/strategy_combinations_test.py index 38ace7da42d..1157520d654 100644 --- a/tensorflow/python/distribute/strategy_combinations_test.py +++ b/tensorflow/python/distribute/strategy_combinations_test.py @@ -23,49 +23,13 @@ from absl.testing import parameterized from tensorflow.python.distribute import combinations from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_combinations -from tensorflow.python.eager import context +from tensorflow.python.distribute import test_util from tensorflow.python.eager import def_function -from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class VirtualDevicesTest(test.TestCase, parameterized.TestCase): - - def setUp(self): - context._reset_context() # pylint: disable=protected-access - # Need to call set_virtual_cpus_to_at_least() in setUp with the maximum - # value needed in any test. - strategy_combinations.set_virtual_cpus_to_at_least(3) - super(VirtualDevicesTest, self).setUp() - - def test3VirtualCPUs(self): - cpu_device = config.list_physical_devices("CPU")[0] - self.assertLen(config.get_logical_device_configuration(cpu_device), 3) - - def testSetVirtualCPUsAgain(self): - strategy_combinations.set_virtual_cpus_to_at_least(2) - cpu_device = config.list_physical_devices("CPU")[0] - self.assertLen(config.get_logical_device_configuration(cpu_device), 3) - - def testSetVirtualCPUsErrors(self): - with self.assertRaises(ValueError): - strategy_combinations.set_virtual_cpus_to_at_least(0) - with self.assertRaisesRegex(RuntimeError, "with 3 < 5 virtual CPUs"): - strategy_combinations.set_virtual_cpus_to_at_least(5) - - @combinations.generate(combinations.combine( - distribution=[strategy_combinations.mirrored_strategy_with_cpu_1_and_2], - mode=["graph", "eager"])) - def testMirrored2CPUs(self, distribution): - with distribution.scope(): - one_per_replica = distribution.run(lambda: constant_op.constant(1)) - num_replicas = distribution.reduce( - reduce_util.ReduceOp.SUM, one_per_replica, axis=None) - self.assertEqual(2, self.evaluate(num_replicas)) - - class StrategyCombinationsTest(test.TestCase, parameterized.TestCase): @combinations.generate( @@ -100,6 +64,19 @@ class StrategyCombinationsTest(test.TestCase, parameterized.TestCase): reduce_util.ReduceOp.SUM, one_per_replica, axis=None) self.assertEqual(self.evaluate(num_replicas), 4.) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_cpu_1_and_2 + ], + mode=["graph", "eager"])) + def testMirrored2CPUs(self, distribution): + with distribution.scope(): + one_per_replica = distribution.run(lambda: constant_op.constant(1)) + num_replicas = distribution.reduce( + reduce_util.ReduceOp.SUM, one_per_replica, axis=None) + self.assertEqual(2, self.evaluate(num_replicas)) + if __name__ == "__main__": - combinations.main() + test_util.main() diff --git a/tensorflow/python/distribute/strategy_common_test.py b/tensorflow/python/distribute/strategy_common_test.py index 199851ab6c2..2c556db6c04 100644 --- a/tensorflow/python/distribute/strategy_common_test.py +++ b/tensorflow/python/distribute/strategy_common_test.py @@ -20,7 +20,6 @@ from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribution_strategy_context as ds_context @@ -28,6 +27,7 @@ from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_test_lib +from tensorflow.python.distribute import test_util from tensorflow.python.distribute.collective_all_reduce_strategy import CollectiveAllReduceStrategy from tensorflow.python.distribute.tpu_strategy import TPUStrategy from tensorflow.python.eager import def_function @@ -757,5 +757,4 @@ class StrategyClusterResolverTest(test.TestCase, parameterized.TestCase): if __name__ == '__main__': - v2_compat.enable_v2_behavior() - combinations.main() + test_util.main() diff --git a/tensorflow/python/distribute/test_util.py b/tensorflow/python/distribute/test_util.py index a6c861e5931..82867edb4c2 100644 --- a/tensorflow/python/distribute/test_util.py +++ b/tensorflow/python/distribute/test_util.py @@ -20,8 +20,12 @@ from __future__ import print_function import functools +from tensorflow.python.compat import v2_compat from tensorflow.python.distribute import collective_all_reduce_strategy +from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import values +from tensorflow.python.eager import context +from tensorflow.python.framework import config from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.util import nest @@ -56,3 +60,38 @@ def _gather(strategy, value): inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values] return strategy._gather(values.PerReplica(inputs), axis=0) # pylint: enable=protected-access + + +def set_logical_devices_to_at_least(device, num): + """Create logical devices of at least a given number.""" + if num < 1: + raise ValueError("`num` must be at least 1 not %r" % (num,)) + physical_devices = config.list_physical_devices(device) + if not physical_devices: + raise RuntimeError("No {} found".format(device)) + if len(physical_devices) >= num: + return + # By default each physical device corresponds to one logical device. We create + # multiple logical devices for the last physical device so that we have `num` + # logical devices. + num = num - len(physical_devices) + 1 + logical_devices = [] + for _ in range(num): + if device.upper() == "GPU": + logical_devices.append( + context.LogicalDeviceConfiguration(memory_limit=2048)) + else: + logical_devices.append(context.LogicalDeviceConfiguration()) + # Create logical devices from the the last device since sometimes the first + # GPU is the primary graphic card and may has less memory available. + config.set_logical_device_configuration(physical_devices[-1], logical_devices) + + +def main(enable_v2_behavior=True): + """All-in-one main function for tf.distribute tests.""" + if enable_v2_behavior: + v2_compat.enable_v2_behavior() + else: + v2_compat.disable_v2_behavior() + # TODO(b/131360402): configure default logical devices. + multi_process_runner.test_main() diff --git a/tensorflow/python/distribute/test_util_test.py b/tensorflow/python/distribute/test_util_test.py index 7dab2e199b1..165f97be6e2 100644 --- a/tensorflow/python/distribute/test_util_test.py +++ b/tensorflow/python/distribute/test_util_test.py @@ -1,13 +1,13 @@ # Copyright 2020 The TensorFlow Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); +# Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, +# distributed under the License is distributed on an 'AS IS' BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. @@ -23,8 +23,10 @@ from absl.testing import parameterized from tensorflow.python.distribute import combinations from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import test_util +from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import test +from tensorflow.python.framework import config from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -71,5 +73,14 @@ class GatherTest(test.TestCase, parameterized.TestCase): self.evaluate(results['bar'][1]), [1.] * strategy.num_replicas_in_sync) +class LogicalDevicesTest(test.TestCase): + + def testLogicalCPUs(self): + context._reset_context() + test_util.set_logical_devices_to_at_least('CPU', 3) + cpu_device = config.list_physical_devices('CPU')[0] + self.assertLen(config.get_logical_device_configuration(cpu_device), 3) + + if __name__ == '__main__': - combinations.main() + test_util.main() diff --git a/tensorflow/python/distribute/v1/BUILD b/tensorflow/python/distribute/v1/BUILD index 2b1e46b52d3..3c45d9d441e 100644 --- a/tensorflow/python/distribute/v1/BUILD +++ b/tensorflow/python/distribute/v1/BUILD @@ -31,6 +31,7 @@ cuda_py_test( "//tensorflow/python/distribute:multi_worker_util", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/distribute:test_util", "//tensorflow/python/distribute:values", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/eager:context", diff --git a/tensorflow/python/distribute/v1/cross_device_ops_test.py b/tensorflow/python/distribute/v1/cross_device_ops_test.py index be145aba174..9914505f51c 100644 --- a/tensorflow/python/distribute/v1/cross_device_ops_test.py +++ b/tensorflow/python/distribute/v1/cross_device_ops_test.py @@ -857,4 +857,4 @@ if __name__ == "__main__": # Set default inter op thread pool size to one to ensure we don't exhaust the # thread pool with the additional executors to run collectives in eager. os.environ["TF_NUM_INTEROP_THREADS"] = "1" - combinations.main() + test.main() diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index 0943acb04c5..8a9f0acbd75 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -1438,4 +1438,4 @@ def _make_index_slices(values, indices, dense_shape=None): if __name__ == "__main__": - combinations.main() + ds_test_util.main() diff --git a/tensorflow/python/distribute/vars_test.py b/tensorflow/python/distribute/vars_test.py index a2f27a053e8..e9eb9b77460 100644 --- a/tensorflow/python/distribute/vars_test.py +++ b/tensorflow/python/distribute/vars_test.py @@ -26,6 +26,7 @@ from absl.testing import parameterized from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import test_util from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver @@ -1272,4 +1273,4 @@ class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase): if __name__ == "__main__": - combinations.main() + test_util.main() diff --git a/tensorflow/python/keras/mixed_precision/experimental/BUILD b/tensorflow/python/keras/mixed_precision/experimental/BUILD index 249609af375..a21c41c774d 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/BUILD +++ b/tensorflow/python/keras/mixed_precision/experimental/BUILD @@ -153,6 +153,7 @@ py_test( "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:mirrored_strategy", "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/distribute:test_util", "//tensorflow/python/eager:context", "//tensorflow/python/keras/optimizer_v2", "@absl_py//absl/testing:parameterized", diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py index 8bb1dd1a2d4..738333039da 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py @@ -28,6 +28,7 @@ from tensorflow.python.distribute import combinations as ds_combinations from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import test_util from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op @@ -58,7 +59,7 @@ def get_var(val, dtype, name=None): class AutoCastVariableTest(test.TestCase, parameterized.TestCase): def setUp(self): - strategy_combinations.set_virtual_cpus_to_at_least(3) + test_util.set_logical_devices_to_at_least('CPU', 3) super(AutoCastVariableTest, self).setUp() @ds_combinations.generate(maybe_distribute) diff --git a/tensorflow/python/ops/nn_loss_scaling_utilities_test.py b/tensorflow/python/ops/nn_loss_scaling_utilities_test.py index 4f96f9ba6a3..7f150b34bb3 100644 --- a/tensorflow/python/ops/nn_loss_scaling_utilities_test.py +++ b/tensorflow/python/ops/nn_loss_scaling_utilities_test.py @@ -22,6 +22,7 @@ from absl.testing import parameterized from tensorflow.python.distribute import combinations from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import test_util from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -35,7 +36,7 @@ from tensorflow.python.platform import test as test_lib class LossUtilitiesTest(test_lib.TestCase, parameterized.TestCase): def setUp(self): - strategy_combinations.set_virtual_cpus_to_at_least(3) + test_util.set_logical_devices_to_at_least("CPU", 3) super(LossUtilitiesTest, self).setUp() def testComputeAverageLossGlobalBatchSize(self):