[XLA/GPU] Make Thunk::Initialize() happen at compile-time, not run-time. This simplifies GpuExecutable for MLIR transition.
PiperOrigin-RevId: 311178815 Change-Id: Ib9c8b8a2f8719c0cd8b342ab07af6e8cb65d82bf
This commit is contained in:
parent
ac7e71cc3d
commit
9f58e6902c
@ -17,15 +17,15 @@ load(
|
||||
"tf_cuda_library",
|
||||
)
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
||||
load(
|
||||
"@local_config_rocm//rocm:build_defs.bzl",
|
||||
"if_rocm",
|
||||
"if_rocm_is_configured",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core/platform/default:cuda_build_defs.bzl",
|
||||
"if_cuda_is_configured",
|
||||
)
|
||||
load(
|
||||
"@local_config_rocm//rocm:build_defs.bzl",
|
||||
"if_rocm_is_configured",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "if_nccl")
|
||||
|
||||
package(
|
||||
@ -86,6 +86,7 @@ cc_library(
|
||||
name = "gpu_types",
|
||||
hdrs = ["gpu_types.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
)
|
||||
@ -405,6 +406,7 @@ cc_library(
|
||||
deps = [
|
||||
":buffer_allocations",
|
||||
":gpu_executable_run_options",
|
||||
":gpu_types",
|
||||
":hlo_execution_profiler",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
|
@ -104,11 +104,9 @@ GpuVersion AMDGPUCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) {
|
||||
return isa_version;
|
||||
}
|
||||
|
||||
StatusOr<std::pair<std::string, std::vector<uint8>>>
|
||||
AMDGPUCompiler::CompileTargetBinary(const HloModule* module,
|
||||
llvm::Module* llvm_module,
|
||||
GpuVersion gpu_version,
|
||||
se::StreamExecutor* stream_exec) {
|
||||
StatusOr<GpuTargetBinary> AMDGPUCompiler::CompileTargetBinary(
|
||||
const HloModule* module, llvm::Module* llvm_module, GpuVersion gpu_version,
|
||||
se::StreamExecutor* stream_exec) {
|
||||
if (rocdl_dir_.empty()) {
|
||||
// Compute rocdl_dir_ just once and cache it in this member.
|
||||
rocdl_dir_ = GetROCDLDir(module->config());
|
||||
@ -129,7 +127,7 @@ AMDGPUCompiler::CompileTargetBinary(const HloModule* module,
|
||||
user_post_optimization_hook_(*llvm_module);
|
||||
}
|
||||
|
||||
return std::pair<std::string, std::vector<uint8>>("", std::move(hsaco));
|
||||
return GpuTargetBinary{"", std::move(hsaco)};
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -39,7 +39,7 @@ class AMDGPUCompiler : public GpuCompiler {
|
||||
|
||||
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
|
||||
|
||||
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
||||
StatusOr<GpuTargetBinary> CompileTargetBinary(
|
||||
const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
|
||||
|
||||
|
@ -50,7 +50,7 @@ void ConditionalThunk::ComputeAnnotations() {
|
||||
}
|
||||
}
|
||||
|
||||
Status ConditionalThunk::Initialize(const GpuExecutable& executable,
|
||||
Status ConditionalThunk::Initialize(const GpuTargetBinary& target_binary,
|
||||
se::StreamExecutor* executor) {
|
||||
if (branch_index_is_bool_) {
|
||||
TF_RET_CHECK(branch_thunks_.size() == 2);
|
||||
@ -58,7 +58,7 @@ Status ConditionalThunk::Initialize(const GpuExecutable& executable,
|
||||
TF_RET_CHECK(!branch_thunks_.empty());
|
||||
}
|
||||
for (auto& branch_thunk : branch_thunks_) {
|
||||
TF_RETURN_IF_ERROR(branch_thunk->Initialize(executable, executor));
|
||||
TF_RETURN_IF_ERROR(branch_thunk->Initialize(target_binary, executor));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ class ConditionalThunk : public Thunk {
|
||||
ConditionalThunk& operator=(const ConditionalThunk&) = delete;
|
||||
|
||||
void ComputeAnnotations() override;
|
||||
Status Initialize(const GpuExecutable& executable,
|
||||
Status Initialize(const GpuTargetBinary& target_binary,
|
||||
se::StreamExecutor* executor) override;
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
|
@ -39,9 +39,9 @@ void ForThunk::ComputeAnnotations() {
|
||||
body_thunk_sequence_->ComputeAnnotations();
|
||||
}
|
||||
|
||||
Status ForThunk::Initialize(const GpuExecutable& executable,
|
||||
Status ForThunk::Initialize(const GpuTargetBinary& target_binary,
|
||||
se::StreamExecutor* executor) {
|
||||
TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor));
|
||||
TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(target_binary, executor));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -38,7 +38,7 @@ class ForThunk : public Thunk {
|
||||
ForThunk& operator=(const ForThunk&) = delete;
|
||||
|
||||
void ComputeAnnotations() override;
|
||||
Status Initialize(const GpuExecutable& executable,
|
||||
Status Initialize(const GpuTargetBinary& target_binary,
|
||||
se::StreamExecutor* executor) override;
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
|
@ -565,8 +565,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
|
||||
GpuVersion gpu_version = GetGpuVersion(stream_exec);
|
||||
|
||||
using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
|
||||
TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
|
||||
TF_ASSIGN_OR_RETURN(GpuTargetBinary backend_result,
|
||||
CompileTargetBinary(module.get(), &llvm_module,
|
||||
gpu_version, stream_exec));
|
||||
|
||||
@ -578,6 +577,11 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
thunk_schedule->ToString());
|
||||
}
|
||||
|
||||
std::vector<Thunk*> thunks;
|
||||
for (Thunk* thunk : thunk_schedule->TotalOrder()) {
|
||||
thunks.push_back(thunk);
|
||||
}
|
||||
|
||||
std::unique_ptr<HloProfileIndexMap> profile_index_map;
|
||||
std::unique_ptr<HloProfilePrinterData> profile_printer;
|
||||
|
||||
@ -597,14 +601,19 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
}
|
||||
|
||||
auto* gpu_executable = new GpuExecutable(
|
||||
backend_result.first, backend_result.second, gpu_version,
|
||||
std::move(thunk_schedule), std::move(module),
|
||||
std::move(buffer_assignment), std::move(profile_printer),
|
||||
std::move(profile_index_map));
|
||||
std::move(backend_result), gpu_version, std::move(thunk_schedule),
|
||||
std::move(module), std::move(buffer_assignment),
|
||||
std::move(profile_printer), std::move(profile_index_map));
|
||||
if (embed_ir_in_executable) {
|
||||
DCHECK_NE("", ir_module_string_before_opt);
|
||||
gpu_executable->set_ir_module_string(ir_module_string_before_opt);
|
||||
}
|
||||
|
||||
for (Thunk* thunk : thunks) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
thunk->Initialize(gpu_executable->target_binary(), stream_exec));
|
||||
}
|
||||
|
||||
return std::unique_ptr<Executable>(gpu_executable);
|
||||
}
|
||||
|
||||
|
@ -74,10 +74,9 @@ class GpuCompiler : public LLVMCompiler {
|
||||
|
||||
virtual GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) = 0;
|
||||
|
||||
virtual StatusOr<std::pair<std::string, std::vector<uint8>>>
|
||||
CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||
GpuVersion gpu_version,
|
||||
se::StreamExecutor* stream_exec) = 0;
|
||||
virtual StatusOr<GpuTargetBinary> CompileTargetBinary(
|
||||
const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec) = 0;
|
||||
|
||||
Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
|
||||
|
||||
|
@ -52,16 +52,15 @@ using ::tensorflow::profiler::ScopedAnnotation;
|
||||
// Implementation note: HLO profiling is always enabled for GPU executables,
|
||||
// since we can use timers around thunks.
|
||||
GpuExecutable::GpuExecutable(
|
||||
const string& text, const std::vector<uint8>& binary,
|
||||
GpuVersion gpu_version, std::unique_ptr<const ThunkSchedule> thunk_schedule,
|
||||
GpuTargetBinary target_binary, GpuVersion gpu_version,
|
||||
std::unique_ptr<const ThunkSchedule> thunk_schedule,
|
||||
std::shared_ptr<HloModule> hlo_module,
|
||||
std::shared_ptr<const BufferAssignment> assignment,
|
||||
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
|
||||
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
|
||||
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
|
||||
std::move(hlo_profile_index_map)),
|
||||
text_(text),
|
||||
binary_(binary),
|
||||
target_binary_(std::move(target_binary)),
|
||||
gpu_version_(gpu_version),
|
||||
thunk_schedule_(std::move(thunk_schedule)),
|
||||
assignment_(std::move(assignment)) {
|
||||
@ -176,7 +175,6 @@ Status GpuExecutable::ExecuteThunks(
|
||||
// module, we won't get any data, but that's probably an OK trade-off.
|
||||
ScopedAnnotation annotation([&] { return thunk->profile_annotation(); });
|
||||
|
||||
TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor));
|
||||
int32 stream_no =
|
||||
thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction());
|
||||
se::Stream* stream =
|
||||
@ -469,7 +467,7 @@ const InstructionValueSet& GpuExecutable::GetRootValueSet() const {
|
||||
int64 GpuExecutable::SizeOfGeneratedCodeInBytes() {
|
||||
// Non-empty PTX but empty cubin: compilation must have failed, return
|
||||
// "unknown".
|
||||
if (binary().empty() && !text_.empty()) {
|
||||
if (binary().empty() && !text().empty()) {
|
||||
return -1;
|
||||
}
|
||||
return binary().size();
|
||||
|
@ -52,8 +52,7 @@ class GpuExecutable : public Executable {
|
||||
// We need to share ownership of hlo_module and assignment with profiler to
|
||||
// safely keep a reference to these objects during tracing period, thus they
|
||||
// are passed as shared pointers.
|
||||
GpuExecutable(const string& text, const std::vector<uint8>& binary,
|
||||
GpuVersion gpu_version,
|
||||
GpuExecutable(GpuTargetBinary target_binary, GpuVersion gpu_version,
|
||||
std::unique_ptr<const ThunkSchedule> thunk_schedule,
|
||||
std::shared_ptr<HloModule> hlo_module,
|
||||
std::shared_ptr<const BufferAssignment> assignment,
|
||||
@ -73,12 +72,14 @@ class GpuExecutable : public Executable {
|
||||
|
||||
// Returns the compiled code for the computation. The compiled code is PTX in
|
||||
// Cuda and unused empty string in ROCm.
|
||||
const string& text() const { return text_; }
|
||||
const string& text() const { return target_binary_.text; }
|
||||
|
||||
// Returns the binary stored in this GpuExecutable. The binary is cubin in
|
||||
// Cuda, and HSA code object in ROCm. It may be empty, in which case
|
||||
// compilation is left up to the GPU driver.
|
||||
const std::vector<uint8>& binary() const { return binary_; }
|
||||
const std::vector<uint8>& binary() const { return target_binary_.binary; }
|
||||
|
||||
const GpuTargetBinary& target_binary() const { return target_binary_; }
|
||||
|
||||
// ExecuteAsyncOnStream will fail if the compute capability of the stream
|
||||
// doesn't match the compute capability passed to this object's constructor.
|
||||
@ -131,14 +132,7 @@ class GpuExecutable : public Executable {
|
||||
// This string should be modified only before ExecuteOnStream.
|
||||
string ir_module_string_;
|
||||
|
||||
// The compiled code for the computation.
|
||||
const string text_;
|
||||
|
||||
// The GPU machine code for the computation, targeting GPUs at
|
||||
// compute_capability_.
|
||||
//
|
||||
// May be empty, in which case we leave compilation up to the GPU driver.
|
||||
const std::vector<uint8> binary_;
|
||||
const GpuTargetBinary target_binary_;
|
||||
|
||||
// The GPU version for compute compatibility check.
|
||||
GpuVersion gpu_version_;
|
||||
|
@ -16,7 +16,11 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
@ -25,6 +29,19 @@ namespace gpu {
|
||||
// it comprises a pair of integers denoting major and minor version.
|
||||
// On ROCm platform, it comprises one integer for AMD GCN ISA version.
|
||||
using GpuVersion = absl::variant<std::pair<int, int>, int>;
|
||||
|
||||
// A struct to carry around compiled results by the GPU assembler.
|
||||
struct GpuTargetBinary {
|
||||
GpuTargetBinary(const GpuTargetBinary& other) = delete;
|
||||
GpuTargetBinary(GpuTargetBinary&& other) = default;
|
||||
|
||||
// The text format of the compiled result, e.g. PTX.
|
||||
std::string text;
|
||||
|
||||
// The actual compiled binary.
|
||||
std::vector<tensorflow::uint8> binary;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -42,7 +41,7 @@ KernelThunk::KernelThunk(absl::Span<const BufferAllocation* const> args,
|
||||
kernel_name_(kernel_name),
|
||||
unroll_factor_(unroll_factor) {}
|
||||
|
||||
Status KernelThunk::Initialize(const GpuExecutable& executable,
|
||||
Status KernelThunk::Initialize(const GpuTargetBinary& target_binary,
|
||||
se::StreamExecutor* executor) {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
|
||||
@ -55,8 +54,10 @@ Status KernelThunk::Initialize(const GpuExecutable& executable,
|
||||
if (kernel_cache_.end() == it) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<se::KernelBase> kernel,
|
||||
CreateKernel(kernel_name_, args_.size(), executable.text(),
|
||||
executable.binary(), executor));
|
||||
CreateKernel(kernel_name_, args_.size(), target_binary.text,
|
||||
target_binary.binary, executor));
|
||||
CHECK(!target_binary.binary.empty());
|
||||
CHECK(kernel);
|
||||
|
||||
kernel_cache_.emplace(executor, std::move(kernel));
|
||||
}
|
||||
|
@ -35,8 +35,6 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
class GpuExecutable;
|
||||
|
||||
// This class stores everything that StreamExecutor needs for launching a
|
||||
// kernel. It implements the ExecuteOnStream interface for GpuExecutable to
|
||||
// invoke the corresponding kernel.
|
||||
@ -58,7 +56,7 @@ class KernelThunk : public Thunk {
|
||||
int unroll_factor() const { return unroll_factor_; }
|
||||
void SetLaunchDimensions(const LaunchDimensions& launch_dims);
|
||||
|
||||
Status Initialize(const GpuExecutable& executable,
|
||||
Status Initialize(const GpuTargetBinary& target_binary,
|
||||
se::StreamExecutor* executor) override;
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
|
@ -295,11 +295,9 @@ GpuVersion NVPTXCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) {
|
||||
return std::make_pair(cc_major, cc_minor);
|
||||
}
|
||||
|
||||
StatusOr<std::pair<std::string, std::vector<uint8>>>
|
||||
NVPTXCompiler::CompileTargetBinary(const HloModule* module,
|
||||
llvm::Module* llvm_module,
|
||||
GpuVersion gpu_version,
|
||||
se::StreamExecutor* stream_exec) {
|
||||
StatusOr<GpuTargetBinary> NVPTXCompiler::CompileTargetBinary(
|
||||
const HloModule* module, llvm::Module* llvm_module, GpuVersion gpu_version,
|
||||
se::StreamExecutor* stream_exec) {
|
||||
std::pair<int, int> compute_capability =
|
||||
absl::get<std::pair<int, int>>(gpu_version);
|
||||
|
||||
@ -340,8 +338,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module,
|
||||
stream_exec, ptx, compute_capability.first, compute_capability.second,
|
||||
module->config());
|
||||
|
||||
return std::pair<std::string, std::vector<uint8>>(std::move(ptx),
|
||||
std::move(cubin));
|
||||
return GpuTargetBinary{std::move(ptx), std::move(cubin)};
|
||||
}
|
||||
|
||||
std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
|
||||
|
@ -48,7 +48,7 @@ class NVPTXCompiler : public GpuCompiler {
|
||||
|
||||
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
|
||||
|
||||
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
||||
StatusOr<GpuTargetBinary> CompileTargetBinary(
|
||||
const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
|
||||
|
||||
|
@ -34,10 +34,10 @@ void SequentialThunk::ComputeAnnotations() {
|
||||
}
|
||||
}
|
||||
|
||||
Status SequentialThunk::Initialize(const GpuExecutable& executable,
|
||||
Status SequentialThunk::Initialize(const GpuTargetBinary& target_binary,
|
||||
se::StreamExecutor* executor) {
|
||||
for (auto& thunk : thunks_) {
|
||||
TF_RETURN_IF_ERROR(thunk->Initialize(executable, executor));
|
||||
TF_RETURN_IF_ERROR(thunk->Initialize(target_binary, executor));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ class SequentialThunk : public Thunk {
|
||||
const std::vector<std::unique_ptr<Thunk>>& thunks() const { return thunks_; }
|
||||
|
||||
void ComputeAnnotations() override;
|
||||
Status Initialize(const GpuExecutable& executable,
|
||||
Status Initialize(const GpuTargetBinary& target_binary,
|
||||
se::StreamExecutor* executor) override;
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -30,8 +31,6 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
class GpuExecutable;
|
||||
|
||||
// Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the
|
||||
// metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction.
|
||||
//
|
||||
@ -97,7 +96,7 @@ class Thunk {
|
||||
// This may be called multiple times. Its main purpose is to give us a chance
|
||||
// to do initialization outside of ExecuteOnStream() so that the
|
||||
// time spent initializing doesn't count towards our execution profile.
|
||||
virtual Status Initialize(const GpuExecutable& /*executable*/,
|
||||
virtual Status Initialize(const GpuTargetBinary& /*target_binary*/,
|
||||
se::StreamExecutor* /*executor*/) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -45,11 +45,11 @@ void WhileThunk::ComputeAnnotations() {
|
||||
body_thunk_sequence_->ComputeAnnotations();
|
||||
}
|
||||
|
||||
Status WhileThunk::Initialize(const GpuExecutable& executable,
|
||||
Status WhileThunk::Initialize(const GpuTargetBinary& target_binary,
|
||||
se::StreamExecutor* executor) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
condition_thunk_sequence_->Initialize(executable, executor));
|
||||
TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor));
|
||||
condition_thunk_sequence_->Initialize(target_binary, executor));
|
||||
TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(target_binary, executor));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -47,7 +47,7 @@ class WhileThunk : public Thunk {
|
||||
WhileThunk& operator=(const WhileThunk&) = delete;
|
||||
|
||||
void ComputeAnnotations() override;
|
||||
Status Initialize(const GpuExecutable& executable,
|
||||
Status Initialize(const GpuTargetBinary& target_binary,
|
||||
se::StreamExecutor* executor) override;
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
|
@ -549,10 +549,11 @@ StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
|
||||
}
|
||||
|
||||
// TODO(b/137624192): Add profiling support.
|
||||
|
||||
return {absl::make_unique<GpuExecutable>(
|
||||
ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule),
|
||||
emission_context.releaseHloModule(), std::move(buffer_assignment),
|
||||
nullptr, nullptr)};
|
||||
xla::gpu::GpuTargetBinary{ptx, cubin}, GetGpuVersion(stream_exec),
|
||||
std::move(thunk_schedule), emission_context.releaseHloModule(),
|
||||
std::move(buffer_assignment), nullptr, nullptr)};
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile(
|
||||
|
@ -55,16 +55,15 @@ class GpuDummyCompiler : public GpuCompiler {
|
||||
|
||||
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) { return 0; }
|
||||
|
||||
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
||||
StatusOr<GpuTargetBinary> CompileTargetBinary(
|
||||
const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec) {
|
||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override {
|
||||
if (user_post_optimization_hook_) {
|
||||
user_post_optimization_hook_(*llvm_module);
|
||||
}
|
||||
|
||||
std::vector<uint8> compiled_results;
|
||||
return std::pair<std::string, std::vector<uint8>>(
|
||||
"", std::move(compiled_results));
|
||||
return GpuTargetBinary{"", std::move(compiled_results)};
|
||||
}
|
||||
};
|
||||
} // namespace gpu
|
||||
|
Loading…
x
Reference in New Issue
Block a user