Merge pull request #30238 from ROCmSoftwarePlatform:google-upstream-pr-merge_to_gpu_executable
PiperOrigin-RevId: 258865825
This commit is contained in:
commit
fe8a48d81c
@ -7,7 +7,7 @@ load(
|
||||
"if_static",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library", "tf_copts")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_copts", "tf_cuda_library")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
package(
|
||||
@ -47,6 +47,14 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_types",
|
||||
hdrs = ["gpu_types.h"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "partition_assignment",
|
||||
srcs = [
|
||||
@ -453,6 +461,7 @@ cc_library(
|
||||
":cudnn_conv_runner",
|
||||
":cusolver_context",
|
||||
":gpu_debug_info_manager",
|
||||
":gpu_types",
|
||||
":hlo_execution_profiler",
|
||||
":infeed_manager",
|
||||
":ir_emission_utils",
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
|
||||
@ -36,8 +37,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/tracing.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
@ -50,18 +51,17 @@ using tensorflow::tracing::ScopedAnnotation;
|
||||
// Implementation note: HLO profiling is always enabled for GPU executables,
|
||||
// since we can use timers around thunks.
|
||||
GpuExecutable::GpuExecutable(
|
||||
const string& ptx, const std::vector<uint8>& cubin,
|
||||
std::pair<int, int> compute_capability,
|
||||
std::unique_ptr<const ThunkSchedule> thunk_schedule,
|
||||
const string& text, const std::vector<uint8>& binary,
|
||||
GpuVersion gpu_version, std::unique_ptr<const ThunkSchedule> thunk_schedule,
|
||||
std::shared_ptr<HloModule> hlo_module,
|
||||
std::shared_ptr<const BufferAssignment> assignment,
|
||||
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
|
||||
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
|
||||
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
|
||||
std::move(hlo_profile_index_map)),
|
||||
ptx_(ptx),
|
||||
cubin_(cubin),
|
||||
compute_capability_(compute_capability),
|
||||
text_(text),
|
||||
binary_(binary),
|
||||
gpu_version_(gpu_version),
|
||||
thunk_schedule_(std::move(thunk_schedule)),
|
||||
assignment_(std::move(assignment)) {
|
||||
CHECK(has_module() && assignment_);
|
||||
@ -89,10 +89,45 @@ void GpuExecutable::ComputeThunkAnnotations() {
|
||||
}
|
||||
}
|
||||
|
||||
Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions(
|
||||
const ServiceExecutableRunOptions* run_options) {
|
||||
se::Stream* main_stream = run_options->stream();
|
||||
|
||||
stream_executor::PlatformKind platform_kind =
|
||||
main_stream->parent()->platform_kind();
|
||||
if (platform_kind == stream_executor::PlatformKind::kROCm) {
|
||||
int stream_isa_version;
|
||||
main_stream->parent()->GetDeviceDescription().rocm_amdgpu_isa_version(
|
||||
&stream_isa_version);
|
||||
GpuVersion amd_isa_version = stream_isa_version;
|
||||
TF_RET_CHECK(amd_isa_version == gpu_version_)
|
||||
<< "AMDGPU GCN ISA version mismatch; expected {"
|
||||
<< absl::get<int>(gpu_version_) << ", but was " << stream_isa_version;
|
||||
} else if (platform_kind == stream_executor::PlatformKind::kCuda) {
|
||||
std::pair<int, int> stream_compute_compatibility;
|
||||
main_stream->parent()->GetDeviceDescription().cuda_compute_capability(
|
||||
&stream_compute_compatibility.first,
|
||||
&stream_compute_compatibility.second);
|
||||
GpuVersion nvdia_compute_compatibility = stream_compute_compatibility;
|
||||
TF_RET_CHECK(nvdia_compute_compatibility == gpu_version_)
|
||||
<< "Compute capability mismatch; expected {"
|
||||
<< absl::get<std::pair<int, int>>(gpu_version_).first << ", "
|
||||
<< absl::get<std::pair<int, int>>(gpu_version_).second << "}, but was {"
|
||||
<< stream_compute_compatibility.first << ", "
|
||||
<< stream_compute_compatibility.second << "}";
|
||||
} else {
|
||||
return InternalError("Unknown platform: %d", platform_kind);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GpuExecutable::ExecuteThunks(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
const BufferAllocations& buffer_allocations, bool block_host_until_done,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckCompatibilityWithServiceExecutableRunOptions(run_options));
|
||||
GpuDebugInfoManager::Get()->OnModuleStart(module().name());
|
||||
auto cleanup = MakeCleanup(
|
||||
[&]() { GpuDebugInfoManager::Get()->OnModuleStop(module().name()); });
|
||||
@ -100,16 +135,6 @@ Status GpuExecutable::ExecuteThunks(
|
||||
se::Stream* main_stream = run_options->stream();
|
||||
se::StreamExecutor* executor = main_stream->parent();
|
||||
|
||||
std::pair<int, int> stream_compute_compatibility;
|
||||
executor->GetDeviceDescription().cuda_compute_capability(
|
||||
&stream_compute_compatibility.first,
|
||||
&stream_compute_compatibility.second);
|
||||
TF_RET_CHECK(stream_compute_compatibility == compute_capability_)
|
||||
<< "Compute capability mismatch; expected {" << compute_capability_.first
|
||||
<< ", " << compute_capability_.second << "}, but was {"
|
||||
<< stream_compute_compatibility.first << ", "
|
||||
<< stream_compute_compatibility.second << "}";
|
||||
|
||||
bool do_profile = hlo_execution_profile != nullptr;
|
||||
if (do_profile) {
|
||||
LOG(WARNING) << "PROFILING: profiling is enabled";
|
||||
@ -210,10 +235,10 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) {
|
||||
}
|
||||
|
||||
se::MultiModuleLoaderSpec module_spec;
|
||||
if (!cubin().empty()) {
|
||||
module_spec.AddCudaCubinInMemory(cubin());
|
||||
if (!binary().empty()) {
|
||||
module_spec.AddCudaCubinInMemory(binary());
|
||||
}
|
||||
module_spec.AddCudaPtxInMemory(ptx().c_str());
|
||||
module_spec.AddCudaPtxInMemory(text().c_str());
|
||||
|
||||
absl::flat_hash_map<int64, se::DeviceMemoryBase> globals;
|
||||
se::ModuleHandle module_handle;
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
|
||||
@ -34,7 +35,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
@ -44,18 +44,16 @@ namespace gpu {
|
||||
|
||||
// GPU-targeting implementation of the XLA Executable interface.
|
||||
//
|
||||
// Launches the given CUDA kernel via the StreamExecutor.
|
||||
// Launches the given GPU kernel via the StreamExecutor.
|
||||
//
|
||||
// This is an immutable data type after initialization, and thus thread safe.
|
||||
class GpuExecutable : public Executable {
|
||||
public:
|
||||
// cubin (i.e. the compiled ptx) may be empty, in which case we leave
|
||||
// compilation up to the GPU driver.
|
||||
// We need to share ownership of hlo_module and assignment with profiler to
|
||||
// safely keep a reference to these objects during tracing period, thus they
|
||||
// are passed as shared pointers.
|
||||
GpuExecutable(const string& ptx, const std::vector<uint8>& cubin,
|
||||
std::pair<int, int> compute_capability,
|
||||
GpuExecutable(const string& text, const std::vector<uint8>& binary,
|
||||
GpuVersion gpu_version,
|
||||
std::unique_ptr<const ThunkSchedule> thunk_schedule,
|
||||
std::shared_ptr<HloModule> hlo_module,
|
||||
std::shared_ptr<const BufferAssignment> assignment,
|
||||
@ -71,12 +69,14 @@ class GpuExecutable : public Executable {
|
||||
ir_module_string_ = ir_module_string;
|
||||
}
|
||||
|
||||
// Returns the compiled PTX for the computation.
|
||||
const string& ptx() const { return ptx_; }
|
||||
// Returns the compiled code for the computation. The compiled code is PTX in
|
||||
// Cuda and unused empty string in ROCm.
|
||||
const string& text() const { return text_; }
|
||||
|
||||
// Returns the cubin (compiled PTX) stored in this GpuExecutable. May be
|
||||
// empty, in which case compilation is left up to the GPU driver.
|
||||
const std::vector<uint8>& cubin() const { return cubin_; }
|
||||
// Returns the binary stored in this GpuExecutable. The binary is cubin in
|
||||
// Cuda, and HSA code object in ROCm. It may be empty, in which case
|
||||
// compilation is left up to the GPU driver.
|
||||
const std::vector<uint8>& binary() const { return binary_; }
|
||||
|
||||
// ExecuteOnStream will fail if the compute capability of the stream doesn't
|
||||
// match the compute capability passed to this object's constructor.
|
||||
@ -125,6 +125,11 @@ class GpuExecutable : public Executable {
|
||||
// Computes annotations for each thunk and store them in thunk_annotations_.
|
||||
void ComputeThunkAnnotations();
|
||||
|
||||
// GpuExecutable check with either AMD's ISA version, or Nvdia's major minor
|
||||
// version for compute capability, depending on the hardware.
|
||||
Status CheckCompatibilityWithServiceExecutableRunOptions(
|
||||
const ServiceExecutableRunOptions* run_options);
|
||||
|
||||
// The LLVM IR, in string format, of the unoptimized module generated for this
|
||||
// GpuExecutable. We save a string instead of an llvm::Module* because leaving
|
||||
// llvm::Module* in a singleton can cause the heap checker to emit false
|
||||
@ -133,17 +138,17 @@ class GpuExecutable : public Executable {
|
||||
// This string should be modified only before ExecuteOnStream.
|
||||
string ir_module_string_;
|
||||
|
||||
// The PTX for the computation.
|
||||
const string ptx_;
|
||||
// The compiled code for the computation.
|
||||
const string text_;
|
||||
|
||||
// The GPU machine code for the computation, targeting GPUs at
|
||||
// compute_capability_.
|
||||
//
|
||||
// May be empty, in which case we leave compilation up to the GPU driver.
|
||||
const std::vector<uint8> cubin_;
|
||||
const std::vector<uint8> binary_;
|
||||
|
||||
// The compute capability of the GPU we're targeting with this GpuExecutable.
|
||||
std::pair<int, int> compute_capability_;
|
||||
// The GPU version for compute compatibility check.
|
||||
GpuVersion gpu_version_;
|
||||
|
||||
// The thunks to be invoked by this GpuExecutable. They are generated by the
|
||||
// IrEmitter.
|
||||
|
31
tensorflow/compiler/xla/service/gpu/gpu_types.h
Normal file
31
tensorflow/compiler/xla/service/gpu/gpu_types.h
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
|
||||
|
||||
#include "absl/types/variant.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// GpuVersion is used to abstract Gpu hardware version. On Cuda platform,
|
||||
// it comprises a pair of integers denoting major and minor version.
|
||||
// On ROCm platform, it comprises one integer for AMD GCN ISA version.
|
||||
using GpuVersion = absl::variant<std::pair<int, int>, int>;
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
|
@ -55,8 +55,8 @@ Status KernelThunk::Initialize(const GpuExecutable& executable,
|
||||
if (kernel_cache_.end() == it) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<se::KernelBase> kernel,
|
||||
CreateKernel(kernel_name_, args_.size(), executable.ptx(),
|
||||
executable.cubin(), executor));
|
||||
CreateKernel(kernel_name_, args_.size(), executable.text(),
|
||||
executable.binary(), executor));
|
||||
|
||||
kernel_cache_.emplace(executor, std::move(kernel));
|
||||
}
|
||||
|
@ -672,7 +672,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
|
||||
}
|
||||
|
||||
auto* gpu_executable = new GpuExecutable(
|
||||
ptx, cubin, {cc_major, cc_minor}, std::move(thunk_schedule),
|
||||
ptx, cubin, std::make_pair(cc_major, cc_minor), std::move(thunk_schedule),
|
||||
std::move(module), std::move(buffer_assignment),
|
||||
std::move(profile_printer), std::move(profile_index_map));
|
||||
if (embed_ir_in_executable) {
|
||||
|
@ -46,7 +46,7 @@ void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr<HloModule> hlo_module,
|
||||
absl::string_view pattern) {
|
||||
std::unique_ptr<Executable> executable =
|
||||
std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie());
|
||||
string ptx_str(static_cast<GpuExecutable*>(executable.get())->ptx());
|
||||
string ptx_str(static_cast<GpuExecutable*>(executable.get())->text());
|
||||
StatusOr<bool> filecheck_result = RunFileCheck(ptx_str, pattern);
|
||||
ASSERT_TRUE(filecheck_result.ok());
|
||||
EXPECT_TRUE(filecheck_result.ValueOrDie());
|
||||
|
@ -16,14 +16,13 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_TYPES_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_TYPES_H_
|
||||
|
||||
#include <Eigen/Core>
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
#include <Eigen/Core>
|
||||
|
||||
namespace xla {
|
||||
|
||||
using ::tensorflow::string;
|
||||
@ -44,7 +43,6 @@ using complex64 = std::complex<float>;
|
||||
using complex128 = std::complex<double>;
|
||||
|
||||
using ::Eigen::half;
|
||||
|
||||
} // namespace xla
|
||||
|
||||
// Alias namespace ::stream_executor as ::xla::se.
|
||||
|
4
third_party/mkl/build_defs.bzl
vendored
4
third_party/mkl/build_defs.bzl
vendored
@ -106,12 +106,12 @@ def mkl_deps():
|
||||
inclusion in the deps attribute of rules.
|
||||
"""
|
||||
return select({
|
||||
str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): ["//third_party/intel_mkl_dnn"],
|
||||
str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): ["@mkl_dnn"],
|
||||
str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_v1_only")): ["@mkl_dnn_v1//:mkl_dnn"],
|
||||
str(Label("//third_party/mkl:build_with_mkl_ml_only")): ["//third_party/mkl:intel_binary_blob"],
|
||||
str(Label("//third_party/mkl:build_with_mkl")): [
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
"//third_party/intel_mkl_dnn",
|
||||
"@mkl_dnn",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
Loading…
x
Reference in New Issue
Block a user