diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 7825f7e5dd2..70ea4ebe1e2 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -5358,6 +5358,7 @@ py_test(
         ":client_testlib",
         "//tensorflow/python/distribute:combinations",
         "//tensorflow/python/distribute:strategy_combinations",
+        "//tensorflow/python/distribute:test_util",
         "@absl_py//absl/testing:parameterized",
     ],
 )
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 7551b751166..9962aec07f3 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -812,6 +812,7 @@ py_test(
     python_version = "PY3",
     deps = [
         ":combinations",
+        ":test_util",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_combinations",
         "//tensorflow/python/distribute/cluster_resolver:tfconfig_cluster_resolver_py",
@@ -837,6 +838,7 @@ py_library(
         ":multi_process_runner",
         ":multi_worker_test_base",
         ":one_device_strategy",
+        ":test_util",
         ":tpu_strategy",
         "//tensorflow/python:config",
         "//tensorflow/python:platform",
@@ -858,6 +860,7 @@ distribute_py_test(
         ":combinations",
         ":reduce_util",
         ":strategy_combinations",
+        ":test_util",
         "//tensorflow/python:config",
         "//tensorflow/python:constant_op",
         "//tensorflow/python/eager:context",
@@ -948,6 +951,7 @@ distribute_py_test(
         ":multi_worker_test_base",
         ":reduce_util",
         ":strategy_combinations",
+        ":test_util",
         ":values",
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:errors",
@@ -1235,6 +1239,7 @@ distribute_py_test(
         ":combinations",
         ":distribute_lib",
         ":strategy_combinations",
+        ":test_util",
         ":tpu_strategy",
         ":tpu_values",
         ":values",
@@ -1287,6 +1292,7 @@ distribute_py_test(
     deps = [
         ":combinations",
         ":strategy_combinations",
+        ":test_util",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:constant_op",
         "//tensorflow/python:dtypes",
@@ -1698,6 +1704,7 @@ distribute_py_test(
         ":reduce_util",
         ":strategy_combinations",
         ":strategy_test_lib",
+        ":test_util",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:constant_op",
@@ -1743,10 +1750,14 @@ py_library(
     srcs_version = "PY3",
     deps = [
         ":collective_all_reduce_strategy",
+        ":multi_process_runner",
         ":values",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python:config",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:util",
+        "//tensorflow/python/compat:v2_compat",
+        "//tensorflow/python/eager:context",
     ],
 )
 
diff --git a/tensorflow/python/distribute/combinations.py b/tensorflow/python/distribute/combinations.py
index d1679c22c77..c9d3d7d9a9a 100644
--- a/tensorflow/python/distribute/combinations.py
+++ b/tensorflow/python/distribute/combinations.py
@@ -363,11 +363,6 @@ times = combinations_lib.times
 NamedObject = combinations_lib.NamedObject
 
 
-def main():
-  """Tests must call this main()."""
-  return multi_process_runner.test_main()
-
-
 # Identifies whether we're in the main process or worker processes.
 # `_multi_worker_test` decoration behaves differently in the main processs and
 # the worker processes. See the documentation of _multi_worker_test for detail.
diff --git a/tensorflow/python/distribute/combinations_test.py b/tensorflow/python/distribute/combinations_test.py
index fd1646fa9b7..02ddcbef632 100644
--- a/tensorflow/python/distribute/combinations_test.py
+++ b/tensorflow/python/distribute/combinations_test.py
@@ -25,6 +25,7 @@ import unittest
 from absl.testing import parameterized
 
 from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import test_util
 from tensorflow.python.distribute.cluster_resolver import tfconfig_cluster_resolver
 from tensorflow.python.framework import combinations as framework_combinations
 from tensorflow.python.platform import test
@@ -156,4 +157,4 @@ class CombinationsOnClassMultiWorkerExpectedFailureTest(test.TestCase,
 
 
 if __name__ == "__main__":
-  combinations.main()
+  test_util.main()
diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py
index b266dd25bc0..5abd6f483d3 100644
--- a/tensorflow/python/distribute/input_lib_test.py
+++ b/tensorflow/python/distribute/input_lib_test.py
@@ -35,6 +35,7 @@ from tensorflow.python.distribute import input_lib
 from tensorflow.python.distribute import multi_worker_util
 from tensorflow.python.distribute import reduce_util
 from tensorflow.python.distribute import strategy_combinations
+from tensorflow.python.distribute import test_util
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import test
@@ -1421,4 +1422,4 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
 
 
 if __name__ == "__main__":
-  combinations.main()
+  test_util.main()
diff --git a/tensorflow/python/distribute/integration_test/BUILD b/tensorflow/python/distribute/integration_test/BUILD
index c1849afcb70..5ea98593ffe 100644
--- a/tensorflow/python/distribute/integration_test/BUILD
+++ b/tensorflow/python/distribute/integration_test/BUILD
@@ -18,6 +18,7 @@ distribute_py_test(
         "//tensorflow/python/distribute:parameter_server_strategy_v2",
         "//tensorflow/python/distribute:sharded_variable",
         "//tensorflow/python/distribute:strategy_combinations",
+        "//tensorflow/python/distribute:test_util",
         "//tensorflow/python/distribute:values",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:test",
@@ -40,6 +41,7 @@ cuda_py_test(
         "//tensorflow/python/distribute:combinations",
         "//tensorflow/python/distribute:multi_process_runner",
         "//tensorflow/python/distribute:multi_worker_test_base",
+        "//tensorflow/python/distribute:test_util",
         "//tensorflow/python/eager:test",
     ],
 )
diff --git a/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py b/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py
index 6d822ca1b97..02dee6f6adb 100644
--- a/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py
+++ b/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py
@@ -27,9 +27,9 @@ import os
 import tensorflow as tf
 
 from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib
-from tensorflow.python.distribute import combinations
 from tensorflow.python.distribute import multi_process_runner
 from tensorflow.python.distribute import multi_worker_test_base
+from tensorflow.python.distribute import test_util
 from tensorflow.python.eager import test
 
 
@@ -213,4 +213,4 @@ class PeerFailureRecoverTest(test.TestCase):
 
 
 if __name__ == "__main__":
-  combinations.main()
+  test_util.main()
diff --git a/tensorflow/python/distribute/integration_test/saved_model_test.py b/tensorflow/python/distribute/integration_test/saved_model_test.py
index 704279b4b04..b8ac71ef203 100644
--- a/tensorflow/python/distribute/integration_test/saved_model_test.py
+++ b/tensorflow/python/distribute/integration_test/saved_model_test.py
@@ -39,6 +39,7 @@ from tensorflow.python.distribute import multi_worker_test_base
 from tensorflow.python.distribute import parameter_server_strategy_v2
 from tensorflow.python.distribute import sharded_variable
 from tensorflow.python.distribute import strategy_combinations
+from tensorflow.python.distribute import test_util
 from tensorflow.python.distribute import values
 from tensorflow.python.eager import context
 from tensorflow.python.eager import test
@@ -612,4 +613,4 @@ class PSStrategySaveAndLoadTest(test.TestCase):
 
 
 if __name__ == "__main__":
-  combinations.main()
+  test_util.main()
diff --git a/tensorflow/python/distribute/moving_averages_test.py b/tensorflow/python/distribute/moving_averages_test.py
index 577a6c1168f..22522ea2389 100644
--- a/tensorflow/python/distribute/moving_averages_test.py
+++ b/tensorflow/python/distribute/moving_averages_test.py
@@ -23,6 +23,7 @@ from absl.testing import parameterized
 from tensorflow.python.distribute import collective_all_reduce_strategy
 from tensorflow.python.distribute import combinations
 from tensorflow.python.distribute import strategy_combinations
+from tensorflow.python.distribute import test_util
 from tensorflow.python.distribute import tpu_strategy
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import test
@@ -286,4 +287,4 @@ class ExponentialMovingAverageTest(test.TestCase, parameterized.TestCase):
 
 
 if __name__ == "__main__":
-  combinations.main()
+  test_util.main()
diff --git a/tensorflow/python/distribute/strategy_combinations.py b/tensorflow/python/distribute/strategy_combinations.py
index 57273e9bb15..0c35613c1a0 100644
--- a/tensorflow/python/distribute/strategy_combinations.py
+++ b/tensorflow/python/distribute/strategy_combinations.py
@@ -27,11 +27,11 @@ from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.distribute import mirrored_strategy as mirrored_lib
 from tensorflow.python.distribute import multi_process_runner
 from tensorflow.python.distribute import one_device_strategy as one_device_lib
+from tensorflow.python.distribute import test_util
 from tensorflow.python.distribute import tpu_strategy as tpu_lib
 from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
 from tensorflow.python.eager import context
 from tensorflow.python.eager import remote
-from tensorflow.python.framework import config
 from tensorflow.python.platform import flags
 from tensorflow.python.tpu import device_assignment as device_assignment_lib
 from tensorflow.python.tpu import tpu_strategy_util
@@ -246,27 +246,9 @@ multi_worker_mirrored_4x1_cpu = combinations.NamedDistribution(
 graph_and_eager_modes = ["graph", "eager"]
 
 
-# This function should be called in a test's `setUp` method with the
-# maximum value needed in any test.
+# TODO(crccw): remove after tf-nightly picks up the new API.
 def set_virtual_cpus_to_at_least(num_virtual_cpus):
-  """Create virtual CPU devices if they haven't yet been created."""
-  if num_virtual_cpus < 1:
-    raise ValueError("`num_virtual_cpus` must be at least 1 not %r" %
-                     (num_virtual_cpus,))
-  physical_devices = config.list_physical_devices("CPU")
-  if not physical_devices:
-    raise RuntimeError("No CPUs found")
-  configs = config.get_logical_device_configuration(physical_devices[0])
-  if configs is None:
-    logical_devices = [
-        context.LogicalDeviceConfiguration() for _ in range(num_virtual_cpus)
-    ]
-    config.set_logical_device_configuration(physical_devices[0],
-                                            logical_devices)
-  else:
-    if len(configs) < num_virtual_cpus:
-      raise RuntimeError("Already configured with %d < %d virtual CPUs" %
-                         (len(configs), num_virtual_cpus))
+  test_util.set_logical_devices_to_at_least("CPU", num_virtual_cpus)
 
 
 strategies_minus_tpu = [
diff --git a/tensorflow/python/distribute/strategy_combinations_test.py b/tensorflow/python/distribute/strategy_combinations_test.py
index 38ace7da42d..1157520d654 100644
--- a/tensorflow/python/distribute/strategy_combinations_test.py
+++ b/tensorflow/python/distribute/strategy_combinations_test.py
@@ -23,49 +23,13 @@ from absl.testing import parameterized
 from tensorflow.python.distribute import combinations
 from tensorflow.python.distribute import reduce_util
 from tensorflow.python.distribute import strategy_combinations
-from tensorflow.python.eager import context
+from tensorflow.python.distribute import test_util
 from tensorflow.python.eager import def_function
-from tensorflow.python.framework import config
 from tensorflow.python.framework import constant_op
 from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import test
 
 
-class VirtualDevicesTest(test.TestCase, parameterized.TestCase):
-
-  def setUp(self):
-    context._reset_context()  # pylint: disable=protected-access
-    # Need to call set_virtual_cpus_to_at_least() in setUp with the maximum
-    # value needed in any test.
-    strategy_combinations.set_virtual_cpus_to_at_least(3)
-    super(VirtualDevicesTest, self).setUp()
-
-  def test3VirtualCPUs(self):
-    cpu_device = config.list_physical_devices("CPU")[0]
-    self.assertLen(config.get_logical_device_configuration(cpu_device), 3)
-
-  def testSetVirtualCPUsAgain(self):
-    strategy_combinations.set_virtual_cpus_to_at_least(2)
-    cpu_device = config.list_physical_devices("CPU")[0]
-    self.assertLen(config.get_logical_device_configuration(cpu_device), 3)
-
-  def testSetVirtualCPUsErrors(self):
-    with self.assertRaises(ValueError):
-      strategy_combinations.set_virtual_cpus_to_at_least(0)
-    with self.assertRaisesRegex(RuntimeError, "with 3 < 5 virtual CPUs"):
-      strategy_combinations.set_virtual_cpus_to_at_least(5)
-
-  @combinations.generate(combinations.combine(
-      distribution=[strategy_combinations.mirrored_strategy_with_cpu_1_and_2],
-      mode=["graph", "eager"]))
-  def testMirrored2CPUs(self, distribution):
-    with distribution.scope():
-      one_per_replica = distribution.run(lambda: constant_op.constant(1))
-      num_replicas = distribution.reduce(
-          reduce_util.ReduceOp.SUM, one_per_replica, axis=None)
-      self.assertEqual(2, self.evaluate(num_replicas))
-
-
 class StrategyCombinationsTest(test.TestCase, parameterized.TestCase):
 
   @combinations.generate(
@@ -100,6 +64,19 @@ class StrategyCombinationsTest(test.TestCase, parameterized.TestCase):
           reduce_util.ReduceOp.SUM, one_per_replica, axis=None)
       self.assertEqual(self.evaluate(num_replicas), 4.)
 
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.mirrored_strategy_with_cpu_1_and_2
+          ],
+          mode=["graph", "eager"]))
+  def testMirrored2CPUs(self, distribution):
+    with distribution.scope():
+      one_per_replica = distribution.run(lambda: constant_op.constant(1))
+      num_replicas = distribution.reduce(
+          reduce_util.ReduceOp.SUM, one_per_replica, axis=None)
+      self.assertEqual(2, self.evaluate(num_replicas))
+
 
 if __name__ == "__main__":
-  combinations.main()
+  test_util.main()
diff --git a/tensorflow/python/distribute/strategy_common_test.py b/tensorflow/python/distribute/strategy_common_test.py
index 199851ab6c2..2c556db6c04 100644
--- a/tensorflow/python/distribute/strategy_common_test.py
+++ b/tensorflow/python/distribute/strategy_common_test.py
@@ -20,7 +20,6 @@ from __future__ import print_function
 
 from absl.testing import parameterized
 
-from tensorflow.python.compat import v2_compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.distribute import combinations
 from tensorflow.python.distribute import distribution_strategy_context as ds_context
@@ -28,6 +27,7 @@ from tensorflow.python.distribute import multi_worker_test_base
 from tensorflow.python.distribute import reduce_util
 from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.distribute import strategy_test_lib
+from tensorflow.python.distribute import test_util
 from tensorflow.python.distribute.collective_all_reduce_strategy import CollectiveAllReduceStrategy
 from tensorflow.python.distribute.tpu_strategy import TPUStrategy
 from tensorflow.python.eager import def_function
@@ -757,5 +757,4 @@ class StrategyClusterResolverTest(test.TestCase, parameterized.TestCase):
 
 
 if __name__ == '__main__':
-  v2_compat.enable_v2_behavior()
-  combinations.main()
+  test_util.main()
diff --git a/tensorflow/python/distribute/test_util.py b/tensorflow/python/distribute/test_util.py
index a6c861e5931..82867edb4c2 100644
--- a/tensorflow/python/distribute/test_util.py
+++ b/tensorflow/python/distribute/test_util.py
@@ -20,8 +20,12 @@ from __future__ import print_function
 
 import functools
 
+from tensorflow.python.compat import v2_compat
 from tensorflow.python.distribute import collective_all_reduce_strategy
+from tensorflow.python.distribute import multi_process_runner
 from tensorflow.python.distribute import values
+from tensorflow.python.eager import context
+from tensorflow.python.framework import config
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.util import nest
@@ -56,3 +60,38 @@ def _gather(strategy, value):
   inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values]
   return strategy._gather(values.PerReplica(inputs), axis=0)
   # pylint: enable=protected-access
+
+
+def set_logical_devices_to_at_least(device, num):
+  """Create logical devices of at least a given number."""
+  if num < 1:
+    raise ValueError("`num` must be at least 1 not %r" % (num,))
+  physical_devices = config.list_physical_devices(device)
+  if not physical_devices:
+    raise RuntimeError("No {} found".format(device))
+  if len(physical_devices) >= num:
+    return
+  # By default each physical device corresponds to one logical device. We create
+  # multiple logical devices for the last physical device so that we have `num`
+  # logical devices.
+  num = num - len(physical_devices) + 1
+  logical_devices = []
+  for _ in range(num):
+    if device.upper() == "GPU":
+      logical_devices.append(
+          context.LogicalDeviceConfiguration(memory_limit=2048))
+    else:
+      logical_devices.append(context.LogicalDeviceConfiguration())
+  # Create logical devices from the the last device since sometimes the first
+  # GPU is the primary graphic card and may has less memory available.
+  config.set_logical_device_configuration(physical_devices[-1], logical_devices)
+
+
+def main(enable_v2_behavior=True):
+  """All-in-one main function for tf.distribute tests."""
+  if enable_v2_behavior:
+    v2_compat.enable_v2_behavior()
+  else:
+    v2_compat.disable_v2_behavior()
+  # TODO(b/131360402): configure default logical devices.
+  multi_process_runner.test_main()
diff --git a/tensorflow/python/distribute/test_util_test.py b/tensorflow/python/distribute/test_util_test.py
index 7dab2e199b1..165f97be6e2 100644
--- a/tensorflow/python/distribute/test_util_test.py
+++ b/tensorflow/python/distribute/test_util_test.py
@@ -1,13 +1,13 @@
 # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 #
-# Licensed under the Apache License, Version 2.0 (the "License");
+# Licensed under the Apache License, Version 2.0 (the 'License');
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
+# distributed under the License is distributed on an 'AS IS' BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
@@ -23,8 +23,10 @@ from absl.testing import parameterized
 from tensorflow.python.distribute import combinations
 from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.distribute import test_util
+from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import test
+from tensorflow.python.framework import config
 from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import array_ops
 
@@ -71,5 +73,14 @@ class GatherTest(test.TestCase, parameterized.TestCase):
         self.evaluate(results['bar'][1]), [1.] * strategy.num_replicas_in_sync)
 
 
+class LogicalDevicesTest(test.TestCase):
+
+  def testLogicalCPUs(self):
+    context._reset_context()
+    test_util.set_logical_devices_to_at_least('CPU', 3)
+    cpu_device = config.list_physical_devices('CPU')[0]
+    self.assertLen(config.get_logical_device_configuration(cpu_device), 3)
+
+
 if __name__ == '__main__':
-  combinations.main()
+  test_util.main()
diff --git a/tensorflow/python/distribute/v1/BUILD b/tensorflow/python/distribute/v1/BUILD
index 2b1e46b52d3..3c45d9d441e 100644
--- a/tensorflow/python/distribute/v1/BUILD
+++ b/tensorflow/python/distribute/v1/BUILD
@@ -31,6 +31,7 @@ cuda_py_test(
         "//tensorflow/python/distribute:multi_worker_util",
         "//tensorflow/python/distribute:reduce_util",
         "//tensorflow/python/distribute:strategy_combinations",
+        "//tensorflow/python/distribute:test_util",
         "//tensorflow/python/distribute:values",
         "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
         "//tensorflow/python/eager:context",
diff --git a/tensorflow/python/distribute/v1/cross_device_ops_test.py b/tensorflow/python/distribute/v1/cross_device_ops_test.py
index be145aba174..9914505f51c 100644
--- a/tensorflow/python/distribute/v1/cross_device_ops_test.py
+++ b/tensorflow/python/distribute/v1/cross_device_ops_test.py
@@ -857,4 +857,4 @@ if __name__ == "__main__":
   # Set default inter op thread pool size to one to ensure we don't exhaust the
   # thread pool with the additional executors to run collectives in eager.
   os.environ["TF_NUM_INTEROP_THREADS"] = "1"
-  combinations.main()
+  test.main()
diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py
index 0943acb04c5..8a9f0acbd75 100644
--- a/tensorflow/python/distribute/values_test.py
+++ b/tensorflow/python/distribute/values_test.py
@@ -1438,4 +1438,4 @@ def _make_index_slices(values, indices, dense_shape=None):
 
 
 if __name__ == "__main__":
-  combinations.main()
+  ds_test_util.main()
diff --git a/tensorflow/python/distribute/vars_test.py b/tensorflow/python/distribute/vars_test.py
index a2f27a053e8..e9eb9b77460 100644
--- a/tensorflow/python/distribute/vars_test.py
+++ b/tensorflow/python/distribute/vars_test.py
@@ -26,6 +26,7 @@ from absl.testing import parameterized
 from tensorflow.python.distribute import combinations
 from tensorflow.python.distribute import distribution_strategy_context as ds_context
 from tensorflow.python.distribute import strategy_combinations
+from tensorflow.python.distribute import test_util
 from tensorflow.python.distribute import tpu_strategy
 from tensorflow.python.distribute import values
 from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
@@ -1272,4 +1273,4 @@ class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
 
 
 if __name__ == "__main__":
-  combinations.main()
+  test_util.main()
diff --git a/tensorflow/python/keras/mixed_precision/experimental/BUILD b/tensorflow/python/keras/mixed_precision/experimental/BUILD
index 249609af375..a21c41c774d 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/BUILD
+++ b/tensorflow/python/keras/mixed_precision/experimental/BUILD
@@ -153,6 +153,7 @@ py_test(
         "//tensorflow/python/distribute:combinations",
         "//tensorflow/python/distribute:mirrored_strategy",
         "//tensorflow/python/distribute:strategy_combinations",
+        "//tensorflow/python/distribute:test_util",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/keras/optimizer_v2",
         "@absl_py//absl/testing:parameterized",
diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py
index 8bb1dd1a2d4..738333039da 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.distribute import combinations as ds_combinations
 from tensorflow.python.distribute import distribution_strategy_context as ds_context
 from tensorflow.python.distribute import mirrored_strategy
 from tensorflow.python.distribute import strategy_combinations
+from tensorflow.python.distribute import test_util
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
@@ -58,7 +59,7 @@ def get_var(val, dtype, name=None):
 class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
 
   def setUp(self):
-    strategy_combinations.set_virtual_cpus_to_at_least(3)
+    test_util.set_logical_devices_to_at_least('CPU', 3)
     super(AutoCastVariableTest, self).setUp()
 
   @ds_combinations.generate(maybe_distribute)
diff --git a/tensorflow/python/ops/nn_loss_scaling_utilities_test.py b/tensorflow/python/ops/nn_loss_scaling_utilities_test.py
index 4f96f9ba6a3..7f150b34bb3 100644
--- a/tensorflow/python/ops/nn_loss_scaling_utilities_test.py
+++ b/tensorflow/python/ops/nn_loss_scaling_utilities_test.py
@@ -22,6 +22,7 @@ from absl.testing import parameterized
 
 from tensorflow.python.distribute import combinations
 from tensorflow.python.distribute import strategy_combinations
+from tensorflow.python.distribute import test_util
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -35,7 +36,7 @@ from tensorflow.python.platform import test as test_lib
 class LossUtilitiesTest(test_lib.TestCase, parameterized.TestCase):
 
   def setUp(self):
-    strategy_combinations.set_virtual_cpus_to_at_least(3)
+    test_util.set_logical_devices_to_at_least("CPU", 3)
     super(LossUtilitiesTest, self).setUp()
 
   def testComputeAverageLossGlobalBatchSize(self):