[XLA/GPU] Make Thunk::Initialize() happen at compile-time, not run-time. This simplifies GpuExecutable for MLIR transition.
PiperOrigin-RevId: 311178815 Change-Id: Ib9c8b8a2f8719c0cd8b342ab07af6e8cb65d82bf
This commit is contained in:
parent
ac7e71cc3d
commit
9f58e6902c
@ -17,15 +17,15 @@ load(
|
|||||||
"tf_cuda_library",
|
"tf_cuda_library",
|
||||||
)
|
)
|
||||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
load(
|
||||||
|
"@local_config_rocm//rocm:build_defs.bzl",
|
||||||
|
"if_rocm",
|
||||||
|
"if_rocm_is_configured",
|
||||||
|
)
|
||||||
load(
|
load(
|
||||||
"//tensorflow/core/platform/default:cuda_build_defs.bzl",
|
"//tensorflow/core/platform/default:cuda_build_defs.bzl",
|
||||||
"if_cuda_is_configured",
|
"if_cuda_is_configured",
|
||||||
)
|
)
|
||||||
load(
|
|
||||||
"@local_config_rocm//rocm:build_defs.bzl",
|
|
||||||
"if_rocm_is_configured",
|
|
||||||
)
|
|
||||||
load("//tensorflow:tensorflow.bzl", "if_nccl")
|
load("//tensorflow:tensorflow.bzl", "if_nccl")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
@ -86,6 +86,7 @@ cc_library(
|
|||||||
name = "gpu_types",
|
name = "gpu_types",
|
||||||
hdrs = ["gpu_types.h"],
|
hdrs = ["gpu_types.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:types",
|
||||||
"@com_google_absl//absl/types:variant",
|
"@com_google_absl//absl/types:variant",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -405,6 +406,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":buffer_allocations",
|
":buffer_allocations",
|
||||||
":gpu_executable_run_options",
|
":gpu_executable_run_options",
|
||||||
|
":gpu_types",
|
||||||
":hlo_execution_profiler",
|
":hlo_execution_profiler",
|
||||||
"//tensorflow/compiler/xla:executable_run_options",
|
"//tensorflow/compiler/xla:executable_run_options",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
@ -104,11 +104,9 @@ GpuVersion AMDGPUCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) {
|
|||||||
return isa_version;
|
return isa_version;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::pair<std::string, std::vector<uint8>>>
|
StatusOr<GpuTargetBinary> AMDGPUCompiler::CompileTargetBinary(
|
||||||
AMDGPUCompiler::CompileTargetBinary(const HloModule* module,
|
const HloModule* module, llvm::Module* llvm_module, GpuVersion gpu_version,
|
||||||
llvm::Module* llvm_module,
|
se::StreamExecutor* stream_exec) {
|
||||||
GpuVersion gpu_version,
|
|
||||||
se::StreamExecutor* stream_exec) {
|
|
||||||
if (rocdl_dir_.empty()) {
|
if (rocdl_dir_.empty()) {
|
||||||
// Compute rocdl_dir_ just once and cache it in this member.
|
// Compute rocdl_dir_ just once and cache it in this member.
|
||||||
rocdl_dir_ = GetROCDLDir(module->config());
|
rocdl_dir_ = GetROCDLDir(module->config());
|
||||||
@ -129,7 +127,7 @@ AMDGPUCompiler::CompileTargetBinary(const HloModule* module,
|
|||||||
user_post_optimization_hook_(*llvm_module);
|
user_post_optimization_hook_(*llvm_module);
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::pair<std::string, std::vector<uint8>>("", std::move(hsaco));
|
return GpuTargetBinary{"", std::move(hsaco)};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -39,7 +39,7 @@ class AMDGPUCompiler : public GpuCompiler {
|
|||||||
|
|
||||||
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
|
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
|
||||||
|
|
||||||
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
StatusOr<GpuTargetBinary> CompileTargetBinary(
|
||||||
const HloModule* hlo_module, llvm::Module* llvm_module,
|
const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
|
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ void ConditionalThunk::ComputeAnnotations() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConditionalThunk::Initialize(const GpuExecutable& executable,
|
Status ConditionalThunk::Initialize(const GpuTargetBinary& target_binary,
|
||||||
se::StreamExecutor* executor) {
|
se::StreamExecutor* executor) {
|
||||||
if (branch_index_is_bool_) {
|
if (branch_index_is_bool_) {
|
||||||
TF_RET_CHECK(branch_thunks_.size() == 2);
|
TF_RET_CHECK(branch_thunks_.size() == 2);
|
||||||
@ -58,7 +58,7 @@ Status ConditionalThunk::Initialize(const GpuExecutable& executable,
|
|||||||
TF_RET_CHECK(!branch_thunks_.empty());
|
TF_RET_CHECK(!branch_thunks_.empty());
|
||||||
}
|
}
|
||||||
for (auto& branch_thunk : branch_thunks_) {
|
for (auto& branch_thunk : branch_thunks_) {
|
||||||
TF_RETURN_IF_ERROR(branch_thunk->Initialize(executable, executor));
|
TF_RETURN_IF_ERROR(branch_thunk->Initialize(target_binary, executor));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -52,7 +52,7 @@ class ConditionalThunk : public Thunk {
|
|||||||
ConditionalThunk& operator=(const ConditionalThunk&) = delete;
|
ConditionalThunk& operator=(const ConditionalThunk&) = delete;
|
||||||
|
|
||||||
void ComputeAnnotations() override;
|
void ComputeAnnotations() override;
|
||||||
Status Initialize(const GpuExecutable& executable,
|
Status Initialize(const GpuTargetBinary& target_binary,
|
||||||
se::StreamExecutor* executor) override;
|
se::StreamExecutor* executor) override;
|
||||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||||
|
|
||||||
|
@ -39,9 +39,9 @@ void ForThunk::ComputeAnnotations() {
|
|||||||
body_thunk_sequence_->ComputeAnnotations();
|
body_thunk_sequence_->ComputeAnnotations();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ForThunk::Initialize(const GpuExecutable& executable,
|
Status ForThunk::Initialize(const GpuTargetBinary& target_binary,
|
||||||
se::StreamExecutor* executor) {
|
se::StreamExecutor* executor) {
|
||||||
TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor));
|
TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(target_binary, executor));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ class ForThunk : public Thunk {
|
|||||||
ForThunk& operator=(const ForThunk&) = delete;
|
ForThunk& operator=(const ForThunk&) = delete;
|
||||||
|
|
||||||
void ComputeAnnotations() override;
|
void ComputeAnnotations() override;
|
||||||
Status Initialize(const GpuExecutable& executable,
|
Status Initialize(const GpuTargetBinary& target_binary,
|
||||||
se::StreamExecutor* executor) override;
|
se::StreamExecutor* executor) override;
|
||||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||||
|
|
||||||
|
@ -565,8 +565,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
|||||||
|
|
||||||
GpuVersion gpu_version = GetGpuVersion(stream_exec);
|
GpuVersion gpu_version = GetGpuVersion(stream_exec);
|
||||||
|
|
||||||
using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
|
TF_ASSIGN_OR_RETURN(GpuTargetBinary backend_result,
|
||||||
TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
|
|
||||||
CompileTargetBinary(module.get(), &llvm_module,
|
CompileTargetBinary(module.get(), &llvm_module,
|
||||||
gpu_version, stream_exec));
|
gpu_version, stream_exec));
|
||||||
|
|
||||||
@ -578,6 +577,11 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
|||||||
thunk_schedule->ToString());
|
thunk_schedule->ToString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<Thunk*> thunks;
|
||||||
|
for (Thunk* thunk : thunk_schedule->TotalOrder()) {
|
||||||
|
thunks.push_back(thunk);
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<HloProfileIndexMap> profile_index_map;
|
std::unique_ptr<HloProfileIndexMap> profile_index_map;
|
||||||
std::unique_ptr<HloProfilePrinterData> profile_printer;
|
std::unique_ptr<HloProfilePrinterData> profile_printer;
|
||||||
|
|
||||||
@ -597,14 +601,19 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto* gpu_executable = new GpuExecutable(
|
auto* gpu_executable = new GpuExecutable(
|
||||||
backend_result.first, backend_result.second, gpu_version,
|
std::move(backend_result), gpu_version, std::move(thunk_schedule),
|
||||||
std::move(thunk_schedule), std::move(module),
|
std::move(module), std::move(buffer_assignment),
|
||||||
std::move(buffer_assignment), std::move(profile_printer),
|
std::move(profile_printer), std::move(profile_index_map));
|
||||||
std::move(profile_index_map));
|
|
||||||
if (embed_ir_in_executable) {
|
if (embed_ir_in_executable) {
|
||||||
DCHECK_NE("", ir_module_string_before_opt);
|
DCHECK_NE("", ir_module_string_before_opt);
|
||||||
gpu_executable->set_ir_module_string(ir_module_string_before_opt);
|
gpu_executable->set_ir_module_string(ir_module_string_before_opt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (Thunk* thunk : thunks) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
thunk->Initialize(gpu_executable->target_binary(), stream_exec));
|
||||||
|
}
|
||||||
|
|
||||||
return std::unique_ptr<Executable>(gpu_executable);
|
return std::unique_ptr<Executable>(gpu_executable);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,10 +74,9 @@ class GpuCompiler : public LLVMCompiler {
|
|||||||
|
|
||||||
virtual GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) = 0;
|
virtual GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) = 0;
|
||||||
|
|
||||||
virtual StatusOr<std::pair<std::string, std::vector<uint8>>>
|
virtual StatusOr<GpuTargetBinary> CompileTargetBinary(
|
||||||
CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module,
|
const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||||
GpuVersion gpu_version,
|
GpuVersion gpu_version, se::StreamExecutor* stream_exec) = 0;
|
||||||
se::StreamExecutor* stream_exec) = 0;
|
|
||||||
|
|
||||||
Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
|
Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
|
||||||
|
|
||||||
|
@ -52,16 +52,15 @@ using ::tensorflow::profiler::ScopedAnnotation;
|
|||||||
// Implementation note: HLO profiling is always enabled for GPU executables,
|
// Implementation note: HLO profiling is always enabled for GPU executables,
|
||||||
// since we can use timers around thunks.
|
// since we can use timers around thunks.
|
||||||
GpuExecutable::GpuExecutable(
|
GpuExecutable::GpuExecutable(
|
||||||
const string& text, const std::vector<uint8>& binary,
|
GpuTargetBinary target_binary, GpuVersion gpu_version,
|
||||||
GpuVersion gpu_version, std::unique_ptr<const ThunkSchedule> thunk_schedule,
|
std::unique_ptr<const ThunkSchedule> thunk_schedule,
|
||||||
std::shared_ptr<HloModule> hlo_module,
|
std::shared_ptr<HloModule> hlo_module,
|
||||||
std::shared_ptr<const BufferAssignment> assignment,
|
std::shared_ptr<const BufferAssignment> assignment,
|
||||||
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
|
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
|
||||||
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
|
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
|
||||||
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
|
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
|
||||||
std::move(hlo_profile_index_map)),
|
std::move(hlo_profile_index_map)),
|
||||||
text_(text),
|
target_binary_(std::move(target_binary)),
|
||||||
binary_(binary),
|
|
||||||
gpu_version_(gpu_version),
|
gpu_version_(gpu_version),
|
||||||
thunk_schedule_(std::move(thunk_schedule)),
|
thunk_schedule_(std::move(thunk_schedule)),
|
||||||
assignment_(std::move(assignment)) {
|
assignment_(std::move(assignment)) {
|
||||||
@ -176,7 +175,6 @@ Status GpuExecutable::ExecuteThunks(
|
|||||||
// module, we won't get any data, but that's probably an OK trade-off.
|
// module, we won't get any data, but that's probably an OK trade-off.
|
||||||
ScopedAnnotation annotation([&] { return thunk->profile_annotation(); });
|
ScopedAnnotation annotation([&] { return thunk->profile_annotation(); });
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor));
|
|
||||||
int32 stream_no =
|
int32 stream_no =
|
||||||
thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction());
|
thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction());
|
||||||
se::Stream* stream =
|
se::Stream* stream =
|
||||||
@ -469,7 +467,7 @@ const InstructionValueSet& GpuExecutable::GetRootValueSet() const {
|
|||||||
int64 GpuExecutable::SizeOfGeneratedCodeInBytes() {
|
int64 GpuExecutable::SizeOfGeneratedCodeInBytes() {
|
||||||
// Non-empty PTX but empty cubin: compilation must have failed, return
|
// Non-empty PTX but empty cubin: compilation must have failed, return
|
||||||
// "unknown".
|
// "unknown".
|
||||||
if (binary().empty() && !text_.empty()) {
|
if (binary().empty() && !text().empty()) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
return binary().size();
|
return binary().size();
|
||||||
|
@ -52,8 +52,7 @@ class GpuExecutable : public Executable {
|
|||||||
// We need to share ownership of hlo_module and assignment with profiler to
|
// We need to share ownership of hlo_module and assignment with profiler to
|
||||||
// safely keep a reference to these objects during tracing period, thus they
|
// safely keep a reference to these objects during tracing period, thus they
|
||||||
// are passed as shared pointers.
|
// are passed as shared pointers.
|
||||||
GpuExecutable(const string& text, const std::vector<uint8>& binary,
|
GpuExecutable(GpuTargetBinary target_binary, GpuVersion gpu_version,
|
||||||
GpuVersion gpu_version,
|
|
||||||
std::unique_ptr<const ThunkSchedule> thunk_schedule,
|
std::unique_ptr<const ThunkSchedule> thunk_schedule,
|
||||||
std::shared_ptr<HloModule> hlo_module,
|
std::shared_ptr<HloModule> hlo_module,
|
||||||
std::shared_ptr<const BufferAssignment> assignment,
|
std::shared_ptr<const BufferAssignment> assignment,
|
||||||
@ -73,12 +72,14 @@ class GpuExecutable : public Executable {
|
|||||||
|
|
||||||
// Returns the compiled code for the computation. The compiled code is PTX in
|
// Returns the compiled code for the computation. The compiled code is PTX in
|
||||||
// Cuda and unused empty string in ROCm.
|
// Cuda and unused empty string in ROCm.
|
||||||
const string& text() const { return text_; }
|
const string& text() const { return target_binary_.text; }
|
||||||
|
|
||||||
// Returns the binary stored in this GpuExecutable. The binary is cubin in
|
// Returns the binary stored in this GpuExecutable. The binary is cubin in
|
||||||
// Cuda, and HSA code object in ROCm. It may be empty, in which case
|
// Cuda, and HSA code object in ROCm. It may be empty, in which case
|
||||||
// compilation is left up to the GPU driver.
|
// compilation is left up to the GPU driver.
|
||||||
const std::vector<uint8>& binary() const { return binary_; }
|
const std::vector<uint8>& binary() const { return target_binary_.binary; }
|
||||||
|
|
||||||
|
const GpuTargetBinary& target_binary() const { return target_binary_; }
|
||||||
|
|
||||||
// ExecuteAsyncOnStream will fail if the compute capability of the stream
|
// ExecuteAsyncOnStream will fail if the compute capability of the stream
|
||||||
// doesn't match the compute capability passed to this object's constructor.
|
// doesn't match the compute capability passed to this object's constructor.
|
||||||
@ -131,14 +132,7 @@ class GpuExecutable : public Executable {
|
|||||||
// This string should be modified only before ExecuteOnStream.
|
// This string should be modified only before ExecuteOnStream.
|
||||||
string ir_module_string_;
|
string ir_module_string_;
|
||||||
|
|
||||||
// The compiled code for the computation.
|
const GpuTargetBinary target_binary_;
|
||||||
const string text_;
|
|
||||||
|
|
||||||
// The GPU machine code for the computation, targeting GPUs at
|
|
||||||
// compute_capability_.
|
|
||||||
//
|
|
||||||
// May be empty, in which case we leave compilation up to the GPU driver.
|
|
||||||
const std::vector<uint8> binary_;
|
|
||||||
|
|
||||||
// The GPU version for compute compatibility check.
|
// The GPU version for compute compatibility check.
|
||||||
GpuVersion gpu_version_;
|
GpuVersion gpu_version_;
|
||||||
|
@ -16,7 +16,11 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/types/variant.h"
|
#include "absl/types/variant.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
@ -25,6 +29,19 @@ namespace gpu {
|
|||||||
// it comprises a pair of integers denoting major and minor version.
|
// it comprises a pair of integers denoting major and minor version.
|
||||||
// On ROCm platform, it comprises one integer for AMD GCN ISA version.
|
// On ROCm platform, it comprises one integer for AMD GCN ISA version.
|
||||||
using GpuVersion = absl::variant<std::pair<int, int>, int>;
|
using GpuVersion = absl::variant<std::pair<int, int>, int>;
|
||||||
|
|
||||||
|
// A struct to carry around compiled results by the GPU assembler.
|
||||||
|
struct GpuTargetBinary {
|
||||||
|
GpuTargetBinary(const GpuTargetBinary& other) = delete;
|
||||||
|
GpuTargetBinary(GpuTargetBinary&& other) = default;
|
||||||
|
|
||||||
|
// The text format of the compiled result, e.g. PTX.
|
||||||
|
std::string text;
|
||||||
|
|
||||||
|
// The actual compiled binary.
|
||||||
|
std::vector<tensorflow::uint8> binary;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
|||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
|
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
@ -42,7 +41,7 @@ KernelThunk::KernelThunk(absl::Span<const BufferAllocation* const> args,
|
|||||||
kernel_name_(kernel_name),
|
kernel_name_(kernel_name),
|
||||||
unroll_factor_(unroll_factor) {}
|
unroll_factor_(unroll_factor) {}
|
||||||
|
|
||||||
Status KernelThunk::Initialize(const GpuExecutable& executable,
|
Status KernelThunk::Initialize(const GpuTargetBinary& target_binary,
|
||||||
se::StreamExecutor* executor) {
|
se::StreamExecutor* executor) {
|
||||||
tensorflow::mutex_lock lock(mutex_);
|
tensorflow::mutex_lock lock(mutex_);
|
||||||
|
|
||||||
@ -55,8 +54,10 @@ Status KernelThunk::Initialize(const GpuExecutable& executable,
|
|||||||
if (kernel_cache_.end() == it) {
|
if (kernel_cache_.end() == it) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<se::KernelBase> kernel,
|
std::unique_ptr<se::KernelBase> kernel,
|
||||||
CreateKernel(kernel_name_, args_.size(), executable.text(),
|
CreateKernel(kernel_name_, args_.size(), target_binary.text,
|
||||||
executable.binary(), executor));
|
target_binary.binary, executor));
|
||||||
|
CHECK(!target_binary.binary.empty());
|
||||||
|
CHECK(kernel);
|
||||||
|
|
||||||
kernel_cache_.emplace(executor, std::move(kernel));
|
kernel_cache_.emplace(executor, std::move(kernel));
|
||||||
}
|
}
|
||||||
|
@ -35,8 +35,6 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
class GpuExecutable;
|
|
||||||
|
|
||||||
// This class stores everything that StreamExecutor needs for launching a
|
// This class stores everything that StreamExecutor needs for launching a
|
||||||
// kernel. It implements the ExecuteOnStream interface for GpuExecutable to
|
// kernel. It implements the ExecuteOnStream interface for GpuExecutable to
|
||||||
// invoke the corresponding kernel.
|
// invoke the corresponding kernel.
|
||||||
@ -58,7 +56,7 @@ class KernelThunk : public Thunk {
|
|||||||
int unroll_factor() const { return unroll_factor_; }
|
int unroll_factor() const { return unroll_factor_; }
|
||||||
void SetLaunchDimensions(const LaunchDimensions& launch_dims);
|
void SetLaunchDimensions(const LaunchDimensions& launch_dims);
|
||||||
|
|
||||||
Status Initialize(const GpuExecutable& executable,
|
Status Initialize(const GpuTargetBinary& target_binary,
|
||||||
se::StreamExecutor* executor) override;
|
se::StreamExecutor* executor) override;
|
||||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||||
|
|
||||||
|
@ -295,11 +295,9 @@ GpuVersion NVPTXCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) {
|
|||||||
return std::make_pair(cc_major, cc_minor);
|
return std::make_pair(cc_major, cc_minor);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::pair<std::string, std::vector<uint8>>>
|
StatusOr<GpuTargetBinary> NVPTXCompiler::CompileTargetBinary(
|
||||||
NVPTXCompiler::CompileTargetBinary(const HloModule* module,
|
const HloModule* module, llvm::Module* llvm_module, GpuVersion gpu_version,
|
||||||
llvm::Module* llvm_module,
|
se::StreamExecutor* stream_exec) {
|
||||||
GpuVersion gpu_version,
|
|
||||||
se::StreamExecutor* stream_exec) {
|
|
||||||
std::pair<int, int> compute_capability =
|
std::pair<int, int> compute_capability =
|
||||||
absl::get<std::pair<int, int>>(gpu_version);
|
absl::get<std::pair<int, int>>(gpu_version);
|
||||||
|
|
||||||
@ -340,8 +338,7 @@ NVPTXCompiler::CompileTargetBinary(const HloModule* module,
|
|||||||
stream_exec, ptx, compute_capability.first, compute_capability.second,
|
stream_exec, ptx, compute_capability.first, compute_capability.second,
|
||||||
module->config());
|
module->config());
|
||||||
|
|
||||||
return std::pair<std::string, std::vector<uint8>>(std::move(ptx),
|
return GpuTargetBinary{std::move(ptx), std::move(cubin)};
|
||||||
std::move(cubin));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
|
std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
|
||||||
|
@ -48,7 +48,7 @@ class NVPTXCompiler : public GpuCompiler {
|
|||||||
|
|
||||||
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
|
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
|
||||||
|
|
||||||
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
StatusOr<GpuTargetBinary> CompileTargetBinary(
|
||||||
const HloModule* hlo_module, llvm::Module* llvm_module,
|
const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
|
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
|
||||||
|
|
||||||
|
@ -34,10 +34,10 @@ void SequentialThunk::ComputeAnnotations() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SequentialThunk::Initialize(const GpuExecutable& executable,
|
Status SequentialThunk::Initialize(const GpuTargetBinary& target_binary,
|
||||||
se::StreamExecutor* executor) {
|
se::StreamExecutor* executor) {
|
||||||
for (auto& thunk : thunks_) {
|
for (auto& thunk : thunks_) {
|
||||||
TF_RETURN_IF_ERROR(thunk->Initialize(executable, executor));
|
TF_RETURN_IF_ERROR(thunk->Initialize(target_binary, executor));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -40,7 +40,7 @@ class SequentialThunk : public Thunk {
|
|||||||
const std::vector<std::unique_ptr<Thunk>>& thunks() const { return thunks_; }
|
const std::vector<std::unique_ptr<Thunk>>& thunks() const { return thunks_; }
|
||||||
|
|
||||||
void ComputeAnnotations() override;
|
void ComputeAnnotations() override;
|
||||||
Status Initialize(const GpuExecutable& executable,
|
Status Initialize(const GpuTargetBinary& target_binary,
|
||||||
se::StreamExecutor* executor) override;
|
se::StreamExecutor* executor) override;
|
||||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -30,8 +31,6 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
class GpuExecutable;
|
|
||||||
|
|
||||||
// Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the
|
// Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the
|
||||||
// metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction.
|
// metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction.
|
||||||
//
|
//
|
||||||
@ -97,7 +96,7 @@ class Thunk {
|
|||||||
// This may be called multiple times. Its main purpose is to give us a chance
|
// This may be called multiple times. Its main purpose is to give us a chance
|
||||||
// to do initialization outside of ExecuteOnStream() so that the
|
// to do initialization outside of ExecuteOnStream() so that the
|
||||||
// time spent initializing doesn't count towards our execution profile.
|
// time spent initializing doesn't count towards our execution profile.
|
||||||
virtual Status Initialize(const GpuExecutable& /*executable*/,
|
virtual Status Initialize(const GpuTargetBinary& /*target_binary*/,
|
||||||
se::StreamExecutor* /*executor*/) {
|
se::StreamExecutor* /*executor*/) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -45,11 +45,11 @@ void WhileThunk::ComputeAnnotations() {
|
|||||||
body_thunk_sequence_->ComputeAnnotations();
|
body_thunk_sequence_->ComputeAnnotations();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status WhileThunk::Initialize(const GpuExecutable& executable,
|
Status WhileThunk::Initialize(const GpuTargetBinary& target_binary,
|
||||||
se::StreamExecutor* executor) {
|
se::StreamExecutor* executor) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
condition_thunk_sequence_->Initialize(executable, executor));
|
condition_thunk_sequence_->Initialize(target_binary, executor));
|
||||||
TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor));
|
TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(target_binary, executor));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ class WhileThunk : public Thunk {
|
|||||||
WhileThunk& operator=(const WhileThunk&) = delete;
|
WhileThunk& operator=(const WhileThunk&) = delete;
|
||||||
|
|
||||||
void ComputeAnnotations() override;
|
void ComputeAnnotations() override;
|
||||||
Status Initialize(const GpuExecutable& executable,
|
Status Initialize(const GpuTargetBinary& target_binary,
|
||||||
se::StreamExecutor* executor) override;
|
se::StreamExecutor* executor) override;
|
||||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||||
|
|
||||||
|
@ -549,10 +549,11 @@ StatusOr<std::unique_ptr<Executable>> MlirCompilerImpl::RunBackend(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO(b/137624192): Add profiling support.
|
// TODO(b/137624192): Add profiling support.
|
||||||
|
|
||||||
return {absl::make_unique<GpuExecutable>(
|
return {absl::make_unique<GpuExecutable>(
|
||||||
ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule),
|
xla::gpu::GpuTargetBinary{ptx, cubin}, GetGpuVersion(stream_exec),
|
||||||
emission_context.releaseHloModule(), std::move(buffer_assignment),
|
std::move(thunk_schedule), emission_context.releaseHloModule(),
|
||||||
nullptr, nullptr)};
|
std::move(buffer_assignment), nullptr, nullptr)};
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile(
|
StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile(
|
||||||
|
@ -55,16 +55,15 @@ class GpuDummyCompiler : public GpuCompiler {
|
|||||||
|
|
||||||
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) { return 0; }
|
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) { return 0; }
|
||||||
|
|
||||||
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
StatusOr<GpuTargetBinary> CompileTargetBinary(
|
||||||
const HloModule* hlo_module, llvm::Module* llvm_module,
|
const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec) {
|
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override {
|
||||||
if (user_post_optimization_hook_) {
|
if (user_post_optimization_hook_) {
|
||||||
user_post_optimization_hook_(*llvm_module);
|
user_post_optimization_hook_(*llvm_module);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint8> compiled_results;
|
std::vector<uint8> compiled_results;
|
||||||
return std::pair<std::string, std::vector<uint8>>(
|
return GpuTargetBinary{"", std::move(compiled_results)};
|
||||||
"", std::move(compiled_results));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
Loading…
x
Reference in New Issue
Block a user