diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h
index d3f5dd3e662..c55f5750da7 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.h
+++ b/tensorflow/compiler/xla/client/executable_build_options.h
@@ -115,6 +115,16 @@ class ExecutableBuildOptions {
     return *this;
   }
 
+  // Thread pool for parallel compilation.
+  tensorflow::thread::ThreadPool* compile_thread_pool() const {
+    return compile_thread_pool_;
+  }
+  ExecutableBuildOptions& set_run_backend_only(
+      tensorflow::thread::ThreadPool* compile_thread_pool) {
+    compile_thread_pool_ = compile_thread_pool;
+    return *this;
+  }
+
  private:
   int device_ordinal_ = -1;
   Shape result_layout_;
@@ -128,6 +138,7 @@ class ExecutableBuildOptions {
   absl::optional<DeviceAssignment> device_assignment_;
   bool alias_passthrough_params_ = false;
   bool run_backend_only_ = false;
+  tensorflow::thread::ThreadPool* compile_thread_pool_ = nullptr;
 };
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h
index 5cf84854f66..623b8262178 100644
--- a/tensorflow/compiler/xla/service/compiler.h
+++ b/tensorflow/compiler/xla/service/compiler.h
@@ -42,6 +42,7 @@ limitations under the License.
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
 #include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/threadpool.h"
 
 namespace xla {
 
@@ -165,6 +166,9 @@ class Compiler {
     // given algorithm over those buffers, to see which variant is fastest.  Any
     // space allocated will be deallocated before the compilation returns.
     se::DeviceMemoryAllocator* device_allocator = nullptr;
+
+    // An optional thread pool for parallel compilation.
+    tensorflow::thread::ThreadPool* thread_pool = nullptr;
   };
 
   virtual ~Compiler() {}
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 767917f26f3..f9bacdd8145 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -1452,7 +1452,11 @@ cc_library(
         "//tensorflow/stream_executor:stream_executor_headers",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
+        "@llvm-project//llvm:AsmParser",
+        "@llvm-project//llvm:BitReader",
+        "@llvm-project//llvm:BitWriter",
         "@llvm-project//llvm:Core",
+        "@llvm-project//llvm:TransformUtils",
         "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
         "@llvm-project//mlir:IR",
     ],
@@ -1517,7 +1521,7 @@ cc_library(
         "//tensorflow/stream_executor:stream_executor_headers",
         "//tensorflow/stream_executor/cuda:cuda_diagnostics",
         "//tensorflow/stream_executor/gpu:asm_compiler",
-    ]),
+    ]) + ["//tensorflow/stream_executor/gpu:gpu_driver_header"],
 )
 
 cc_library(
diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
index 974db02b1b3..f6409b476b5 100644
--- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
@@ -108,12 +108,17 @@ StatusOr<std::pair<std::string, std::vector<uint8>>>
 AMDGPUCompiler::CompileTargetBinary(const HloModule* module,
                                     llvm::Module* llvm_module,
                                     GpuVersion gpu_version,
-                                    se::StreamExecutor* stream_exec) {
+                                    se::StreamExecutor* stream_exec,
+                                    bool relocatable) {
   if (rocdl_dir_.empty()) {
     // Compute rocdl_dir_ just once and cache it in this member.
     rocdl_dir_ = GetROCDLDir(module->config());
   }
 
+  if (relocatable) {
+    return Unimplemented("relocatable target binary is not implemented");
+  }
+
   std::vector<uint8> hsaco;
   {
     XLA_SCOPED_LOGGING_TIMER(
diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
index acc5e021e3d..36318badeef 100644
--- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
@@ -41,7 +41,8 @@ class AMDGPUCompiler : public GpuCompiler {
 
   StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
       const HloModule* hlo_module, llvm::Module* llvm_module,
-      GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
+      GpuVersion gpu_version, se::StreamExecutor* stream_exec,
+      bool relocatable) override;
 
  private:
   // The parent directory of ROCm-Device-Libs IR libraries.
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 8a694860ed0..b5e3c14c791 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -24,11 +24,15 @@ limitations under the License.
 #include "absl/memory/memory.h"
 #include "absl/strings/numbers.h"
 #include "absl/strings/str_cat.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/Bitcode/BitcodeReader.h"
+#include "llvm/Bitcode/BitcodeWriter.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Verifier.h"
+#include "llvm/Transforms/Utils/SplitModule.h"
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
 #include "mlir/InitAllDialects.h"  // from @llvm-project
 #include "tensorflow/compiler/xla/protobuf_util.h"
@@ -114,11 +118,13 @@ limitations under the License.
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/blocking_counter.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/regexp.h"
 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
 #include "tensorflow/core/platform/subprocess.h"
+#include "tensorflow/core/platform/threadpool.h"
 #include "tensorflow/core/platform/tracing.h"
 #include "tensorflow/core/profiler/lib/traceme.h"
 #include "tensorflow/core/util/env_var.h"
@@ -641,6 +647,124 @@ static Status CompileModuleToLlvmIrImpl(
   return Status::OK();
 }
 
+StatusOr<std::pair<std::string, std::vector<uint8>>>
+GpuCompiler::CompileToTargetBinary(const HloModule& module,
+                                   std::unique_ptr<llvm::Module> llvm_module,
+                                   se::StreamExecutor* stream_exec,
+                                   const CompileOptions& options) {
+  using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
+
+  const auto compile_single_module =
+      [this, stream_exec, &module](
+          llvm::Module* llvm_module,
+          bool relocatable) -> StatusOr<BackendCompileResult> {
+    {
+      XLA_SCOPED_LOGGING_TIMER(
+          "GpuCompiler::RunBackend - Running LLVM verifier");
+
+      std::string err;
+      llvm::raw_string_ostream err_stream(err);
+
+      // verifyModule() returns true if the module is broken.
+      TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
+          << "Invalid LLVM IR before optimizations:\n"
+          << err_stream.str()
+          << "\nThis probably indicates a bug in the HLO -> LLVM IR "
+             "lowering. "
+             "Rerun with --xla_dump_to to get the IR and looks for files "
+             "with "
+             "name containing: *"
+          << FilenameFor(module, "", "") << "*";
+    }
+    GpuVersion gpu_version = GetGpuVersion(stream_exec);
+    return CompileTargetBinary(&module, llvm_module, gpu_version, stream_exec,
+                               relocatable);
+  };
+
+  tensorflow::thread::ThreadPool* thread_pool = options.thread_pool;
+  if (!thread_pool) {
+    return compile_single_module(llvm_module.get(), /*relocatable=*/false);
+  }
+
+  // Test whether LinkModules is supported.
+  if (this->LinkModules(stream_exec, {}).status().code() ==
+      tensorflow::error::Code::UNIMPLEMENTED) {
+    return compile_single_module(llvm_module.get(), /*relocatable=*/false);
+  }
+
+  std::vector<std::unique_ptr<llvm::Module>> llvm_modules;
+  int num_functions = 0;
+  for (llvm::Function& func : llvm_module->functions()) {
+    if (!func.isDeclaration() &&
+        func.getLinkage() == llvm::GlobalValue::LinkageTypes::ExternalLinkage) {
+      num_functions++;
+    }
+  }
+
+  llvm::SplitModule(
+      std::move(llvm_module),
+      std::max<unsigned>(
+          1, std::min<unsigned>(thread_pool->NumThreads(), num_functions)),
+      [&](std::unique_ptr<llvm::Module> module) {
+        llvm_modules.push_back(std::move(module));
+      },
+      /*PreserveLocals=*/true);
+
+  std::vector<StatusOr<BackendCompileResult>> compile_results(
+      llvm_modules.size());
+  tensorflow::BlockingCounter counter(llvm_modules.size());
+  for (int i = 0; i < llvm_modules.size(); i++) {
+    thread_pool->Schedule([&compile_results, compile_single_module, i,
+                           &llvm_modules, &counter] {
+      llvm::Module* original_module = llvm_modules[i].get();
+      llvm::LLVMContext context;
+      std::string buffer;
+      llvm::raw_string_ostream error(buffer);
+      llvm::DiagnosticPrinterRawOStream printer(error);
+      auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info,
+                                  void* Context) {
+        auto printer = static_cast<llvm::DiagnosticPrinterRawOStream*>(Context);
+        diag_info.print(*printer);
+      };
+      context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer);
+
+      std::unique_ptr<llvm::Module> new_llvm_module;
+      {
+        std::string ir;
+        {
+          llvm::raw_string_ostream os(ir);
+          original_module->print(os, nullptr);
+        }
+        llvm::SMDiagnostic err;
+        new_llvm_module = llvm::parseAssemblyString(ir, err, context);
+      }
+
+      compile_results[i] =
+          compile_single_module(new_llvm_module.get(), /*relocatable=*/true);
+      counter.DecrementCount();
+    });
+  }
+  counter.Wait();
+
+  std::string ptx_snippets;
+  std::vector<std::vector<uint8>> submodule_compile_results;
+  for (auto& maybe_result : compile_results) {
+    TF_ASSIGN_OR_RETURN(auto result, maybe_result);
+    if (result.second.empty()) {
+      continue;
+    }
+    ptx_snippets += result.first;
+    ptx_snippets += "\n";
+    submodule_compile_results.push_back(result.second);
+  }
+
+  TF_ASSIGN_OR_RETURN(
+      std::vector<uint8> backend_result,
+      this->LinkModules(stream_exec, std::move(submodule_compile_results)));
+
+  return std::make_pair(ptx_snippets, backend_result);
+}
+
 StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
     std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
     const CompileOptions& options) {
@@ -650,15 +774,6 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
   TF_RET_CHECK(stream_exec != nullptr);
 
   llvm::LLVMContext llvm_context;
-  std::string buffer;
-  llvm::raw_string_ostream error(buffer);
-  llvm::DiagnosticPrinterRawOStream printer(error);
-  auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info,
-                              void* Context) {
-    auto printer = static_cast<llvm::DiagnosticPrinterRawOStream*>(Context);
-    diag_info.print(*printer);
-  };
-  llvm_context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer);
 
   GpuDeviceInfo gpu_device_info;
   gpu_device_info.threads_per_block_limit =
@@ -724,34 +839,16 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
 
   llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false);
 
-  {
-    XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier");
-
-    std::string err;
-    llvm::raw_string_ostream err_stream(err);
-
-    // verifyModule() returns true if the module is broken.
-    TF_RET_CHECK(!llvm::verifyModule(*llvm_module, &err_stream))
-        << "Invalid LLVM IR before optimizations:\n"
-        << err_stream.str()
-        << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
-           "Rerun with --xla_dump_to to get the IR and looks for files with "
-           "name containing: *"
-        << FilenameFor(*module, "", "") << "*";
-  }
-
-  GpuVersion gpu_version = GetGpuVersion(stream_exec);
-
   using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
   TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
-                      CompileTargetBinary(module.get(), llvm_module.get(),
-                                          gpu_version, stream_exec));
-
+                      CompileToTargetBinary(*module, std::move(llvm_module),
+                                            stream_exec, options));
   if (DumpingEnabledForHloModule(*module)) {
     DumpToFileInDirOrStdout(*module, "", "thunk_schedule",
                             thunk_schedule->ToString());
   }
 
+  GpuVersion gpu_version = GetGpuVersion(stream_exec);
   auto* gpu_executable = new GpuExecutable(
       backend_result.first, backend_result.second, gpu_version,
       std::move(thunk_schedule), std::move(module),
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
index 4e57aa9def7..1d42976e352 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
@@ -83,8 +83,8 @@ class GpuCompiler : public LLVMCompiler {
 
   virtual StatusOr<std::pair<std::string, std::vector<uint8>>>
   CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module,
-                      GpuVersion gpu_version,
-                      se::StreamExecutor* stream_exec) = 0;
+                      GpuVersion gpu_version, se::StreamExecutor* stream_exec,
+                      bool relocatable) = 0;
 
   Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
 
@@ -96,6 +96,10 @@ class GpuCompiler : public LLVMCompiler {
   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
                      AotCompilationOptions const& options) override;
 
+  StatusOr<std::pair<std::string, std::vector<uint8>>> CompileToTargetBinary(
+      const HloModule& module, std::unique_ptr<llvm::Module> llvm_module,
+      se::StreamExecutor* stream_exec, const CompileOptions& options);
+
   se::Platform::Id PlatformId() const override { return platform_id_; }
 
   HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
@@ -115,6 +119,12 @@ class GpuCompiler : public LLVMCompiler {
   }
 
  private:
+  virtual StatusOr<std::vector<uint8>> LinkModules(
+      se::StreamExecutor* stream_exec,
+      std::vector<std::vector<uint8>> modules) {
+    return Unimplemented("LinkModules is not implemented.");
+  }
+
   se::Platform::Id platform_id_;
 
   // The triple that represents our target.
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 3225cd2d5a5..070b8a1fcfb 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -51,6 +51,7 @@ limitations under the License.
 #include "tensorflow/core/profiler/lib/traceme.h"
 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
 #include "tensorflow/stream_executor/gpu/asm_compiler.h"
+#include "tensorflow/stream_executor/gpu/gpu_driver.h"
 
 namespace xla {
 namespace gpu {
@@ -299,7 +300,8 @@ StatusOr<std::pair<std::string, std::vector<uint8>>>
 NVPTXCompiler::CompileTargetBinary(const HloModule* module,
                                    llvm::Module* llvm_module,
                                    GpuVersion gpu_version,
-                                   se::StreamExecutor* stream_exec) {
+                                   se::StreamExecutor* stream_exec,
+                                   bool relocatable) {
   std::pair<int, int> compute_capability =
       absl::get<std::pair<int, int>>(gpu_version);
 
@@ -338,7 +340,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module,
 
   std::vector<uint8> cubin = CompileGpuAsmOrGetCachedResult(
       stream_exec, ptx, compute_capability.first, compute_capability.second,
-      module->config());
+      module->config(), relocatable);
 
   return std::pair<std::string, std::vector<uint8>>(std::move(ptx),
                                                     std::move(cubin));
@@ -346,7 +348,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module,
 
 std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
     se::StreamExecutor* stream_exec, const string& ptx, int cc_major,
-    int cc_minor, const HloModuleConfig& hlo_module_config) {
+    int cc_minor, const HloModuleConfig& hlo_module_config, bool relocatable) {
   XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompileGpuAsmOrGetCachedResult");
   tensorflow::profiler::TraceMe activity(
       "PTX->CUBIN", tensorflow::profiler::TraceMeLevel::kInfo);
@@ -361,7 +363,7 @@ std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
     tensorflow::mutex_lock lock(mutex_);
     std::tie(iter, inserted) = compilation_cache_.emplace(
         std::piecewise_construct,
-        std::forward_as_tuple(ptx, cc_major, cc_minor),
+        std::forward_as_tuple(ptx, cc_major, cc_minor, relocatable),
         std::forward_as_tuple());
     cache_ptx = &iter->first.ptx;
     cache_value = &iter->second;
@@ -375,9 +377,13 @@ std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
     if (inserted) {
       CHECK(!cache_value->compilation_done);
       if (!ptx.empty()) {
-        StatusOr<std::vector<uint8>> maybe_cubin =
-            se::CompileGpuAsm(stream_exec->device_ordinal(), cache_ptx->c_str(),
-                              PtxOptsFromConfig(hlo_module_config));
+        auto ptxas_config = PtxOptsFromConfig(hlo_module_config);
+        if (relocatable) {
+          ptxas_config.extra_flags.push_back("-c");
+        }
+        StatusOr<std::vector<uint8>> maybe_cubin = se::CompileGpuAsm(
+            stream_exec->device_ordinal(), cache_ptx->c_str(), ptxas_config);
+
         if (maybe_cubin.ok()) {
           cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie();
           VLOG(2) << "Compiled PTX size:" << ptx.size()
@@ -445,5 +451,17 @@ std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
   return cache_value->cubin_data;
 }
 
+StatusOr<std::vector<uint8>> NVPTXCompiler::LinkModules(
+    se::StreamExecutor* stream_exec, std::vector<std::vector<uint8>> modules) {
+  std::vector<stream_executor::CubinOrPTXImage> images;
+  images.reserve(modules.size());
+  for (auto& module : modules) {
+    images.push_back({"", std::move(module)});
+  }
+  return LinkGpuAsm(static_cast<se::gpu::GpuContext*>(
+                        stream_exec->implementation()->GpuContextHack()),
+                    images);
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
index 3e19b35af19..5c78b48b9c6 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
@@ -52,9 +52,14 @@ class NVPTXCompiler : public GpuCompiler {
 
   StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
       const HloModule* hlo_module, llvm::Module* llvm_module,
-      GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
+      GpuVersion gpu_version, se::StreamExecutor* stream_exec,
+      bool relocatable) override;
 
  private:
+  StatusOr<std::vector<uint8>> LinkModules(
+      se::StreamExecutor* stream_exec,
+      std::vector<std::vector<uint8>> modules) override;
+
   tensorflow::mutex mutex_;
 
   // When compiling an HLO module, we need to find a path to the nvvm libdevice
@@ -71,7 +76,7 @@ class NVPTXCompiler : public GpuCompiler {
   // compiled cubin.  If compilation was unsuccessful, returns an empty vector.
   std::vector<uint8> CompileGpuAsmOrGetCachedResult(
       se::StreamExecutor* stream_exec, const string& ptx, int cc_major,
-      int cc_minor, const HloModuleConfig& hlo_module_config);
+      int cc_minor, const HloModuleConfig& hlo_module_config, bool relocatable);
 
   // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor}
   // -> cubin so we don't recompile the same ptx twice.  This is important for
@@ -86,24 +91,32 @@ class NVPTXCompiler : public GpuCompiler {
   // If compiling the ptx fails, we return an empty cubin, cross our fingers,
   // and leave compilation up to the driver.
   struct CompilationCacheKey {
-    CompilationCacheKey(std::string ptx, int cc_major, int cc_minor)
-        : ptx(std::move(ptx)), cc_major(cc_major), cc_minor(cc_minor) {}
+    CompilationCacheKey(std::string ptx, int cc_major, int cc_minor,
+                        bool relocatable)
+        : ptx(std::move(ptx)),
+          cc_major(cc_major),
+          cc_minor(cc_minor),
+          relocatable(relocatable) {}
     string ptx;
     int cc_major;
     int cc_minor;
+    bool relocatable;
   };
   struct CompilationCacheHash {
     size_t operator()(const CompilationCacheKey& key) const {
       return tensorflow::Hash64Combine(
-          tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx), key.cc_major),
-          key.cc_minor);
+          tensorflow::Hash64Combine(
+              tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx),
+                                        key.cc_major),
+              key.cc_minor),
+          key.relocatable);
     }
   };
   struct CompilationCacheEq {
     size_t operator()(const CompilationCacheKey& a,
                       const CompilationCacheKey& b) const {
       return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor &&
-             a.ptx == b.ptx;
+             a.ptx == b.ptx && a.relocatable == b.relocatable;
     }
   };
   struct CompilationCacheValue {
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 0eff81c9a0d..ea8c45d3d46 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -190,11 +190,12 @@ LocalService::CompileExecutables(
   // single partition computations are built using `BuildExecutables`, fix it,
   // and remove this special case (provided the performance if similar).
   if (build_options.num_partitions() == 1) {
-    TF_ASSIGN_OR_RETURN(
-        std::unique_ptr<Executable> executable,
-        BuildExecutable(proto, std::move(module_config), execute_backend_.get(),
-                        executor, build_options.device_allocator(),
-                        build_options.run_backend_only()));
+    TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
+                        BuildExecutable(proto, std::move(module_config),
+                                        execute_backend_.get(), executor,
+                                        {build_options.device_allocator(),
+                                         build_options.compile_thread_pool()},
+                                        build_options.run_backend_only()));
     std::vector<std::unique_ptr<Executable>> executables;
     executables.push_back(std::move(executable));
     return executables;
@@ -206,10 +207,12 @@ LocalService::CompileExecutables(
     std::vector<se::StreamExecutor*> executors(build_options.num_partitions(),
                                                executor);
 
-    return BuildExecutables({&proto}, std::move(module_configs),
-                            execute_backend_.get(), {executors},
-                            build_options.device_allocator(),
-                            build_options.run_backend_only());
+    return BuildExecutables(
+        /*module_protos=*/{&proto}, std::move(module_configs),
+        execute_backend_.get(), {executors},
+        Compiler::CompileOptions{build_options.device_allocator(),
+                                 build_options.compile_thread_pool()},
+        build_options.run_backend_only());
   }
 }
 
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index a6d23c18797..cf781b4fcdd 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -357,7 +357,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
     const std::vector<const HloModuleProto*>& module_protos,
     std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
     Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
-    se::DeviceMemoryAllocator* device_allocator, bool run_backend_only) {
+    const Compiler::CompileOptions& options, bool run_backend_only) {
   VLOG(1) << StrFormat("BuildExecutable on service %p", this);
 
   // Dump computation proto state if flag is set.
@@ -387,17 +387,15 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
 
   std::vector<std::unique_ptr<Executable>> executables;
   if (!run_backend_only) {
-    TF_ASSIGN_OR_RETURN(
-        executables,
-        backend->compiler()->Compile(std::move(module_group),
-                                     std::move(executors), device_allocator));
+    TF_ASSIGN_OR_RETURN(executables, backend->compiler()->Compile(
+                                         std::move(module_group),
+                                         std::move(executors), options));
   } else {
     auto modules = module_group->ConsumeModules();
     for (std::unique_ptr<HloModule>& module : modules) {
-      TF_ASSIGN_OR_RETURN(
-          std::unique_ptr<Executable> executable,
-          backend->compiler()->RunBackend(std::move(module), executors[0][0],
-                                          device_allocator));
+      TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
+                          backend->compiler()->RunBackend(
+                              std::move(module), executors[0][0], options));
       executables.push_back(std::move(executable));
     }
   }
@@ -710,7 +708,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
   TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
                       BuildExecutables(module_protos, std::move(module_configs),
                                        execute_backend_.get(), all_executors,
-                                       /*device_allocator=*/nullptr));
+                                       {/*device_allocator=*/nullptr}));
   std::vector<Executable*> executable_ptrs;
   executable_ptrs.reserve(executables.size());
   for (const auto& executable : executables) {
@@ -810,7 +808,7 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
 StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
     const HloModuleProto& module_proto,
     std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
-    se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator,
+    se::StreamExecutor* executor, const Compiler::CompileOptions& options,
     bool run_backend_only) {
   VLOG(1) << StrFormat(
       "BuildExecutable on service %p with serialized module proto: %s", this,
@@ -822,14 +820,13 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
   DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
 
   if (!run_backend_only) {
-    TF_ASSIGN_OR_RETURN(
-        module, backend->compiler()->RunHloPasses(std::move(module), executor,
-                                                  device_allocator));
+    TF_ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses(
+                                    std::move(module), executor, options));
   }
 
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
-                      backend->compiler()->RunBackend(
-                          std::move(module), executor, device_allocator));
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<Executable> executable,
+      backend->compiler()->RunBackend(std::move(module), executor, options));
 
   const auto& debug_opts = module_config->debug_options();
   if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) &&
@@ -875,7 +872,7 @@ Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
       BuildExecutable(arg->computation(), std::move(module_config),
                       execute_backend_.get(),
                       execute_backend_->default_stream_executor(),
-                      /*device_allocator=*/nullptr));
+                      {/*device_allocator=*/nullptr}));
 
   *result->mutable_handle() = compilation_cache_.Insert(std::move(executable));
 
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 712ccc44d91..02288bba475 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -235,8 +235,7 @@ class Service : public ServiceInterface {
   StatusOr<std::unique_ptr<Executable>> BuildExecutable(
       const HloModuleProto& module_proto,
       std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
-      se::StreamExecutor* executor,
-      se::DeviceMemoryAllocator* device_allocator = nullptr,
+      se::StreamExecutor* executor, const Compiler::CompileOptions& options,
       bool run_backend_only = false);
 
   // Same as BuildExecutable() above, but builds a list of Executables for the
@@ -245,8 +244,7 @@ class Service : public ServiceInterface {
       const std::vector<const HloModuleProto*>& module_protos,
       std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
       Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
-      se::DeviceMemoryAllocator* device_allocator,
-      bool run_backend_only = false);
+      const Compiler::CompileOptions& options, bool run_backend_only = false);
 
   // Runs the given executable with the given arguments and register the result
   // in the allocation tracker. The handle of the result from the tracker is
diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
index 0ddb01fc6ab..49e7560d2a8 100644
--- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
+++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
@@ -57,7 +57,8 @@ class GpuDummyCompiler : public GpuCompiler {
 
   StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
       const HloModule* hlo_module, llvm::Module* llvm_module,
-      GpuVersion gpu_version, se::StreamExecutor* stream_exec) {
+      GpuVersion gpu_version, se::StreamExecutor* stream_exec,
+      bool relocatable) {
     if (user_post_optimization_hook_) {
       user_post_optimization_hook_(*llvm_module);
     }
diff --git a/tensorflow/stream_executor/cuda/BUILD b/tensorflow/stream_executor/cuda/BUILD
index 0ee227d51f2..839950b1021 100644
--- a/tensorflow/stream_executor/cuda/BUILD
+++ b/tensorflow/stream_executor/cuda/BUILD
@@ -598,6 +598,16 @@ cc_library(
     ]),
 )
 
+cc_library(
+    name = "cuda_asm_compiler",
+    srcs = if_cuda_is_configured(["cuda_asm_compiler.cc"]),
+    deps = if_cuda_is_configured([
+        "//tensorflow/core:lib_proto_parsing",
+        "//tensorflow/stream_executor/gpu:asm_compiler",
+        "//tensorflow/stream_executor/gpu:gpu_driver_header",
+    ]),
+)
+
 cc_library(
     name = "cuda_gpu_executor",
     srcs = if_cuda_is_configured(["cuda_gpu_executor.cc"]),
@@ -611,6 +621,7 @@ cc_library(
         ":cuda_platform_id",
         ":cuda_stream",
         ":cuda_timer",
+        ":cuda_asm_compiler",
         "@com_google_absl//absl/strings",
         "//tensorflow/stream_executor:event",
         "//tensorflow/stream_executor:plugin_registry",
diff --git a/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc b/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc
new file mode 100644
index 00000000000..f92d3c487d0
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc
@@ -0,0 +1,55 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/stream_executor/gpu/asm_compiler.h"
+#include "tensorflow/stream_executor/gpu/gpu_driver.h"
+
+namespace stream_executor {
+
+#define RETURN_IF_CUDA_ERROR(expr)                                            \
+  do {                                                                        \
+    CUresult _status = expr;                                                  \
+    if (!SE_PREDICT_TRUE(_status == CUDA_SUCCESS)) {                          \
+      const char* error_string;                                               \
+      cuGetErrorString(_status, &error_string);                               \
+      std::ostringstream oss;                                                 \
+      oss << error_string << "\nin " << __FILE__ << "(" << __LINE__ << "): '" \
+          << #expr << "'";                                                    \
+      return port::Status(port::error::UNKNOWN, oss.str().c_str());           \
+    }                                                                         \
+  } while (false)
+
+port::StatusOr<std::vector<uint8>> LinkGpuAsm(
+    gpu::GpuContext* context, std::vector<CubinOrPTXImage> images) {
+  gpu::ScopedActivateContext activation(context);
+
+  CUlinkState link_state;
+  RETURN_IF_CUDA_ERROR(cuLinkCreate(0, nullptr, nullptr, &link_state));
+  for (auto& image : images) {
+    RETURN_IF_CUDA_ERROR(cuLinkAddData(
+        link_state, CU_JIT_INPUT_CUBIN, static_cast<void*>(image.bytes.data()),
+        image.bytes.size(), "", 0, nullptr, nullptr));
+  }
+  void* cubin_out;
+  size_t cubin_size;
+  RETURN_IF_CUDA_ERROR(cuLinkComplete(link_state, &cubin_out, &cubin_size));
+  std::vector<uint8> cubin(static_cast<uint8*>(cubin_out),
+                           static_cast<uint8*>(cubin_out) + cubin_size);
+  RETURN_IF_CUDA_ERROR(cuLinkDestroy(link_state));
+  return std::move(cubin);
+}
+
+}  // namespace stream_executor
diff --git a/tensorflow/stream_executor/gpu/asm_compiler.h b/tensorflow/stream_executor/gpu/asm_compiler.h
index 1ac58aaddf3..388f919a3c3 100644
--- a/tensorflow/stream_executor/gpu/asm_compiler.h
+++ b/tensorflow/stream_executor/gpu/asm_compiler.h
@@ -24,6 +24,9 @@ limitations under the License.
 #include "tensorflow/stream_executor/platform/port.h"
 
 namespace stream_executor {
+namespace gpu {
+class GpuContext;
+}
 
 // Compiles the given PTX string using ptxas and returns the resulting machine
 // code (i.e. a cubin) as a byte array. The generated cubin matches the compute
@@ -72,6 +75,11 @@ struct HsacoImage {
 port::StatusOr<std::vector<uint8>> BundleGpuAsm(
     std::vector<HsacoImage> images, const std::string rocm_root_dir);
 
+// Links multiple relocatable GPU images (e.g. results of ptxas -c) into a
+// single image.
+port::StatusOr<std::vector<uint8>> LinkGpuAsm(
+    gpu::GpuContext* context, std::vector<CubinOrPTXImage> images);
+
 }  // namespace stream_executor
 
 #endif  // TENSORFLOW_STREAM_EXECUTOR_GPU_ASM_COMPILER_H_