STT-tensorflow/third_party/toolchains/remote/configure.bzl
A. Unique TensorFlower 6dd6ad9fd7 Set up remote GPU testing.
Currently, we set the tag "local" for GPU tests in order to be able to execute
remote CPU and local GPU tests within the same bazel invocation.

This change introduces the possibility to set REMOTE_GPU_TESTING to enable
GPU tests to also run remotely; given that tags cannot use starlark's select,
we use an autoconfig rule that defines a function returning the tags we want:
"local" by default and "remote-gpu" if REMOTE_GPU_TESTING is set.

The platform is set via exec_compatible_with constraints, so we select on the
"remote-gpu" tag to add a constraint that is only fulfilled by GPU-enabled
platforms.

PiperOrigin-RevId: 229141861
2019-01-14 01:38:12 -08:00

44 lines
1.5 KiB
Python

"""Repository rule for remote GPU autoconfiguration.
This rule creates the starlark file
//third_party/toolchains/remote:execution.bzl
providing the function `gpu_test_tags`.
`gpu_test_tags` will return:
* `local`: if `REMOTE_GPU_TESTING` is false, allowing CPU tests to run
remotely and GPU tests to run locally in the same bazel invocation.
* `remote-gpu`: if `REMOTE_GPU_TESTING` is true; this allows rules to
set an execution requirement that enables a GPU-enabled remote platform.
"""
_REMOTE_GPU_TESTING = "REMOTE_GPU_TESTING"
def _flag_enabled(repository_ctx, flag_name):
if flag_name not in repository_ctx.os.environ:
return False
return repository_ctx.os.environ[flag_name].strip() == "1"
def _remote_execution_configure(repository_ctx):
# If we do not support remote gpu test execution, mark them as local, so we
# can combine remote builds with local gpu tests.
gpu_test_tags = "\"local\""
if _flag_enabled(repository_ctx, _REMOTE_GPU_TESTING):
gpu_test_tags = "\"remote-gpu\""
repository_ctx.template(
"remote_execution.bzl",
Label("//third_party/toolchains/remote:execution.bzl.tpl"),
{
"%{gpu_test_tags}": gpu_test_tags,
},
)
repository_ctx.template(
"BUILD",
Label("//third_party/toolchains/remote:BUILD.tpl"),
)
remote_execution_configure = repository_rule(
implementation = _remote_execution_configure,
environ = [_REMOTE_GPU_TESTING],
)