diff --git a/third_party/gpus/crosstool/BUILD.tpl b/third_party/gpus/crosstool/BUILD.tpl
index 5a78654a90f..bc92f91a777 100644
--- a/third_party/gpus/crosstool/BUILD.tpl
+++ b/third_party/gpus/crosstool/BUILD.tpl
@@ -68,6 +68,7 @@ cc_toolchain_config(
     linker_bin_path = "%{linker_bin_path}",
     builtin_sysroot = "%{builtin_sysroot}",
     cuda_path = "%{cuda_toolkit_path}",
+    compiler = "%{compiler}",
 )
 
 cc_toolchain(
@@ -124,6 +125,7 @@ cc_toolchain_config(
     msvc_lib_path = "%{msvc_lib_path}",
     msvc_link_path = "%{msvc_link_path}",
     msvc_ml_path = "%{msvc_ml_path}",
+    compiler = "msvc",
 )
 
 filegroup(
diff --git a/third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl b/third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl
index 3d4d41aa2b1..7b249c0c606 100644
--- a/third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl
+++ b/third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl
@@ -626,48 +626,82 @@ def _impl(ctx):
         ],
     )
 
-    default_compile_flags_feature = feature(
-        name = "default_compile_flags",
-        enabled = True,
-        flag_sets = [
-            flag_set(
-                actions = [
-                    ACTION_NAMES.assemble,
-                    ACTION_NAMES.preprocess_assemble,
-                    ACTION_NAMES.linkstamp_compile,
-                    ACTION_NAMES.c_compile,
-                    ACTION_NAMES.cpp_compile,
-                    ACTION_NAMES.cpp_header_parsing,
-                    ACTION_NAMES.cpp_module_compile,
-                    ACTION_NAMES.cpp_module_codegen,
-                    ACTION_NAMES.lto_backend,
-                    ACTION_NAMES.clif_match,
-                ],
-                flag_groups = [
-                    flag_group(
-                        flags = [
-                            "/DCOMPILER_MSVC",
-                            "/DNOMINMAX",
-                            "/D_WIN32_WINNT=0x0600",
-                            "/D_CRT_SECURE_NO_DEPRECATE",
-                            "/D_CRT_SECURE_NO_WARNINGS",
-                            "/D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS",
-                            "/bigobj",
-                            "/Zm500",
-                            "/J",
-                            "/Gy",
-                            "/GF",
-                            "/EHsc",
-                            "/wd4351",
-                            "/wd4291",
-                            "/wd4250",
-                            "/wd4996",
-                        ],
-                    ),
-                ],
-            ),
-        ],
-    )
+    if ctx.attr.compiler == "clang":
+      default_compile_flags_feature = feature(
+          name = "default_compile_flags",
+          enabled = True,
+          flag_sets = [
+              flag_set(
+                  actions = [
+                      ACTION_NAMES.assemble,
+                      ACTION_NAMES.preprocess_assemble,
+                      ACTION_NAMES.linkstamp_compile,
+                      ACTION_NAMES.c_compile,
+                      ACTION_NAMES.cpp_compile,
+                      ACTION_NAMES.cpp_header_parsing,
+                      ACTION_NAMES.cpp_module_compile,
+                      ACTION_NAMES.cpp_module_codegen,
+                      ACTION_NAMES.lto_backend,
+                      ACTION_NAMES.clif_match,
+                  ],
+                  flag_groups = [
+                      flag_group(
+                          flags = [
+                              "-fexperimental-new-pass-manager",
+                          ],
+                      ),
+                  ],
+              ),
+          ],
+      )
+
+    elif ctx.attr.compiler == "msvc":
+      default_compile_flags_feature = feature(
+          name = "default_compile_flags",
+          enabled = True,
+          flag_sets = [
+              flag_set(
+                  actions = [
+                      ACTION_NAMES.assemble,
+                      ACTION_NAMES.preprocess_assemble,
+                      ACTION_NAMES.linkstamp_compile,
+                      ACTION_NAMES.c_compile,
+                      ACTION_NAMES.cpp_compile,
+                      ACTION_NAMES.cpp_header_parsing,
+                      ACTION_NAMES.cpp_module_compile,
+                      ACTION_NAMES.cpp_module_codegen,
+                      ACTION_NAMES.lto_backend,
+                      ACTION_NAMES.clif_match,
+                  ],
+                  flag_groups = [
+                      flag_group(
+                          flags = [
+                              "/DCOMPILER_MSVC",
+                              "/DNOMINMAX",
+                              "/D_WIN32_WINNT=0x0600",
+                              "/D_CRT_SECURE_NO_DEPRECATE",
+                              "/D_CRT_SECURE_NO_WARNINGS",
+                              "/D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS",
+                              "/bigobj",
+                              "/Zm500",
+                              "/J",
+                              "/Gy",
+                              "/GF",
+                              "/EHsc",
+                              "/wd4351",
+                              "/wd4291",
+                              "/wd4250",
+                              "/wd4996",
+                          ],
+                      ),
+                  ],
+              ),
+          ],
+      )
+
+    else:
+      default_compile_flags_feature = feature(
+          name = "default_compile_flags")
 
     static_link_msvcrt_debug_feature = feature(
         name = "static_link_msvcrt_debug",
@@ -1320,6 +1354,7 @@ def _impl(ctx):
 
     if (ctx.attr.cpu == "local"):
         features = [
+            default_compile_flags_feature,
             cpp11_feature,
             stdlib_feature,
             determinism_feature,
@@ -1510,6 +1545,7 @@ cc_toolchain_config = rule(
         "msvc_lib_path": attr.string(default = "msvc_not_used"),
         "msvc_link_path": attr.string(default = "msvc_not_used"),
         "msvc_ml_path": attr.string(default = "msvc_not_used"),
+        "compiler": attr.string(values = ["clang", "msvc", "unknown"], default="unknown"),
     },
     provides = [CcToolchainConfigInfo],
     executable = True,
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index a4eccc4d235..8fa64f264dc 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -1024,8 +1024,10 @@ def _create_local_cuda_repository(repository_ctx):
     cuda_defines = {}
     cuda_defines["%{builtin_sysroot}"] = tf_sysroot
     cuda_defines["%{cuda_toolkit_path}"] = ""
+    cuda_defines["%{compiler}"] = "unknown"
     if is_cuda_clang:
         cuda_defines["%{cuda_toolkit_path}"] = cuda_config.config["cuda_toolkit_path"]
+        cuda_defines["%{compiler}"] = "clang"
 
     host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX)
     if not host_compiler_prefix: