From 9f58e6902cea5d26e68635d1c766c2dc6125577a Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Tue, 12 May 2020 12:26:39 -0700 Subject: [PATCH] [XLA/GPU] Make Thunk::Initialize() happen at compile-time, not run-time. This simplifies GpuExecutable for MLIR transition. PiperOrigin-RevId: 311178815 Change-Id: Ib9c8b8a2f8719c0cd8b342ab07af6e8cb65d82bf --- tensorflow/compiler/xla/service/gpu/BUILD | 12 ++++++----- .../xla/service/gpu/amdgpu_compiler.cc | 10 ++++----- .../xla/service/gpu/amdgpu_compiler.h | 2 +- .../xla/service/gpu/conditional_thunk.cc | 4 ++-- .../xla/service/gpu/conditional_thunk.h | 2 +- .../compiler/xla/service/gpu/for_thunk.cc | 4 ++-- .../compiler/xla/service/gpu/for_thunk.h | 2 +- .../compiler/xla/service/gpu/gpu_compiler.cc | 21 +++++++++++++------ .../compiler/xla/service/gpu/gpu_compiler.h | 7 +++---- .../xla/service/gpu/gpu_executable.cc | 10 ++++----- .../compiler/xla/service/gpu/gpu_executable.h | 18 ++++++---------- .../compiler/xla/service/gpu/gpu_types.h | 17 +++++++++++++++ .../compiler/xla/service/gpu/kernel_thunk.cc | 9 ++++---- .../compiler/xla/service/gpu/kernel_thunk.h | 4 +--- .../xla/service/gpu/nvptx_compiler.cc | 11 ++++------ .../compiler/xla/service/gpu/nvptx_compiler.h | 2 +- .../xla/service/gpu/sequential_thunk.cc | 4 ++-- .../xla/service/gpu/sequential_thunk.h | 2 +- tensorflow/compiler/xla/service/gpu/thunk.h | 5 ++--- .../compiler/xla/service/gpu/while_thunk.cc | 6 +++--- .../compiler/xla/service/gpu/while_thunk.h | 2 +- .../service/mlir_gpu/mlir_compiler_impl.cc | 7 ++++--- .../compiler/xla/tests/llvm_compiler_test.cc | 7 +++---- 23 files changed, 90 insertions(+), 78 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 61bc41283e1..8f8263a85f9 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -17,15 +17,15 @@ load( "tf_cuda_library", ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load( + "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm", + "if_rocm_is_configured", +) load( "//tensorflow/core/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) -load( - "@local_config_rocm//rocm:build_defs.bzl", - "if_rocm_is_configured", -) load("//tensorflow:tensorflow.bzl", "if_nccl") package( @@ -86,6 +86,7 @@ cc_library( name = "gpu_types", hdrs = ["gpu_types.h"], deps = [ + "//tensorflow/compiler/xla:types", "@com_google_absl//absl/types:variant", ], ) @@ -405,6 +406,7 @@ cc_library( deps = [ ":buffer_allocations", ":gpu_executable_run_options", + ":gpu_types", ":hlo_execution_profiler", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla/service:hlo", diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc index 974db02b1b3..485aff0c4d8 100644 --- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc @@ -104,11 +104,9 @@ GpuVersion AMDGPUCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) { return isa_version; } -StatusOr>> -AMDGPUCompiler::CompileTargetBinary(const HloModule* module, - llvm::Module* llvm_module, - GpuVersion gpu_version, - se::StreamExecutor* stream_exec) { +StatusOr AMDGPUCompiler::CompileTargetBinary( + const HloModule* module, llvm::Module* llvm_module, GpuVersion gpu_version, + se::StreamExecutor* stream_exec) { if (rocdl_dir_.empty()) { // Compute rocdl_dir_ just once and cache it in this member. rocdl_dir_ = GetROCDLDir(module->config()); @@ -129,7 +127,7 @@ AMDGPUCompiler::CompileTargetBinary(const HloModule* module, user_post_optimization_hook_(*llvm_module); } - return std::pair>("", std::move(hsaco)); + return GpuTargetBinary{"", std::move(hsaco)}; } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h index acc5e021e3d..9033585763b 100644 --- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h @@ -39,7 +39,7 @@ class AMDGPUCompiler : public GpuCompiler { GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override; - StatusOr>> CompileTargetBinary( + StatusOr CompileTargetBinary( const HloModule* hlo_module, llvm::Module* llvm_module, GpuVersion gpu_version, se::StreamExecutor* stream_exec) override; diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index e31f45942b1..5e7d89c7aee 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -50,7 +50,7 @@ void ConditionalThunk::ComputeAnnotations() { } } -Status ConditionalThunk::Initialize(const GpuExecutable& executable, +Status ConditionalThunk::Initialize(const GpuTargetBinary& target_binary, se::StreamExecutor* executor) { if (branch_index_is_bool_) { TF_RET_CHECK(branch_thunks_.size() == 2); @@ -58,7 +58,7 @@ Status ConditionalThunk::Initialize(const GpuExecutable& executable, TF_RET_CHECK(!branch_thunks_.empty()); } for (auto& branch_thunk : branch_thunks_) { - TF_RETURN_IF_ERROR(branch_thunk->Initialize(executable, executor)); + TF_RETURN_IF_ERROR(branch_thunk->Initialize(target_binary, executor)); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h index 404e2131eff..ba69e1a38ec 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -52,7 +52,7 @@ class ConditionalThunk : public Thunk { ConditionalThunk& operator=(const ConditionalThunk&) = delete; void ComputeAnnotations() override; - Status Initialize(const GpuExecutable& executable, + Status Initialize(const GpuTargetBinary& target_binary, se::StreamExecutor* executor) override; Status ExecuteOnStream(const ExecuteParams& params) override; diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index 0a97f668b38..aacc9deb739 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -39,9 +39,9 @@ void ForThunk::ComputeAnnotations() { body_thunk_sequence_->ComputeAnnotations(); } -Status ForThunk::Initialize(const GpuExecutable& executable, +Status ForThunk::Initialize(const GpuTargetBinary& target_binary, se::StreamExecutor* executor) { - TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(target_binary, executor)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index 57402f70627..57657b6825f 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -38,7 +38,7 @@ class ForThunk : public Thunk { ForThunk& operator=(const ForThunk&) = delete; void ComputeAnnotations() override; - Status Initialize(const GpuExecutable& executable, + Status Initialize(const GpuTargetBinary& target_binary, se::StreamExecutor* executor) override; Status ExecuteOnStream(const ExecuteParams& params) override; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 5f6dfd7d3a5..533ff52a90d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -565,8 +565,7 @@ StatusOr> GpuCompiler::RunBackend( GpuVersion gpu_version = GetGpuVersion(stream_exec); - using BackendCompileResult = std::pair>; - TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result, + TF_ASSIGN_OR_RETURN(GpuTargetBinary backend_result, CompileTargetBinary(module.get(), &llvm_module, gpu_version, stream_exec)); @@ -578,6 +577,11 @@ StatusOr> GpuCompiler::RunBackend( thunk_schedule->ToString()); } + std::vector thunks; + for (Thunk* thunk : thunk_schedule->TotalOrder()) { + thunks.push_back(thunk); + } + std::unique_ptr profile_index_map; std::unique_ptr profile_printer; @@ -597,14 +601,19 @@ StatusOr> GpuCompiler::RunBackend( } auto* gpu_executable = new GpuExecutable( - backend_result.first, backend_result.second, gpu_version, - std::move(thunk_schedule), std::move(module), - std::move(buffer_assignment), std::move(profile_printer), - std::move(profile_index_map)); + std::move(backend_result), gpu_version, std::move(thunk_schedule), + std::move(module), std::move(buffer_assignment), + std::move(profile_printer), std::move(profile_index_map)); if (embed_ir_in_executable) { DCHECK_NE("", ir_module_string_before_opt); gpu_executable->set_ir_module_string(ir_module_string_before_opt); } + + for (Thunk* thunk : thunks) { + TF_RETURN_IF_ERROR( + thunk->Initialize(gpu_executable->target_binary(), stream_exec)); + } + return std::unique_ptr(gpu_executable); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index b52af5392d1..deb5d785777 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -74,10 +74,9 @@ class GpuCompiler : public LLVMCompiler { virtual GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) = 0; - virtual StatusOr>> - CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module, - GpuVersion gpu_version, - se::StreamExecutor* stream_exec) = 0; + virtual StatusOr CompileTargetBinary( + const HloModule* hlo_module, llvm::Module* llvm_module, + GpuVersion gpu_version, se::StreamExecutor* stream_exec) = 0; Status PrepareHloModuleForIrEmitting(HloModule* hlo_module); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 2df6b50d361..ebd3630635b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -52,16 +52,15 @@ using ::tensorflow::profiler::ScopedAnnotation; // Implementation note: HLO profiling is always enabled for GPU executables, // since we can use timers around thunks. GpuExecutable::GpuExecutable( - const string& text, const std::vector& binary, - GpuVersion gpu_version, std::unique_ptr thunk_schedule, + GpuTargetBinary target_binary, GpuVersion gpu_version, + std::unique_ptr thunk_schedule, std::shared_ptr hlo_module, std::shared_ptr assignment, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)), - text_(text), - binary_(binary), + target_binary_(std::move(target_binary)), gpu_version_(gpu_version), thunk_schedule_(std::move(thunk_schedule)), assignment_(std::move(assignment)) { @@ -176,7 +175,6 @@ Status GpuExecutable::ExecuteThunks( // module, we won't get any data, but that's probably an OK trade-off. ScopedAnnotation annotation([&] { return thunk->profile_annotation(); }); - TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor)); int32 stream_no = thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction()); se::Stream* stream = @@ -469,7 +467,7 @@ const InstructionValueSet& GpuExecutable::GetRootValueSet() const { int64 GpuExecutable::SizeOfGeneratedCodeInBytes() { // Non-empty PTX but empty cubin: compilation must have failed, return // "unknown". - if (binary().empty() && !text_.empty()) { + if (binary().empty() && !text().empty()) { return -1; } return binary().size(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 045a36c099b..29441c60b04 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -52,8 +52,7 @@ class GpuExecutable : public Executable { // We need to share ownership of hlo_module and assignment with profiler to // safely keep a reference to these objects during tracing period, thus they // are passed as shared pointers. - GpuExecutable(const string& text, const std::vector& binary, - GpuVersion gpu_version, + GpuExecutable(GpuTargetBinary target_binary, GpuVersion gpu_version, std::unique_ptr thunk_schedule, std::shared_ptr hlo_module, std::shared_ptr assignment, @@ -73,12 +72,14 @@ class GpuExecutable : public Executable { // Returns the compiled code for the computation. The compiled code is PTX in // Cuda and unused empty string in ROCm. - const string& text() const { return text_; } + const string& text() const { return target_binary_.text; } // Returns the binary stored in this GpuExecutable. The binary is cubin in // Cuda, and HSA code object in ROCm. It may be empty, in which case // compilation is left up to the GPU driver. - const std::vector& binary() const { return binary_; } + const std::vector& binary() const { return target_binary_.binary; } + + const GpuTargetBinary& target_binary() const { return target_binary_; } // ExecuteAsyncOnStream will fail if the compute capability of the stream // doesn't match the compute capability passed to this object's constructor. @@ -131,14 +132,7 @@ class GpuExecutable : public Executable { // This string should be modified only before ExecuteOnStream. string ir_module_string_; - // The compiled code for the computation. - const string text_; - - // The GPU machine code for the computation, targeting GPUs at - // compute_capability_. - // - // May be empty, in which case we leave compilation up to the GPU driver. - const std::vector binary_; + const GpuTargetBinary target_binary_; // The GPU version for compute compatibility check. GpuVersion gpu_version_; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_types.h b/tensorflow/compiler/xla/service/gpu/gpu_types.h index 1c51040fb82..5c8b8093d65 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_types.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_types.h @@ -16,7 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_ +#include +#include + #include "absl/types/variant.h" +#include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { @@ -25,6 +29,19 @@ namespace gpu { // it comprises a pair of integers denoting major and minor version. // On ROCm platform, it comprises one integer for AMD GCN ISA version. using GpuVersion = absl::variant, int>; + +// A struct to carry around compiled results by the GPU assembler. +struct GpuTargetBinary { + GpuTargetBinary(const GpuTargetBinary& other) = delete; + GpuTargetBinary(GpuTargetBinary&& other) = default; + + // The text format of the compiled result, e.g. PTX. + std::string text; + + // The actual compiled binary. + std::vector binary; +}; + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index d976b5d8d4d..0b5010ea66b 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -18,7 +18,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -42,7 +41,7 @@ KernelThunk::KernelThunk(absl::Span args, kernel_name_(kernel_name), unroll_factor_(unroll_factor) {} -Status KernelThunk::Initialize(const GpuExecutable& executable, +Status KernelThunk::Initialize(const GpuTargetBinary& target_binary, se::StreamExecutor* executor) { tensorflow::mutex_lock lock(mutex_); @@ -55,8 +54,10 @@ Status KernelThunk::Initialize(const GpuExecutable& executable, if (kernel_cache_.end() == it) { TF_ASSIGN_OR_RETURN( std::unique_ptr kernel, - CreateKernel(kernel_name_, args_.size(), executable.text(), - executable.binary(), executor)); + CreateKernel(kernel_name_, args_.size(), target_binary.text, + target_binary.binary, executor)); + CHECK(!target_binary.binary.empty()); + CHECK(kernel); kernel_cache_.emplace(executor, std::move(kernel)); } diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index 88351881f3a..97a1d08a57e 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -35,8 +35,6 @@ limitations under the License. namespace xla { namespace gpu { -class GpuExecutable; - // This class stores everything that StreamExecutor needs for launching a // kernel. It implements the ExecuteOnStream interface for GpuExecutable to // invoke the corresponding kernel. @@ -58,7 +56,7 @@ class KernelThunk : public Thunk { int unroll_factor() const { return unroll_factor_; } void SetLaunchDimensions(const LaunchDimensions& launch_dims); - Status Initialize(const GpuExecutable& executable, + Status Initialize(const GpuTargetBinary& target_binary, se::StreamExecutor* executor) override; Status ExecuteOnStream(const ExecuteParams& params) override; diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 0196267d904..cf6fe9292e5 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -295,11 +295,9 @@ GpuVersion NVPTXCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) { return std::make_pair(cc_major, cc_minor); } -StatusOr>> -NVPTXCompiler::CompileTargetBinary(const HloModule* module, - llvm::Module* llvm_module, - GpuVersion gpu_version, - se::StreamExecutor* stream_exec) { +StatusOr NVPTXCompiler::CompileTargetBinary( + const HloModule* module, llvm::Module* llvm_module, GpuVersion gpu_version, + se::StreamExecutor* stream_exec) { std::pair compute_capability = absl::get>(gpu_version); @@ -340,8 +338,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module, stream_exec, ptx, compute_capability.first, compute_capability.second, module->config()); - return std::pair>(std::move(ptx), - std::move(cubin)); + return GpuTargetBinary{std::move(ptx), std::move(cubin)}; } std::vector NVPTXCompiler::CompileGpuAsmOrGetCachedResult( diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index e69be947522..ec550b5b2ff 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -48,7 +48,7 @@ class NVPTXCompiler : public GpuCompiler { GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override; - StatusOr>> CompileTargetBinary( + StatusOr CompileTargetBinary( const HloModule* hlo_module, llvm::Module* llvm_module, GpuVersion gpu_version, se::StreamExecutor* stream_exec) override; diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc index 025ca60ef0c..bd260336c28 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc @@ -34,10 +34,10 @@ void SequentialThunk::ComputeAnnotations() { } } -Status SequentialThunk::Initialize(const GpuExecutable& executable, +Status SequentialThunk::Initialize(const GpuTargetBinary& target_binary, se::StreamExecutor* executor) { for (auto& thunk : thunks_) { - TF_RETURN_IF_ERROR(thunk->Initialize(executable, executor)); + TF_RETURN_IF_ERROR(thunk->Initialize(target_binary, executor)); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h index 3abb82c0b66..b5475664733 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h @@ -40,7 +40,7 @@ class SequentialThunk : public Thunk { const std::vector>& thunks() const { return thunks_; } void ComputeAnnotations() override; - Status Initialize(const GpuExecutable& executable, + Status Initialize(const GpuTargetBinary& target_binary, se::StreamExecutor* executor) override; Status ExecuteOnStream(const ExecuteParams& params) override; diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index e9be41b74de..7aff9ca47b7 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/lib/core/status.h" @@ -30,8 +31,6 @@ limitations under the License. namespace xla { namespace gpu { -class GpuExecutable; - // Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the // metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction. // @@ -97,7 +96,7 @@ class Thunk { // This may be called multiple times. Its main purpose is to give us a chance // to do initialization outside of ExecuteOnStream() so that the // time spent initializing doesn't count towards our execution profile. - virtual Status Initialize(const GpuExecutable& /*executable*/, + virtual Status Initialize(const GpuTargetBinary& /*target_binary*/, se::StreamExecutor* /*executor*/) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 4134cd39832..2650508093e 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -45,11 +45,11 @@ void WhileThunk::ComputeAnnotations() { body_thunk_sequence_->ComputeAnnotations(); } -Status WhileThunk::Initialize(const GpuExecutable& executable, +Status WhileThunk::Initialize(const GpuTargetBinary& target_binary, se::StreamExecutor* executor) { TF_RETURN_IF_ERROR( - condition_thunk_sequence_->Initialize(executable, executor)); - TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor)); + condition_thunk_sequence_->Initialize(target_binary, executor)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(target_binary, executor)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 31db01b72ba..77ee0104a1f 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -47,7 +47,7 @@ class WhileThunk : public Thunk { WhileThunk& operator=(const WhileThunk&) = delete; void ComputeAnnotations() override; - Status Initialize(const GpuExecutable& executable, + Status Initialize(const GpuTargetBinary& target_binary, se::StreamExecutor* executor) override; Status ExecuteOnStream(const ExecuteParams& params) override; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc index 35ac3b2bf63..667cdef8f6c 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc @@ -549,10 +549,11 @@ StatusOr> MlirCompilerImpl::RunBackend( } // TODO(b/137624192): Add profiling support. + return {absl::make_unique( - ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule), - emission_context.releaseHloModule(), std::move(buffer_assignment), - nullptr, nullptr)}; + xla::gpu::GpuTargetBinary{ptx, cubin}, GetGpuVersion(stream_exec), + std::move(thunk_schedule), emission_context.releaseHloModule(), + std::move(buffer_assignment), nullptr, nullptr)}; } StatusOr>> MlirCompilerImpl::Compile( diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 1947f517bd9..16ed02296b7 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -55,16 +55,15 @@ class GpuDummyCompiler : public GpuCompiler { GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) { return 0; } - StatusOr>> CompileTargetBinary( + StatusOr CompileTargetBinary( const HloModule* hlo_module, llvm::Module* llvm_module, - GpuVersion gpu_version, se::StreamExecutor* stream_exec) { + GpuVersion gpu_version, se::StreamExecutor* stream_exec) override { if (user_post_optimization_hook_) { user_post_optimization_hook_(*llvm_module); } std::vector compiled_results; - return std::pair>( - "", std::move(compiled_results)); + return GpuTargetBinary{"", std::move(compiled_results)}; } }; } // namespace gpu