From 44b6753cc21852c4eb1acf7a7a5e45de6b001384 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg@google.com>
Date: Wed, 3 Mar 2021 08:29:03 -0800
Subject: [PATCH] Make '--define=using_cuda_nvcc' temporarily working again
 until downstream projects have transitioned.

PiperOrigin-RevId: 360675718
Change-Id: Iacf19f732bd49258a1db6467e114effcbcdf6b38
---
 third_party/gpus/cuda/build_defs.bzl.tpl | 19 +++++++++++++++++++
 third_party/gpus/local_config_cuda.BUILD | 22 ++++++++++++++++------
 2 files changed, 35 insertions(+), 6 deletions(-)

diff --git a/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/gpus/cuda/build_defs.bzl.tpl
index 3c6e23fec5d..e38e285423f 100644
--- a/third_party/gpus/cuda/build_defs.bzl.tpl
+++ b/third_party/gpus/cuda/build_defs.bzl.tpl
@@ -96,3 +96,22 @@ def cuda_header_library(
 def cuda_library(copts = [], **kwargs):
     """Wrapper over cc_library which adds default CUDA options."""
     native.cc_library(copts = cuda_default_copts() + copts, **kwargs)
+
+EnableCudaInfo = provider()
+
+def _enable_cuda_flag_impl(ctx):
+    value = ctx.build_setting_value
+    if ctx.attr.enable_override:
+        print(
+            "\n\033[1;33mWarning:\033[0m '--define=using_cuda_nvcc' will be " +
+            "unsupported soon. Use '--@local_config_cuda//:enable_cuda' " +
+            "instead."
+        )
+        value = True
+    return EnableCudaInfo(value = value)
+
+enable_cuda_flag = rule(
+    implementation = _enable_cuda_flag_impl,
+    build_setting = config.bool(flag = True),
+    attrs = {"enable_override": attr.bool()},
+)
diff --git a/third_party/gpus/local_config_cuda.BUILD b/third_party/gpus/local_config_cuda.BUILD
index 52cfd31a135..e289c23dbe0 100644
--- a/third_party/gpus/local_config_cuda.BUILD
+++ b/third_party/gpus/local_config_cuda.BUILD
@@ -1,8 +1,5 @@
-load(
-    "@bazel_skylib//rules:common_settings.bzl",
-    "bool_flag",
-    "string_flag",
-)
+load("@local_config_cuda//cuda:build_defs.bzl", "enable_cuda_flag")
+load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
 
 package(default_visibility = ["//visibility:public"])
 
@@ -10,9 +7,13 @@ package(default_visibility = ["//visibility:public"])
 #
 # Enable with '--@local_config_cuda//:enable_cuda', or indirectly with
 # ./configure or '--config=cuda'.
-bool_flag(
+enable_cuda_flag(
     name = "enable_cuda",
     build_setting_default = False,
+    enable_override = select({
+        ":define_using_cuda_nvcc": True,
+        "//conditions:default": False,
+    }),
 )
 
 # Config setting whether CUDA support has been requested.
@@ -48,3 +49,12 @@ config_setting(
     name = "is_cuda_compiler_nvcc",
     flag_values = {":cuda_compiler": "nvcc"},
 )
+
+# Config setting to keep `--define=using_cuda_nvcc=true` working.
+# TODO(b/174244321): Remove when downstream projects have been fixed, along
+# with the enable_cuda_flag rule in cuda:build_defs.bzl.tpl.
+config_setting(
+    name = "define_using_cuda_nvcc",
+    define_values = {"using_cuda_nvcc": "true"},
+    visibility = ["//visibility:private"],
+)