Merge pull request #30238 from ROCmSoftwarePlatform:google-upstream-pr-merge_to_gpu_executable

PiperOrigin-RevId: 258865825
This commit is contained in:
TensorFlower Gardener 2019-07-18 16:39:15 -07:00
commit fe8a48d81c
9 changed files with 114 additions and 46 deletions

View File

@ -7,7 +7,7 @@ load(
"if_static",
"tf_cuda_tests_tags",
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library", "tf_copts")
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_copts", "tf_cuda_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
package(
@ -47,6 +47,14 @@ cc_library(
],
)
cc_library(
name = "gpu_types",
hdrs = ["gpu_types.h"],
deps = [
"@com_google_absl//absl/types:variant",
],
)
cc_library(
name = "partition_assignment",
srcs = [
@ -453,6 +461,7 @@ cc_library(
":cudnn_conv_runner",
":cusolver_context",
":gpu_debug_info_manager",
":gpu_types",
":hlo_execution_profiler",
":infeed_manager",
":ir_emission_utils",

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
@ -36,8 +37,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/stream_executor/platform.h"
namespace xla {
namespace gpu {
@ -50,18 +51,17 @@ using tensorflow::tracing::ScopedAnnotation;
// Implementation note: HLO profiling is always enabled for GPU executables,
// since we can use timers around thunks.
GpuExecutable::GpuExecutable(
const string& ptx, const std::vector<uint8>& cubin,
std::pair<int, int> compute_capability,
std::unique_ptr<const ThunkSchedule> thunk_schedule,
const string& text, const std::vector<uint8>& binary,
GpuVersion gpu_version, std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::shared_ptr<HloModule> hlo_module,
std::shared_ptr<const BufferAssignment> assignment,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)),
ptx_(ptx),
cubin_(cubin),
compute_capability_(compute_capability),
text_(text),
binary_(binary),
gpu_version_(gpu_version),
thunk_schedule_(std::move(thunk_schedule)),
assignment_(std::move(assignment)) {
CHECK(has_module() && assignment_);
@ -89,10 +89,45 @@ void GpuExecutable::ComputeThunkAnnotations() {
}
}
Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions(
const ServiceExecutableRunOptions* run_options) {
se::Stream* main_stream = run_options->stream();
stream_executor::PlatformKind platform_kind =
main_stream->parent()->platform_kind();
if (platform_kind == stream_executor::PlatformKind::kROCm) {
int stream_isa_version;
main_stream->parent()->GetDeviceDescription().rocm_amdgpu_isa_version(
&stream_isa_version);
GpuVersion amd_isa_version = stream_isa_version;
TF_RET_CHECK(amd_isa_version == gpu_version_)
<< "AMDGPU GCN ISA version mismatch; expected {"
<< absl::get<int>(gpu_version_) << ", but was " << stream_isa_version;
} else if (platform_kind == stream_executor::PlatformKind::kCuda) {
std::pair<int, int> stream_compute_compatibility;
main_stream->parent()->GetDeviceDescription().cuda_compute_capability(
&stream_compute_compatibility.first,
&stream_compute_compatibility.second);
GpuVersion nvdia_compute_compatibility = stream_compute_compatibility;
TF_RET_CHECK(nvdia_compute_compatibility == gpu_version_)
<< "Compute capability mismatch; expected {"
<< absl::get<std::pair<int, int>>(gpu_version_).first << ", "
<< absl::get<std::pair<int, int>>(gpu_version_).second << "}, but was {"
<< stream_compute_compatibility.first << ", "
<< stream_compute_compatibility.second << "}";
} else {
return InternalError("Unknown platform: %d", platform_kind);
}
return Status::OK();
}
Status GpuExecutable::ExecuteThunks(
const ServiceExecutableRunOptions* run_options,
const BufferAllocations& buffer_allocations, bool block_host_until_done,
HloExecutionProfile* hlo_execution_profile) {
TF_RETURN_IF_ERROR(
CheckCompatibilityWithServiceExecutableRunOptions(run_options));
GpuDebugInfoManager::Get()->OnModuleStart(module().name());
auto cleanup = MakeCleanup(
[&]() { GpuDebugInfoManager::Get()->OnModuleStop(module().name()); });
@ -100,16 +135,6 @@ Status GpuExecutable::ExecuteThunks(
se::Stream* main_stream = run_options->stream();
se::StreamExecutor* executor = main_stream->parent();
std::pair<int, int> stream_compute_compatibility;
executor->GetDeviceDescription().cuda_compute_capability(
&stream_compute_compatibility.first,
&stream_compute_compatibility.second);
TF_RET_CHECK(stream_compute_compatibility == compute_capability_)
<< "Compute capability mismatch; expected {" << compute_capability_.first
<< ", " << compute_capability_.second << "}, but was {"
<< stream_compute_compatibility.first << ", "
<< stream_compute_compatibility.second << "}";
bool do_profile = hlo_execution_profile != nullptr;
if (do_profile) {
LOG(WARNING) << "PROFILING: profiling is enabled";
@ -210,10 +235,10 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) {
}
se::MultiModuleLoaderSpec module_spec;
if (!cubin().empty()) {
module_spec.AddCudaCubinInMemory(cubin());
if (!binary().empty()) {
module_spec.AddCudaCubinInMemory(binary());
}
module_spec.AddCudaPtxInMemory(ptx().c_str());
module_spec.AddCudaPtxInMemory(text().c_str());
absl::flat_hash_map<int64, se::DeviceMemoryBase> globals;
se::ModuleHandle module_handle;

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
@ -34,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
@ -44,18 +44,16 @@ namespace gpu {
// GPU-targeting implementation of the XLA Executable interface.
//
// Launches the given CUDA kernel via the StreamExecutor.
// Launches the given GPU kernel via the StreamExecutor.
//
// This is an immutable data type after initialization, and thus thread safe.
class GpuExecutable : public Executable {
public:
// cubin (i.e. the compiled ptx) may be empty, in which case we leave
// compilation up to the GPU driver.
// We need to share ownership of hlo_module and assignment with profiler to
// safely keep a reference to these objects during tracing period, thus they
// are passed as shared pointers.
GpuExecutable(const string& ptx, const std::vector<uint8>& cubin,
std::pair<int, int> compute_capability,
GpuExecutable(const string& text, const std::vector<uint8>& binary,
GpuVersion gpu_version,
std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::shared_ptr<HloModule> hlo_module,
std::shared_ptr<const BufferAssignment> assignment,
@ -71,12 +69,14 @@ class GpuExecutable : public Executable {
ir_module_string_ = ir_module_string;
}
// Returns the compiled PTX for the computation.
const string& ptx() const { return ptx_; }
// Returns the compiled code for the computation. The compiled code is PTX in
// Cuda and unused empty string in ROCm.
const string& text() const { return text_; }
// Returns the cubin (compiled PTX) stored in this GpuExecutable. May be
// empty, in which case compilation is left up to the GPU driver.
const std::vector<uint8>& cubin() const { return cubin_; }
// Returns the binary stored in this GpuExecutable. The binary is cubin in
// Cuda, and HSA code object in ROCm. It may be empty, in which case
// compilation is left up to the GPU driver.
const std::vector<uint8>& binary() const { return binary_; }
// ExecuteOnStream will fail if the compute capability of the stream doesn't
// match the compute capability passed to this object's constructor.
@ -125,6 +125,11 @@ class GpuExecutable : public Executable {
// Computes annotations for each thunk and store them in thunk_annotations_.
void ComputeThunkAnnotations();
// GpuExecutable check with either AMD's ISA version, or Nvdia's major minor
// version for compute capability, depending on the hardware.
Status CheckCompatibilityWithServiceExecutableRunOptions(
const ServiceExecutableRunOptions* run_options);
// The LLVM IR, in string format, of the unoptimized module generated for this
// GpuExecutable. We save a string instead of an llvm::Module* because leaving
// llvm::Module* in a singleton can cause the heap checker to emit false
@ -133,17 +138,17 @@ class GpuExecutable : public Executable {
// This string should be modified only before ExecuteOnStream.
string ir_module_string_;
// The PTX for the computation.
const string ptx_;
// The compiled code for the computation.
const string text_;
// The GPU machine code for the computation, targeting GPUs at
// compute_capability_.
//
// May be empty, in which case we leave compilation up to the GPU driver.
const std::vector<uint8> cubin_;
const std::vector<uint8> binary_;
// The compute capability of the GPU we're targeting with this GpuExecutable.
std::pair<int, int> compute_capability_;
// The GPU version for compute compatibility check.
GpuVersion gpu_version_;
// The thunks to be invoked by this GpuExecutable. They are generated by the
// IrEmitter.

View File

@ -0,0 +1,31 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
#include "absl/types/variant.h"
namespace xla {
namespace gpu {
// GpuVersion is used to abstract Gpu hardware version. On Cuda platform,
// it comprises a pair of integers denoting major and minor version.
// On ROCm platform, it comprises one integer for AMD GCN ISA version.
using GpuVersion = absl::variant<std::pair<int, int>, int>;
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_

View File

@ -55,8 +55,8 @@ Status KernelThunk::Initialize(const GpuExecutable& executable,
if (kernel_cache_.end() == it) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<se::KernelBase> kernel,
CreateKernel(kernel_name_, args_.size(), executable.ptx(),
executable.cubin(), executor));
CreateKernel(kernel_name_, args_.size(), executable.text(),
executable.binary(), executor));
kernel_cache_.emplace(executor, std::move(kernel));
}

View File

@ -672,7 +672,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
}
auto* gpu_executable = new GpuExecutable(
ptx, cubin, {cc_major, cc_minor}, std::move(thunk_schedule),
ptx, cubin, std::make_pair(cc_major, cc_minor), std::move(thunk_schedule),
std::move(module), std::move(buffer_assignment),
std::move(profile_printer), std::move(profile_index_map));
if (embed_ir_in_executable) {

View File

@ -46,7 +46,7 @@ void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr<HloModule> hlo_module,
absl::string_view pattern) {
std::unique_ptr<Executable> executable =
std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie());
string ptx_str(static_cast<GpuExecutable*>(executable.get())->ptx());
string ptx_str(static_cast<GpuExecutable*>(executable.get())->text());
StatusOr<bool> filecheck_result = RunFileCheck(ptx_str, pattern);
ASSERT_TRUE(filecheck_result.ok());
EXPECT_TRUE(filecheck_result.ValueOrDie());

View File

@ -16,14 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_TYPES_H_
#define TENSORFLOW_COMPILER_XLA_TYPES_H_
#include <Eigen/Core>
#include <complex>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/platform/types.h"
#include <Eigen/Core>
namespace xla {
using ::tensorflow::string;
@ -44,7 +43,6 @@ using complex64 = std::complex<float>;
using complex128 = std::complex<double>;
using ::Eigen::half;
} // namespace xla
// Alias namespace ::stream_executor as ::xla::se.

View File

@ -106,12 +106,12 @@ def mkl_deps():
inclusion in the deps attribute of rules.
"""
return select({
str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): ["//third_party/intel_mkl_dnn"],
str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): ["@mkl_dnn"],
str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_v1_only")): ["@mkl_dnn_v1//:mkl_dnn"],
str(Label("//third_party/mkl:build_with_mkl_ml_only")): ["//third_party/mkl:intel_binary_blob"],
str(Label("//third_party/mkl:build_with_mkl")): [
"//third_party/mkl:intel_binary_blob",
"//third_party/intel_mkl_dnn",
"@mkl_dnn",
],
"//conditions:default": [],
})