STT-tensorflow/tensorflow/python/distribute/combinations.py
Ran Chen 3a62bc8168 Use the same worker pool for tests that requires the same number of workers
We used to use one worker pool per strategy combination, but it's not necessary.
If the cluster topology is the same they can share the same worker pool. This
reduces the overhead of initializing worker pools, which can take O(10s) for
GPU builds.

PiperOrigin-RevId: 339117353
Change-Id: I1f631f79597b07991528c77482c44c201a01abe4
2020-10-26 14:31:49 -07:00

521 lines
20 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This module customizes `test_combinations` for `tf.distribute.Strategy`.
Additionally it provides `generate()`, `combine()` and `times()` with
`tf.distribute.Strategy` customizations as a default.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import re
import sys
import types
import unittest
import six
from tensorflow.python.client import session
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.eager import context
from tensorflow.python.framework import combinations as framework_combinations
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_combinations as combinations_lib
from tensorflow.python.framework import test_util
from tensorflow.python.platform import flags
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
# TODO(rchao): Rename `distribution` parameter to `strategy` or
# `distribute_strategy` in all tests.
class DistributionParameter(combinations_lib.ParameterModifier):
"""Transforms arguments of type `NamedDistribution`.
Convert all arguments of type `NamedDistribution` to the value of their
`strategy` property.
"""
def modified_arguments(self, kwargs, requested_parameters):
# Get the parameter that indicates if we need to set the `_use_policy` flag
# on the strategy object. This is a temporary flag for testing the variable
# policy rollout.
use_var_policy = kwargs.get("use_var_policy", None)
distribution_arguments = {}
for k, v in kwargs.items():
if isinstance(v, NamedDistribution):
strategy = v.strategy
if use_var_policy:
strategy.extended._use_var_policy = use_var_policy
distribution_arguments[k] = strategy
return distribution_arguments
class ClusterParameters(combinations_lib.ParameterModifier):
"""Adds cluster parameters if a `NamedDistribution` has it.
It needs to be before DistributionParameter.
"""
def modified_arguments(self, kwargs, requested_parameters):
strategy = None
for _, v in kwargs.items():
if isinstance(v, NamedDistribution):
if strategy is not None and _num_total_workers(v.has_chief,
v.num_workers) > 1:
raise ValueError("Only support one NamedDistribution for multi worker"
"tests.")
strategy = v
if strategy:
has_chief = strategy.has_chief
num_workers = strategy.num_workers
runner = strategy.runner
if "has_chief" in kwargs and kwargs["has_chief"] != has_chief:
raise ValueError(
"both has_chief and strategy specified but are not compatible")
if "num_workers" in kwargs and kwargs["num_workers"] != num_workers:
raise ValueError(
"both num_workers and strategy specified but are not compatible")
else:
has_chief = kwargs.get("has_chief", False)
num_workers = kwargs.get("num_workers", 1)
runner = kwargs.get("runner", None)
# Always set cluster parameters if they're requested. So that generate()
# works when there's no startegy in the combinations.
update = {}
if "has_chief" in requested_parameters:
update["has_chief"] = has_chief
if "num_workers" in requested_parameters:
update["num_workers"] = num_workers
if "runner" in requested_parameters:
update["runner"] = runner
return update
class DistributionCombination(combinations_lib.TestCombination):
"""Sets up distribution strategy for tests."""
def should_execute_combination(self, kwargs):
distributions = [
v for v in kwargs.values() if isinstance(v, NamedDistribution)
]
if test_util.is_xla_enabled() and any(d.no_xla for d in distributions):
return (
False,
"n/a: skipping strategy combination with no_xla=True in XLA tests")
return (True, None)
def parameter_modifiers(self):
return [
DistributionParameter(),
combinations_lib.OptionalParameter("use_var_policy"),
]
class ClusterCombination(combinations_lib.TestCombination):
"""Sets up multi worker tests."""
def parameter_modifiers(self):
return [ClusterParameters()]
class GPUCombination(combinations_lib.TestCombination):
"""Enable tests to request GPU hardware and skip non-GPU combinations.
This class expects test_combinations to be generated with `NamedDistribution`
wrapping instances of `tf.distribute.Strategy`.
Optionally, the `required_gpus` argument is supported. GPU hardware is
required, if its value is `True` or > 0.
Attributes:
GPU_TEST: The environment is considered to have GPU hardware available if
the name of the program contains "test_gpu" or "test_xla_gpu".
"""
GPU_TEST = re.search(r"(test_gpu|test_xla_gpu)$", sys.argv[0])
def should_execute_combination(self, kwargs):
distributions = [
v for v in kwargs.values() if isinstance(v, NamedDistribution)
]
required_gpus = kwargs.get("required_gpus", None)
if distributions and required_gpus:
raise ValueError("Do not use `required_gpus` and arguments of type "
"NamedDistribution together.")
number_of_required_gpus = max([required_gpus or 0] +
[d.required_gpus or 0 for d in distributions])
if not number_of_required_gpus and GPUCombination.GPU_TEST:
return (False, "Test that doesn't require GPUs.")
elif (number_of_required_gpus > 0
and context.num_gpus() < number_of_required_gpus):
return (False, ("Only {} of {} required GPUs are available.".format(
context.num_gpus(), number_of_required_gpus)))
else:
return (True, None)
def parameter_modifiers(self):
return [combinations_lib.OptionalParameter("required_gpus")]
class TPUCombination(combinations_lib.TestCombination):
"""Allow to request TPU hardware and skip non-TPU combinations.
This class expects test_combinations to be generated with `NamedDistribution`
wrapping instances of `tf.distribute.Strategy`.
Optionally, the `required_tpus` parameter is supported. TPU hardware is
required, if its argument is `True` or > 0.
Optionally, the `use_cloud_tpu` parameter is supported. If TPU hardware is
required by `required_tpus`, it specifically must be a Cloud TPU (specified
with `--tpu`) if `use_cloud_tpu` is `True`.
Attributes:
TPU_TEST: The environment is considered to have TPU hardware available if
the name of the program contains "test_tpu".
"""
TPU_TEST = "test_tpu" in sys.argv[0]
def should_execute_combination(self, kwargs):
distributions = [
v for v in kwargs.values() if isinstance(v, NamedDistribution)
]
# TODO(isaprykin): Migrate all tests away from using 'required_tpu' in favor
# of 'required_tpus'.
if "required_tpus" in kwargs and "required_tpu" in kwargs:
raise ValueError("Do not use `required_tpu`. Both `required_tpus` and "
"`required_tpu` were specified.")
required_tpus = kwargs.get("required_tpus", None) or kwargs.get(
"required_tpu", None)
if distributions and required_tpus:
raise ValueError("Do not use `required_tpus` and arguments of type "
"NamedDistribution together.")
# TODO(isaprykin): Add support for a particular number of TPUs. Right now
# it's binary.
number_of_required_tpus = max([required_tpus or 0] +
[d.required_tpu or 0 for d in distributions])
use_cloud_tpu = any([kwargs.get("use_cloud_tpu")] +
[d.use_cloud_tpu for d in distributions])
tpu = hasattr(flags.FLAGS, "tpu") and flags.FLAGS.tpu or ""
if not number_of_required_tpus and TPUCombination.TPU_TEST:
return (False, "Test that doesn't require TPUs.")
if number_of_required_tpus and not TPUCombination.TPU_TEST:
return (False, "Test requires a TPU, but it's not available.")
if use_cloud_tpu and not tpu:
return (False, "Test requires a Cloud TPU, but none specified.")
if not use_cloud_tpu and tpu:
return (False, "Test requires local TPU, but Cloud TPU specified.")
return (True, None)
def parameter_modifiers(self):
return [
combinations_lib.OptionalParameter("required_tpus"),
combinations_lib.OptionalParameter("required_tpu"),
combinations_lib.OptionalParameter("use_cloud_tpu"),
]
class NamedDistribution(object):
"""Wraps a `tf.distribute.Strategy` and adds a name for test titles."""
def __init__(self,
name,
distribution_fn,
required_gpus=None,
required_tpu=False,
use_cloud_tpu=False,
has_chief=False,
num_workers=1,
pool_runner_fn=None,
no_xla=False):
"""Initialize NamedDistribution.
Args:
name: Name that will be a part of the name of the test case.
distribution_fn: A callable that creates a `tf.distribute.Strategy`.
required_gpus: The number of GPUs that the strategy requires.
required_tpu: Whether the strategy requires TPU.
use_cloud_tpu: Whether the strategy requires cloud TPU.
has_chief: Whether the strategy requires a chief worker.
num_workers: The number of workers that the strategy requires.
pool_runner_fn: An optional callable that returns a MultiProcessPoolRunner
to run the test.
no_xla: Whether to skip in XLA tests.
"""
object.__init__(self)
self._name = name
self._distribution_fn = distribution_fn
self.required_gpus = required_gpus
self.required_tpu = required_tpu
self.use_cloud_tpu = use_cloud_tpu
self.has_chief = has_chief
self.num_workers = num_workers
self._pool_runner_fn = pool_runner_fn
self.no_xla = no_xla
@property
def runner(self):
if self._pool_runner_fn is not None:
return self._pool_runner_fn()
return None
@property
def strategy(self):
return self._distribution_fn()
def __repr__(self):
return self._name
def concat(*combined):
"""Concats combinations."""
result = []
for one in combined:
result += one
return result
@tf_export("__internal__.distribute.combinations.generate", v1=[])
def generate(combinations, test_combinations=()):
# pylint: disable=g-doc-args,g-doc-return-or-yield
"""Distributed adapter of `tf.__internal__.test.combinations.generate`.
All tests with distributed strategy should use this one instead of
`tf.__internal__.test.combinations.generate`. This function has support of
strategy combinations, GPU/TPU and multi worker support.
See `tf.__internal__.test.combinations.generate` for usage.
"""
# pylint: enable=g-doc-args,g-doc-return-or-yield
default_combinations = (
framework_combinations.EagerGraphCombination(),
framework_combinations.TFVersionCombination(),
ClusterCombination(),
DistributionCombination(),
GPUCombination(),
TPUCombination(),
)
# We apply our own decoration to handle multi worker tests before applying
# framework.test_combinations.generate. The order is important since we need
# framework.test_combinations.generate to apply all parameter modifiers first.
combination_decorator = combinations_lib.generate(
combinations, test_combinations=default_combinations + test_combinations)
def decorator(test_method_or_class):
if isinstance(test_method_or_class, type):
# If it's a test class.
class_object = test_method_or_class
# Decorate each test method with _multi_worker_test.
for name, test_method in six.iteritems(class_object.__dict__.copy()):
if (name.startswith(unittest.TestLoader.testMethodPrefix) and
isinstance(test_method, types.FunctionType)):
setattr(class_object, name, _multi_worker_test(test_method))
return combination_decorator(class_object)
else:
return combination_decorator(_multi_worker_test(test_method_or_class))
return decorator
combine = combinations_lib.combine
times = combinations_lib.times
NamedObject = combinations_lib.NamedObject
# Identifies whether we're in the main process or worker processes.
# `_multi_worker_test` decoration behaves differently in the main processs and
# the worker processes. See the documentation of _multi_worker_test for detail.
_running_in_worker = False
_TestResult = collections.namedtuple("_TestResult", ["status", "message"])
def _test_runner(test_id):
"""Executes the test with the given test_id.
This is a simple wrapper around TestRunner to be used with
multi_process_runner. Similar to test.main(), but it executes only one test
specified by test_id and returns whether the test succeeds. If the test fails,
the function prints failures and errors to stdout.
Args:
test_id: TestCase.id()
Returns:
A boolean indicates whether the test succeeds.
"""
global _running_in_worker
# No need to restore the value of _running_in_worker since it should always be
# True in worker processes.
_running_in_worker = True
test = unittest.defaultTestLoader.loadTestsFromName(test_id)
runner = unittest.TextTestRunner()
result = runner.run(test)
# Treat expected failures as failures, so that the main process can get
# them and fail as expected. Also treat errors as failures to simplify the
# handling.
failures = result.failures + result.expectedFailures + result.errors
if failures:
ret = _TestResult(status="failure", message=failures[0][1])
elif result.skipped:
ret = _TestResult(status="skipped", message=result.skipped[0][1])
else:
# Treat unexpectedSuccesses as OK so that the test case in the main process
# succeed as well.
ret = _TestResult(status="ok", message=None)
# Print tracebacks to stdout and multi_process_runner will collect
# them and stream back to the main process.
if ret.message:
print(ret.message)
return ret
def _multi_worker_test(test_method):
"""Decorate test_method so that it runs in each worker.
We use `multi_process_runner` to simulate multiple workers. Since we run the
this function in the main process and all worker processes, this decoration
behaves differently in the main process and worker procssses. In the main
process, it spawns subprocesses and runs the test on each of them; in a worker
process, it executes test in the same way as a normal test, e.g.
setUp()/tearDown() are called before/after the test.
Args:
test_method: a function which must be a test method.
Returns:
Decorated `test_method`. Note that the decorated function has additional
arguments.
"""
def decorator(self, has_chief, num_workers, runner, **kwargs):
if _num_total_workers(has_chief, num_workers) == 1 or _running_in_worker:
# We're in worker process or the test is for single worker. Either case we
# execute the test method directly instead of spawning subprocesses.
# For MultiWorkerMirroredStrategy(CollectiveAllReduceStrategy), install a
# session that connects to the local server. This is necessary for multi
# worker graph mode tests to work. Those tests cannot use their graphs or
# sessions, including the one returned by self.cached_session(). Since
# existing tests may already be doing so, we only install the session for
# multi worker tests.
with _multi_worker_session(kwargs):
test_method(self, **kwargs)
return
# We're in the main process. We spawn subprocesses and run the *test* on
# each of them. Note that we're not directly executing test_method passed to
# _multi_worker_test, because we need setUp()/tearDown() to be called and
# all the decorations on the test method. The conceptual call stack is:
# [main process]test.main()
# [main process]test_runner.run(test)
# [main process]wrapper by combinations.generate()
# [main process]_multi_worker_test.decorator()
# # A sub process goes through the same code path as the main
# # process.
# [sub process]_test_runner()
# [sub process]test_runner.run(test)
# [sub process]wrapper by combinations.generate()
# [sub process]_multi_worker_test.decorator()
# # _running_in_worker is True
# [sub process]test_method()
test_id = self.id()
if runner:
results = runner.run(_test_runner, args=(test_id,))
else:
cluster_spec = multi_worker_test_base.create_cluster_spec(
has_chief=has_chief,
num_workers=num_workers,
num_ps=0,
has_eval=False)
results = multi_process_runner.run(
_test_runner, cluster_spec, args=(test_id,)).return_value
skip_reason = None
for result in results:
if result.status == "failure":
# We can't tell which worker the return value come from, so we fail on
# the first error.
self.fail(result.message)
break
elif result.status == "skipped":
# Record the skip reason, but do not actually skip the test in case some
# processes fail instead.
skip_reason = result.message
if skip_reason is not None:
self.skipTest(skip_reason)
argspec = tf_inspect.getfullargspec(test_method)
decorator_args = (argspec.args or []) + ["has_chief", "num_workers", "runner"]
decorator_argspec = argspec._replace(args=decorator_args)
return tf_decorator.make_decorator(
test_method, decorator, decorator_argspec=decorator_argspec)
def _num_total_workers(has_chief, num_workers):
"""Returns the number of workers including the chief."""
if has_chief:
return num_workers + 1
return num_workers
def _multi_worker_session(kwargs):
"""Returns a context manager that enters a session that is configured for the MultiWorkerMirroredStrategy.
Args:
kwargs: a dict. Keyword arguments passed to the test.
Returns:
A context manager. If MultiWorkerMirroredStrategy is the one and only one
strategy in kwargs and it's in graph mode, it's the seesion that is
configured for that strategy. Otherwise, it's a no-op context manager.
"""
strategy = None
for _, v in kwargs.items():
if isinstance(v, distribute_lib.StrategyBase):
if strategy is not None:
logging.warning(
"The test uses multiple strategies. Skipping "
"entering a session that is configured for the strategy.")
return ops.NullContextmanager()
strategy = v
if context.executing_eagerly() or not isinstance(
strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy):
return ops.NullContextmanager()
sess_config = copy.deepcopy(context.context().config)
sess_config = strategy.update_config_proto(sess_config)
target = strategy.cluster_resolver.master()
return session.Session(config=sess_config, target=target).as_default()