Add test_util.main() and test_util.set_logical_devices_to_at_least()
test_util.main() replaces combinations.main() test_util.set_logical_devices_to_at_least() replaces strategy_combinations.set_virtual_cpus_to_at_least() PiperOrigin-RevId: 335742598 Change-Id: Ie9967ed1f1fe866a83472319137aeb23a521c943
This commit is contained in:
parent
bde6a9cfee
commit
4ba6a1dc99
@ -5358,6 +5358,7 @@ py_test(
|
||||
":client_testlib",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/distribute:test_util",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -812,6 +812,7 @@ py_test(
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":combinations",
|
||||
":test_util",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_combinations",
|
||||
"//tensorflow/python/distribute/cluster_resolver:tfconfig_cluster_resolver_py",
|
||||
@ -837,6 +838,7 @@ py_library(
|
||||
":multi_process_runner",
|
||||
":multi_worker_test_base",
|
||||
":one_device_strategy",
|
||||
":test_util",
|
||||
":tpu_strategy",
|
||||
"//tensorflow/python:config",
|
||||
"//tensorflow/python:platform",
|
||||
@ -858,6 +860,7 @@ distribute_py_test(
|
||||
":combinations",
|
||||
":reduce_util",
|
||||
":strategy_combinations",
|
||||
":test_util",
|
||||
"//tensorflow/python:config",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python/eager:context",
|
||||
@ -948,6 +951,7 @@ distribute_py_test(
|
||||
":multi_worker_test_base",
|
||||
":reduce_util",
|
||||
":strategy_combinations",
|
||||
":test_util",
|
||||
":values",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:errors",
|
||||
@ -1235,6 +1239,7 @@ distribute_py_test(
|
||||
":combinations",
|
||||
":distribute_lib",
|
||||
":strategy_combinations",
|
||||
":test_util",
|
||||
":tpu_strategy",
|
||||
":tpu_values",
|
||||
":values",
|
||||
@ -1287,6 +1292,7 @@ distribute_py_test(
|
||||
deps = [
|
||||
":combinations",
|
||||
":strategy_combinations",
|
||||
":test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -1698,6 +1704,7 @@ distribute_py_test(
|
||||
":reduce_util",
|
||||
":strategy_combinations",
|
||||
":strategy_test_lib",
|
||||
":test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
@ -1743,10 +1750,14 @@ py_library(
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":collective_all_reduce_strategy",
|
||||
":multi_process_runner",
|
||||
":values",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:config",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -363,11 +363,6 @@ times = combinations_lib.times
|
||||
NamedObject = combinations_lib.NamedObject
|
||||
|
||||
|
||||
def main():
|
||||
"""Tests must call this main()."""
|
||||
return multi_process_runner.test_main()
|
||||
|
||||
|
||||
# Identifies whether we're in the main process or worker processes.
|
||||
# `_multi_worker_test` decoration behaves differently in the main processs and
|
||||
# the worker processes. See the documentation of _multi_worker_test for detail.
|
||||
|
@ -25,6 +25,7 @@ import unittest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.distribute.cluster_resolver import tfconfig_cluster_resolver
|
||||
from tensorflow.python.framework import combinations as framework_combinations
|
||||
from tensorflow.python.platform import test
|
||||
@ -156,4 +157,4 @@ class CombinationsOnClassMultiWorkerExpectedFailureTest(test.TestCase,
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
combinations.main()
|
||||
test_util.main()
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
@ -1421,4 +1422,4 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
combinations.main()
|
||||
test_util.main()
|
||||
|
@ -18,6 +18,7 @@ distribute_py_test(
|
||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||
"//tensorflow/python/distribute:sharded_variable",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/distribute:test_util",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
@ -40,6 +41,7 @@ cuda_py_test(
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:multi_process_runner",
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||
"//tensorflow/python/distribute:test_util",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
)
|
||||
|
@ -27,9 +27,9 @@ import os
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.eager import test
|
||||
|
||||
|
||||
@ -213,4 +213,4 @@ class PeerFailureRecoverTest(test.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
combinations.main()
|
||||
test_util.main()
|
||||
|
@ -39,6 +39,7 @@ from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||
from tensorflow.python.distribute import sharded_variable
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
@ -612,4 +613,4 @@ class PSStrategySaveAndLoadTest(test.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
combinations.main()
|
||||
test_util.main()
|
||||
|
@ -23,6 +23,7 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
@ -286,4 +287,4 @@ class ExponentialMovingAverageTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
combinations.main()
|
||||
test_util.main()
|
||||
|
@ -27,11 +27,11 @@ from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import mirrored_strategy as mirrored_lib
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import one_device_strategy as one_device_lib
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.distribute import tpu_strategy as tpu_lib
|
||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import remote
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.platform import flags
|
||||
from tensorflow.python.tpu import device_assignment as device_assignment_lib
|
||||
from tensorflow.python.tpu import tpu_strategy_util
|
||||
@ -246,27 +246,9 @@ multi_worker_mirrored_4x1_cpu = combinations.NamedDistribution(
|
||||
graph_and_eager_modes = ["graph", "eager"]
|
||||
|
||||
|
||||
# This function should be called in a test's `setUp` method with the
|
||||
# maximum value needed in any test.
|
||||
# TODO(crccw): remove after tf-nightly picks up the new API.
|
||||
def set_virtual_cpus_to_at_least(num_virtual_cpus):
|
||||
"""Create virtual CPU devices if they haven't yet been created."""
|
||||
if num_virtual_cpus < 1:
|
||||
raise ValueError("`num_virtual_cpus` must be at least 1 not %r" %
|
||||
(num_virtual_cpus,))
|
||||
physical_devices = config.list_physical_devices("CPU")
|
||||
if not physical_devices:
|
||||
raise RuntimeError("No CPUs found")
|
||||
configs = config.get_logical_device_configuration(physical_devices[0])
|
||||
if configs is None:
|
||||
logical_devices = [
|
||||
context.LogicalDeviceConfiguration() for _ in range(num_virtual_cpus)
|
||||
]
|
||||
config.set_logical_device_configuration(physical_devices[0],
|
||||
logical_devices)
|
||||
else:
|
||||
if len(configs) < num_virtual_cpus:
|
||||
raise RuntimeError("Already configured with %d < %d virtual CPUs" %
|
||||
(len(configs), num_virtual_cpus))
|
||||
test_util.set_logical_devices_to_at_least("CPU", num_virtual_cpus)
|
||||
|
||||
|
||||
strategies_minus_tpu = [
|
||||
|
@ -23,49 +23,13 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class VirtualDevicesTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
context._reset_context() # pylint: disable=protected-access
|
||||
# Need to call set_virtual_cpus_to_at_least() in setUp with the maximum
|
||||
# value needed in any test.
|
||||
strategy_combinations.set_virtual_cpus_to_at_least(3)
|
||||
super(VirtualDevicesTest, self).setUp()
|
||||
|
||||
def test3VirtualCPUs(self):
|
||||
cpu_device = config.list_physical_devices("CPU")[0]
|
||||
self.assertLen(config.get_logical_device_configuration(cpu_device), 3)
|
||||
|
||||
def testSetVirtualCPUsAgain(self):
|
||||
strategy_combinations.set_virtual_cpus_to_at_least(2)
|
||||
cpu_device = config.list_physical_devices("CPU")[0]
|
||||
self.assertLen(config.get_logical_device_configuration(cpu_device), 3)
|
||||
|
||||
def testSetVirtualCPUsErrors(self):
|
||||
with self.assertRaises(ValueError):
|
||||
strategy_combinations.set_virtual_cpus_to_at_least(0)
|
||||
with self.assertRaisesRegex(RuntimeError, "with 3 < 5 virtual CPUs"):
|
||||
strategy_combinations.set_virtual_cpus_to_at_least(5)
|
||||
|
||||
@combinations.generate(combinations.combine(
|
||||
distribution=[strategy_combinations.mirrored_strategy_with_cpu_1_and_2],
|
||||
mode=["graph", "eager"]))
|
||||
def testMirrored2CPUs(self, distribution):
|
||||
with distribution.scope():
|
||||
one_per_replica = distribution.run(lambda: constant_op.constant(1))
|
||||
num_replicas = distribution.reduce(
|
||||
reduce_util.ReduceOp.SUM, one_per_replica, axis=None)
|
||||
self.assertEqual(2, self.evaluate(num_replicas))
|
||||
|
||||
|
||||
class StrategyCombinationsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
@ -100,6 +64,19 @@ class StrategyCombinationsTest(test.TestCase, parameterized.TestCase):
|
||||
reduce_util.ReduceOp.SUM, one_per_replica, axis=None)
|
||||
self.assertEqual(self.evaluate(num_replicas), 4.)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_cpu_1_and_2
|
||||
],
|
||||
mode=["graph", "eager"]))
|
||||
def testMirrored2CPUs(self, distribution):
|
||||
with distribution.scope():
|
||||
one_per_replica = distribution.run(lambda: constant_op.constant(1))
|
||||
num_replicas = distribution.reduce(
|
||||
reduce_util.ReduceOp.SUM, one_per_replica, axis=None)
|
||||
self.assertEqual(2, self.evaluate(num_replicas))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
combinations.main()
|
||||
test_util.main()
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
@ -28,6 +27,7 @@ from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import strategy_test_lib
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.distribute.collective_all_reduce_strategy import CollectiveAllReduceStrategy
|
||||
from tensorflow.python.distribute.tpu_strategy import TPUStrategy
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -757,5 +757,4 @@ class StrategyClusterResolverTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v2_compat.enable_v2_behavior()
|
||||
combinations.main()
|
||||
test_util.main()
|
||||
|
@ -20,8 +20,12 @@ from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.util import nest
|
||||
@ -56,3 +60,38 @@ def _gather(strategy, value):
|
||||
inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values]
|
||||
return strategy._gather(values.PerReplica(inputs), axis=0)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def set_logical_devices_to_at_least(device, num):
|
||||
"""Create logical devices of at least a given number."""
|
||||
if num < 1:
|
||||
raise ValueError("`num` must be at least 1 not %r" % (num,))
|
||||
physical_devices = config.list_physical_devices(device)
|
||||
if not physical_devices:
|
||||
raise RuntimeError("No {} found".format(device))
|
||||
if len(physical_devices) >= num:
|
||||
return
|
||||
# By default each physical device corresponds to one logical device. We create
|
||||
# multiple logical devices for the last physical device so that we have `num`
|
||||
# logical devices.
|
||||
num = num - len(physical_devices) + 1
|
||||
logical_devices = []
|
||||
for _ in range(num):
|
||||
if device.upper() == "GPU":
|
||||
logical_devices.append(
|
||||
context.LogicalDeviceConfiguration(memory_limit=2048))
|
||||
else:
|
||||
logical_devices.append(context.LogicalDeviceConfiguration())
|
||||
# Create logical devices from the the last device since sometimes the first
|
||||
# GPU is the primary graphic card and may has less memory available.
|
||||
config.set_logical_device_configuration(physical_devices[-1], logical_devices)
|
||||
|
||||
|
||||
def main(enable_v2_behavior=True):
|
||||
"""All-in-one main function for tf.distribute tests."""
|
||||
if enable_v2_behavior:
|
||||
v2_compat.enable_v2_behavior()
|
||||
else:
|
||||
v2_compat.disable_v2_behavior()
|
||||
# TODO(b/131360402): configure default logical devices.
|
||||
multi_process_runner.test_main()
|
||||
|
@ -1,13 +1,13 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@ -23,8 +23,10 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
@ -71,5 +73,14 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
self.evaluate(results['bar'][1]), [1.] * strategy.num_replicas_in_sync)
|
||||
|
||||
|
||||
class LogicalDevicesTest(test.TestCase):
|
||||
|
||||
def testLogicalCPUs(self):
|
||||
context._reset_context()
|
||||
test_util.set_logical_devices_to_at_least('CPU', 3)
|
||||
cpu_device = config.list_physical_devices('CPU')[0]
|
||||
self.assertLen(config.get_logical_device_configuration(cpu_device), 3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
combinations.main()
|
||||
test_util.main()
|
||||
|
@ -31,6 +31,7 @@ cuda_py_test(
|
||||
"//tensorflow/python/distribute:multi_worker_util",
|
||||
"//tensorflow/python/distribute:reduce_util",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/distribute:test_util",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/eager:context",
|
||||
|
@ -857,4 +857,4 @@ if __name__ == "__main__":
|
||||
# Set default inter op thread pool size to one to ensure we don't exhaust the
|
||||
# thread pool with the additional executors to run collectives in eager.
|
||||
os.environ["TF_NUM_INTEROP_THREADS"] = "1"
|
||||
combinations.main()
|
||||
test.main()
|
||||
|
@ -1438,4 +1438,4 @@ def _make_index_slices(values, indices, dense_shape=None):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
combinations.main()
|
||||
ds_test_util.main()
|
||||
|
@ -26,6 +26,7 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
|
||||
@ -1272,4 +1273,4 @@ class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
combinations.main()
|
||||
test_util.main()
|
||||
|
@ -153,6 +153,7 @@ py_test(
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:mirrored_strategy",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/distribute:test_util",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/keras/optimizer_v2",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.distribute import combinations as ds_combinations
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -58,7 +59,7 @@ def get_var(val, dtype, name=None):
|
||||
class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
strategy_combinations.set_virtual_cpus_to_at_least(3)
|
||||
test_util.set_logical_devices_to_at_least('CPU', 3)
|
||||
super(AutoCastVariableTest, self).setUp()
|
||||
|
||||
@ds_combinations.generate(maybe_distribute)
|
||||
|
@ -22,6 +22,7 @@ from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -35,7 +36,7 @@ from tensorflow.python.platform import test as test_lib
|
||||
class LossUtilitiesTest(test_lib.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
strategy_combinations.set_virtual_cpus_to_at_least(3)
|
||||
test_util.set_logical_devices_to_at_least("CPU", 3)
|
||||
super(LossUtilitiesTest, self).setUp()
|
||||
|
||||
def testComputeAverageLossGlobalBatchSize(self):
|
||||
|
Loading…
Reference in New Issue
Block a user