[XLA/GPU] Make Thunk::Initialize() happen at compile-time, not run-time. This simplifies GpuExecutable for MLIR transition.

PiperOrigin-RevId: 311178815
Change-Id: Ib9c8b8a2f8719c0cd8b342ab07af6e8cb65d82bf
This commit is contained in:
Tim Shen 2020-05-12 12:26:39 -07:00 committed by TensorFlower Gardener
parent ac7e71cc3d
commit 9f58e6902c
23 changed files with 90 additions and 78 deletions

View File

@ -17,15 +17,15 @@ load(
"tf_cuda_library", "tf_cuda_library",
) )
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load(
"@local_config_rocm//rocm:build_defs.bzl",
"if_rocm",
"if_rocm_is_configured",
)
load( load(
"//tensorflow/core/platform/default:cuda_build_defs.bzl", "//tensorflow/core/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured", "if_cuda_is_configured",
) )
load(
"@local_config_rocm//rocm:build_defs.bzl",
"if_rocm_is_configured",
)
load("//tensorflow:tensorflow.bzl", "if_nccl") load("//tensorflow:tensorflow.bzl", "if_nccl")
package( package(
@ -86,6 +86,7 @@ cc_library(
name = "gpu_types", name = "gpu_types",
hdrs = ["gpu_types.h"], hdrs = ["gpu_types.h"],
deps = [ deps = [
"//tensorflow/compiler/xla:types",
"@com_google_absl//absl/types:variant", "@com_google_absl//absl/types:variant",
], ],
) )
@ -405,6 +406,7 @@ cc_library(
deps = [ deps = [
":buffer_allocations", ":buffer_allocations",
":gpu_executable_run_options", ":gpu_executable_run_options",
":gpu_types",
":hlo_execution_profiler", ":hlo_execution_profiler",
"//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo",

View File

@ -104,11 +104,9 @@ GpuVersion AMDGPUCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) {
return isa_version; return isa_version;
} }
StatusOr<std::pair<std::string, std::vector<uint8>>> StatusOr<GpuTargetBinary> AMDGPUCompiler::CompileTargetBinary(
AMDGPUCompiler::CompileTargetBinary(const HloModule* module, const HloModule* module, llvm::Module* llvm_module, GpuVersion gpu_version,
llvm::Module* llvm_module, se::StreamExecutor* stream_exec) {
GpuVersion gpu_version,
se::StreamExecutor* stream_exec) {
if (rocdl_dir_.empty()) { if (rocdl_dir_.empty()) {
// Compute rocdl_dir_ just once and cache it in this member. // Compute rocdl_dir_ just once and cache it in this member.
rocdl_dir_ = GetROCDLDir(module->config()); rocdl_dir_ = GetROCDLDir(module->config());
@ -129,7 +127,7 @@ AMDGPUCompiler::CompileTargetBinary(const HloModule* module,
user_post_optimization_hook_(*llvm_module); user_post_optimization_hook_(*llvm_module);
} }
return std::pair<std::string, std::vector<uint8>>("", std::move(hsaco)); return GpuTargetBinary{"", std::move(hsaco)};
} }
} // namespace gpu } // namespace gpu

View File

@ -39,7 +39,7 @@ class AMDGPUCompiler : public GpuCompiler {
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override; GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary( StatusOr<GpuTargetBinary> CompileTargetBinary(
const HloModule* hlo_module, llvm::Module* llvm_module, const HloModule* hlo_module, llvm::Module* llvm_module,
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override; GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;

View File

@ -50,7 +50,7 @@ void ConditionalThunk::ComputeAnnotations() {
} }
} }
Status ConditionalThunk::Initialize(const GpuExecutable& executable, Status ConditionalThunk::Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) { se::StreamExecutor* executor) {
if (branch_index_is_bool_) { if (branch_index_is_bool_) {
TF_RET_CHECK(branch_thunks_.size() == 2); TF_RET_CHECK(branch_thunks_.size() == 2);
@ -58,7 +58,7 @@ Status ConditionalThunk::Initialize(const GpuExecutable& executable,
TF_RET_CHECK(!branch_thunks_.empty()); TF_RET_CHECK(!branch_thunks_.empty());
} }
for (auto& branch_thunk : branch_thunks_) { for (auto& branch_thunk : branch_thunks_) {
TF_RETURN_IF_ERROR(branch_thunk->Initialize(executable, executor)); TF_RETURN_IF_ERROR(branch_thunk->Initialize(target_binary, executor));
} }
return Status::OK(); return Status::OK();
} }

View File

@ -52,7 +52,7 @@ class ConditionalThunk : public Thunk {
ConditionalThunk& operator=(const ConditionalThunk&) = delete; ConditionalThunk& operator=(const ConditionalThunk&) = delete;
void ComputeAnnotations() override; void ComputeAnnotations() override;
Status Initialize(const GpuExecutable& executable, Status Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) override; se::StreamExecutor* executor) override;
Status ExecuteOnStream(const ExecuteParams& params) override; Status ExecuteOnStream(const ExecuteParams& params) override;

View File

@ -39,9 +39,9 @@ void ForThunk::ComputeAnnotations() {
body_thunk_sequence_->ComputeAnnotations(); body_thunk_sequence_->ComputeAnnotations();
} }
Status ForThunk::Initialize(const GpuExecutable& executable, Status ForThunk::Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) { se::StreamExecutor* executor) {
TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor)); TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(target_binary, executor));
return Status::OK(); return Status::OK();
} }

View File

@ -38,7 +38,7 @@ class ForThunk : public Thunk {
ForThunk& operator=(const ForThunk&) = delete; ForThunk& operator=(const ForThunk&) = delete;
void ComputeAnnotations() override; void ComputeAnnotations() override;
Status Initialize(const GpuExecutable& executable, Status Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) override; se::StreamExecutor* executor) override;
Status ExecuteOnStream(const ExecuteParams& params) override; Status ExecuteOnStream(const ExecuteParams& params) override;

View File

@ -565,8 +565,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
GpuVersion gpu_version = GetGpuVersion(stream_exec); GpuVersion gpu_version = GetGpuVersion(stream_exec);
using BackendCompileResult = std::pair<std::string, std::vector<uint8>>; TF_ASSIGN_OR_RETURN(GpuTargetBinary backend_result,
TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
CompileTargetBinary(module.get(), &llvm_module, CompileTargetBinary(module.get(), &llvm_module,
gpu_version, stream_exec)); gpu_version, stream_exec));
@ -578,6 +577,11 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
thunk_schedule->ToString()); thunk_schedule->ToString());
} }
std::vector<Thunk*> thunks;
for (Thunk* thunk : thunk_schedule->TotalOrder()) {
thunks.push_back(thunk);
}
std::unique_ptr<HloProfileIndexMap> profile_index_map; std::unique_ptr<HloProfileIndexMap> profile_index_map;
std::unique_ptr<HloProfilePrinterData> profile_printer; std::unique_ptr<HloProfilePrinterData> profile_printer;
@ -597,14 +601,19 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
} }
auto* gpu_executable = new GpuExecutable( auto* gpu_executable = new GpuExecutable(
backend_result.first, backend_result.second, gpu_version, std::move(backend_result), gpu_version, std::move(thunk_schedule),
std::move(thunk_schedule), std::move(module), std::move(module), std::move(buffer_assignment),
std::move(buffer_assignment), std::move(profile_printer), std::move(profile_printer), std::move(profile_index_map));
std::move(profile_index_map));
if (embed_ir_in_executable) { if (embed_ir_in_executable) {
DCHECK_NE("", ir_module_string_before_opt); DCHECK_NE("", ir_module_string_before_opt);
gpu_executable->set_ir_module_string(ir_module_string_before_opt); gpu_executable->set_ir_module_string(ir_module_string_before_opt);
} }
for (Thunk* thunk : thunks) {
TF_RETURN_IF_ERROR(
thunk->Initialize(gpu_executable->target_binary(), stream_exec));
}
return std::unique_ptr<Executable>(gpu_executable); return std::unique_ptr<Executable>(gpu_executable);
} }

View File

@ -74,10 +74,9 @@ class GpuCompiler : public LLVMCompiler {
virtual GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) = 0; virtual GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) = 0;
virtual StatusOr<std::pair<std::string, std::vector<uint8>>> virtual StatusOr<GpuTargetBinary> CompileTargetBinary(
CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module, const HloModule* hlo_module, llvm::Module* llvm_module,
GpuVersion gpu_version, GpuVersion gpu_version, se::StreamExecutor* stream_exec) = 0;
se::StreamExecutor* stream_exec) = 0;
Status PrepareHloModuleForIrEmitting(HloModule* hlo_module); Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);

View File

@ -52,16 +52,15 @@ using ::tensorflow::profiler::ScopedAnnotation;
// Implementation note: HLO profiling is always enabled for GPU executables, // Implementation note: HLO profiling is always enabled for GPU executables,
// since we can use timers around thunks. // since we can use timers around thunks.
GpuExecutable::GpuExecutable( GpuExecutable::GpuExecutable(
const string& text, const std::vector<uint8>& binary, GpuTargetBinary target_binary, GpuVersion gpu_version,
GpuVersion gpu_version, std::unique_ptr<const ThunkSchedule> thunk_schedule, std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::shared_ptr<HloModule> hlo_module, std::shared_ptr<HloModule> hlo_module,
std::shared_ptr<const BufferAssignment> assignment, std::shared_ptr<const BufferAssignment> assignment,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data, std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map) std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)), std::move(hlo_profile_index_map)),
text_(text), target_binary_(std::move(target_binary)),
binary_(binary),
gpu_version_(gpu_version), gpu_version_(gpu_version),
thunk_schedule_(std::move(thunk_schedule)), thunk_schedule_(std::move(thunk_schedule)),
assignment_(std::move(assignment)) { assignment_(std::move(assignment)) {
@ -176,7 +175,6 @@ Status GpuExecutable::ExecuteThunks(
// module, we won't get any data, but that's probably an OK trade-off. // module, we won't get any data, but that's probably an OK trade-off.
ScopedAnnotation annotation([&] { return thunk->profile_annotation(); }); ScopedAnnotation annotation([&] { return thunk->profile_annotation(); });
TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor));
int32 stream_no = int32 stream_no =
thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction()); thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction());
se::Stream* stream = se::Stream* stream =
@ -469,7 +467,7 @@ const InstructionValueSet& GpuExecutable::GetRootValueSet() const {
int64 GpuExecutable::SizeOfGeneratedCodeInBytes() { int64 GpuExecutable::SizeOfGeneratedCodeInBytes() {
// Non-empty PTX but empty cubin: compilation must have failed, return // Non-empty PTX but empty cubin: compilation must have failed, return
// "unknown". // "unknown".
if (binary().empty() && !text_.empty()) { if (binary().empty() && !text().empty()) {
return -1; return -1;
} }
return binary().size(); return binary().size();

View File

@ -52,8 +52,7 @@ class GpuExecutable : public Executable {
// We need to share ownership of hlo_module and assignment with profiler to // We need to share ownership of hlo_module and assignment with profiler to
// safely keep a reference to these objects during tracing period, thus they // safely keep a reference to these objects during tracing period, thus they
// are passed as shared pointers. // are passed as shared pointers.
GpuExecutable(const string& text, const std::vector<uint8>& binary, GpuExecutable(GpuTargetBinary target_binary, GpuVersion gpu_version,
GpuVersion gpu_version,
std::unique_ptr<const ThunkSchedule> thunk_schedule, std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::shared_ptr<HloModule> hlo_module, std::shared_ptr<HloModule> hlo_module,
std::shared_ptr<const BufferAssignment> assignment, std::shared_ptr<const BufferAssignment> assignment,
@ -73,12 +72,14 @@ class GpuExecutable : public Executable {
// Returns the compiled code for the computation. The compiled code is PTX in // Returns the compiled code for the computation. The compiled code is PTX in
// Cuda and unused empty string in ROCm. // Cuda and unused empty string in ROCm.
const string& text() const { return text_; } const string& text() const { return target_binary_.text; }
// Returns the binary stored in this GpuExecutable. The binary is cubin in // Returns the binary stored in this GpuExecutable. The binary is cubin in
// Cuda, and HSA code object in ROCm. It may be empty, in which case // Cuda, and HSA code object in ROCm. It may be empty, in which case
// compilation is left up to the GPU driver. // compilation is left up to the GPU driver.
const std::vector<uint8>& binary() const { return binary_; } const std::vector<uint8>& binary() const { return target_binary_.binary; }
const GpuTargetBinary& target_binary() const { return target_binary_; }
// ExecuteAsyncOnStream will fail if the compute capability of the stream // ExecuteAsyncOnStream will fail if the compute capability of the stream
// doesn't match the compute capability passed to this object's constructor. // doesn't match the compute capability passed to this object's constructor.
@ -131,14 +132,7 @@ class GpuExecutable : public Executable {
// This string should be modified only before ExecuteOnStream. // This string should be modified only before ExecuteOnStream.
string ir_module_string_; string ir_module_string_;
// The compiled code for the computation. const GpuTargetBinary target_binary_;
const string text_;
// The GPU machine code for the computation, targeting GPUs at
// compute_capability_.
//
// May be empty, in which case we leave compilation up to the GPU driver.
const std::vector<uint8> binary_;
// The GPU version for compute compatibility check. // The GPU version for compute compatibility check.
GpuVersion gpu_version_; GpuVersion gpu_version_;

View File

@ -16,7 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_ #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
#include <string>
#include <vector>
#include "absl/types/variant.h" #include "absl/types/variant.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla { namespace xla {
namespace gpu { namespace gpu {
@ -25,6 +29,19 @@ namespace gpu {
// it comprises a pair of integers denoting major and minor version. // it comprises a pair of integers denoting major and minor version.
// On ROCm platform, it comprises one integer for AMD GCN ISA version. // On ROCm platform, it comprises one integer for AMD GCN ISA version.
using GpuVersion = absl::variant<std::pair<int, int>, int>; using GpuVersion = absl::variant<std::pair<int, int>, int>;
// A struct to carry around compiled results by the GPU assembler.
struct GpuTargetBinary {
GpuTargetBinary(const GpuTargetBinary& other) = delete;
GpuTargetBinary(GpuTargetBinary&& other) = default;
// The text format of the compiled result, e.g. PTX.
std::string text;
// The actual compiled binary.
std::vector<tensorflow::uint8> binary;
};
} // namespace gpu } // namespace gpu
} // namespace xla } // namespace xla

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
@ -42,7 +41,7 @@ KernelThunk::KernelThunk(absl::Span<const BufferAllocation* const> args,
kernel_name_(kernel_name), kernel_name_(kernel_name),
unroll_factor_(unroll_factor) {} unroll_factor_(unroll_factor) {}
Status KernelThunk::Initialize(const GpuExecutable& executable, Status KernelThunk::Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) { se::StreamExecutor* executor) {
tensorflow::mutex_lock lock(mutex_); tensorflow::mutex_lock lock(mutex_);
@ -55,8 +54,10 @@ Status KernelThunk::Initialize(const GpuExecutable& executable,
if (kernel_cache_.end() == it) { if (kernel_cache_.end() == it) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::unique_ptr<se::KernelBase> kernel, std::unique_ptr<se::KernelBase> kernel,
CreateKernel(kernel_name_, args_.size(), executable.text(), CreateKernel(kernel_name_, args_.size(), target_binary.text,
executable.binary(), executor)); target_binary.binary, executor));
CHECK(!target_binary.binary.empty());
CHECK(kernel);
kernel_cache_.emplace(executor, std::move(kernel)); kernel_cache_.emplace(executor, std::move(kernel));
} }

View File

@ -35,8 +35,6 @@ limitations under the License.
namespace xla { namespace xla {
namespace gpu { namespace gpu {
class GpuExecutable;
// This class stores everything that StreamExecutor needs for launching a // This class stores everything that StreamExecutor needs for launching a
// kernel. It implements the ExecuteOnStream interface for GpuExecutable to // kernel. It implements the ExecuteOnStream interface for GpuExecutable to
// invoke the corresponding kernel. // invoke the corresponding kernel.
@ -58,7 +56,7 @@ class KernelThunk : public Thunk {
int unroll_factor() const { return unroll_factor_; } int unroll_factor() const { return unroll_factor_; }
void SetLaunchDimensions(const LaunchDimensions& launch_dims); void SetLaunchDimensions(const LaunchDimensions& launch_dims);
Status Initialize(const GpuExecutable& executable, Status Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) override; se::StreamExecutor* executor) override;
Status ExecuteOnStream(const ExecuteParams& params) override; Status ExecuteOnStream(const ExecuteParams& params) override;

View File

@ -295,11 +295,9 @@ GpuVersion NVPTXCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) {
return std::make_pair(cc_major, cc_minor); return std::make_pair(cc_major, cc_minor);
} }
StatusOr<std::pair<std::string, std::vector<uint8>>> StatusOr<GpuTargetBinary> NVPTXCompiler::CompileTargetBinary(
NVPTXCompiler::CompileTargetBinary(const HloModule* module, const HloModule* module, llvm::Module* llvm_module, GpuVersion gpu_version,
llvm::Module* llvm_module, se::StreamExecutor* stream_exec) {
GpuVersion gpu_version,
se::StreamExecutor* stream_exec) {
std::pair<int, int> compute_capability = std::pair<int, int> compute_capability =
absl::get<std::pair<int, int>>(gpu_version); absl::get<std::pair<int, int>>(gpu_version);
@ -340,8 +338,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module,
stream_exec, ptx, compute_capability.first, compute_capability.second, stream_exec, ptx, compute_capability.first, compute_capability.second,
module->config()); module->config());
return std::pair<std::string, std::vector<uint8>>(std::move(ptx), return GpuTargetBinary{std::move(ptx), std::move(cubin)};
std::move(cubin));
} }
std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult( std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(

View File

@ -48,7 +48,7 @@ class NVPTXCompiler : public GpuCompiler {
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override; GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary( StatusOr<GpuTargetBinary> CompileTargetBinary(
const HloModule* hlo_module, llvm::Module* llvm_module, const HloModule* hlo_module, llvm::Module* llvm_module,
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override; GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;

View File

@ -34,10 +34,10 @@ void SequentialThunk::ComputeAnnotations() {
} }
} }
Status SequentialThunk::Initialize(const GpuExecutable& executable, Status SequentialThunk::Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) { se::StreamExecutor* executor) {
for (auto& thunk : thunks_) { for (auto& thunk : thunks_) {
TF_RETURN_IF_ERROR(thunk->Initialize(executable, executor)); TF_RETURN_IF_ERROR(thunk->Initialize(target_binary, executor));
} }
return Status::OK(); return Status::OK();
} }

View File

@ -40,7 +40,7 @@ class SequentialThunk : public Thunk {
const std::vector<std::unique_ptr<Thunk>>& thunks() const { return thunks_; } const std::vector<std::unique_ptr<Thunk>>& thunks() const { return thunks_; }
void ComputeAnnotations() override; void ComputeAnnotations() override;
Status Initialize(const GpuExecutable& executable, Status Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) override; se::StreamExecutor* executor) override;
Status ExecuteOnStream(const ExecuteParams& params) override; Status ExecuteOnStream(const ExecuteParams& params) override;

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
@ -30,8 +31,6 @@ limitations under the License.
namespace xla { namespace xla {
namespace gpu { namespace gpu {
class GpuExecutable;
// Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the // Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the
// metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction. // metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction.
// //
@ -97,7 +96,7 @@ class Thunk {
// This may be called multiple times. Its main purpose is to give us a chance // This may be called multiple times. Its main purpose is to give us a chance
// to do initialization outside of ExecuteOnStream() so that the // to do initialization outside of ExecuteOnStream() so that the
// time spent initializing doesn't count towards our execution profile. // time spent initializing doesn't count towards our execution profile.
virtual Status Initialize(const GpuExecutable& /*executable*/, virtual Status Initialize(const GpuTargetBinary& /*target_binary*/,
se::StreamExecutor* /*executor*/) { se::StreamExecutor* /*executor*/) {
return Status::OK(); return Status::OK();
} }

View File

@ -45,11 +45,11 @@ void WhileThunk::ComputeAnnotations() {
body_thunk_sequence_->ComputeAnnotations(); body_thunk_sequence_->ComputeAnnotations();
} }
Status WhileThunk::Initialize(const GpuExecutable& executable, Status WhileThunk::Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) { se::StreamExecutor* executor) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
condition_thunk_sequence_->Initialize(executable, executor)); condition_thunk_sequence_->Initialize(target_binary, executor));
TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor)); TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(target_binary, executor));
return Status::OK(); return Status::OK();
} }

View File

@ -47,7 +47,7 @@ class WhileThunk : public Thunk {
WhileThunk& operator=(const WhileThunk&) = delete; WhileThunk& operator=(const WhileThunk&) = delete;
void ComputeAnnotations() override; void ComputeAnnotations() override;
Status Initialize(const GpuExecutable& executable, Status Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) override; se::StreamExecutor* executor) override;
Status ExecuteOnStream(const ExecuteParams& params) override; Status ExecuteOnStream(const ExecuteParams& params) override;

View File

@ -549,10 +549,11 @@ StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
} }
// TODO(b/137624192): Add profiling support. // TODO(b/137624192): Add profiling support.
return {absl::make_unique<GpuExecutable>( return {absl::make_unique<GpuExecutable>(
ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule), xla::gpu::GpuTargetBinary{ptx, cubin}, GetGpuVersion(stream_exec),
emission_context.releaseHloModule(), std::move(buffer_assignment), std::move(thunk_schedule), emission_context.releaseHloModule(),
nullptr, nullptr)}; std::move(buffer_assignment), nullptr, nullptr)};
} }
StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile( StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile(

View File

@ -55,16 +55,15 @@ class GpuDummyCompiler : public GpuCompiler {
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) { return 0; } GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) { return 0; }
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary( StatusOr<GpuTargetBinary> CompileTargetBinary(
const HloModule* hlo_module, llvm::Module* llvm_module, const HloModule* hlo_module, llvm::Module* llvm_module,
GpuVersion gpu_version, se::StreamExecutor* stream_exec) { GpuVersion gpu_version, se::StreamExecutor* stream_exec) override {
if (user_post_optimization_hook_) { if (user_post_optimization_hook_) {
user_post_optimization_hook_(*llvm_module); user_post_optimization_hook_(*llvm_module);
} }
std::vector<uint8> compiled_results; std::vector<uint8> compiled_results;
return std::pair<std::string, std::vector<uint8>>( return GpuTargetBinary{"", std::move(compiled_results)};
"", std::move(compiled_results));
} }
}; };
} // namespace gpu } // namespace gpu