diff --git a/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/gpus/cuda/build_defs.bzl.tpl index 3c6e23fec5d..e38e285423f 100644 --- a/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/gpus/cuda/build_defs.bzl.tpl @@ -96,3 +96,22 @@ def cuda_header_library( def cuda_library(copts = [], **kwargs): """Wrapper over cc_library which adds default CUDA options.""" native.cc_library(copts = cuda_default_copts() + copts, **kwargs) + +EnableCudaInfo = provider() + +def _enable_cuda_flag_impl(ctx): + value = ctx.build_setting_value + if ctx.attr.enable_override: + print( + "\n\033[1;33mWarning:\033[0m '--define=using_cuda_nvcc' will be " + + "unsupported soon. Use '--@local_config_cuda//:enable_cuda' " + + "instead." + ) + value = True + return EnableCudaInfo(value = value) + +enable_cuda_flag = rule( + implementation = _enable_cuda_flag_impl, + build_setting = config.bool(flag = True), + attrs = {"enable_override": attr.bool()}, +) diff --git a/third_party/gpus/local_config_cuda.BUILD b/third_party/gpus/local_config_cuda.BUILD index 52cfd31a135..e289c23dbe0 100644 --- a/third_party/gpus/local_config_cuda.BUILD +++ b/third_party/gpus/local_config_cuda.BUILD @@ -1,8 +1,5 @@ -load( - "@bazel_skylib//rules:common_settings.bzl", - "bool_flag", - "string_flag", -) +load("@local_config_cuda//cuda:build_defs.bzl", "enable_cuda_flag") +load("@bazel_skylib//rules:common_settings.bzl", "string_flag") package(default_visibility = ["//visibility:public"]) @@ -10,9 +7,13 @@ package(default_visibility = ["//visibility:public"]) # # Enable with '--@local_config_cuda//:enable_cuda', or indirectly with # ./configure or '--config=cuda'. -bool_flag( +enable_cuda_flag( name = "enable_cuda", build_setting_default = False, + enable_override = select({ + ":define_using_cuda_nvcc": True, + "//conditions:default": False, + }), ) # Config setting whether CUDA support has been requested. @@ -48,3 +49,12 @@ config_setting( name = "is_cuda_compiler_nvcc", flag_values = {":cuda_compiler": "nvcc"}, ) + +# Config setting to keep `--define=using_cuda_nvcc=true` working. +# TODO(b/174244321): Remove when downstream projects have been fixed, along +# with the enable_cuda_flag rule in cuda:build_defs.bzl.tpl. +config_setting( + name = "define_using_cuda_nvcc", + define_values = {"using_cuda_nvcc": "true"}, + visibility = ["//visibility:private"], +)