diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index f58941f2501..a5fc6e80cec 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -7,7 +7,7 @@ load( "if_static", "tf_cuda_tests_tags", ) -load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library", "tf_copts") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_copts", "tf_cuda_library") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") package( @@ -47,6 +47,14 @@ cc_library( ], ) +cc_library( + name = "gpu_types", + hdrs = ["gpu_types.h"], + deps = [ + "@com_google_absl//absl/types:variant", + ], +) + cc_library( name = "partition_assignment", srcs = [ @@ -453,6 +461,7 @@ cc_library( ":cudnn_conv_runner", ":cusolver_context", ":gpu_debug_info_manager", + ":gpu_types", ":hlo_execution_profiler", ":infeed_manager", ":ir_emission_utils", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index ce559c2c1e7..e4942bd76a6 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" @@ -36,8 +37,8 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/tracing.h" -#include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/stream_executor/platform.h" namespace xla { namespace gpu { @@ -50,18 +51,17 @@ using tensorflow::tracing::ScopedAnnotation; // Implementation note: HLO profiling is always enabled for GPU executables, // since we can use timers around thunks. GpuExecutable::GpuExecutable( - const string& ptx, const std::vector& cubin, - std::pair compute_capability, - std::unique_ptr thunk_schedule, + const string& text, const std::vector& binary, + GpuVersion gpu_version, std::unique_ptr thunk_schedule, std::shared_ptr hlo_module, std::shared_ptr assignment, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)), - ptx_(ptx), - cubin_(cubin), - compute_capability_(compute_capability), + text_(text), + binary_(binary), + gpu_version_(gpu_version), thunk_schedule_(std::move(thunk_schedule)), assignment_(std::move(assignment)) { CHECK(has_module() && assignment_); @@ -89,10 +89,45 @@ void GpuExecutable::ComputeThunkAnnotations() { } } +Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions( + const ServiceExecutableRunOptions* run_options) { + se::Stream* main_stream = run_options->stream(); + + stream_executor::PlatformKind platform_kind = + main_stream->parent()->platform_kind(); + if (platform_kind == stream_executor::PlatformKind::kROCm) { + int stream_isa_version; + main_stream->parent()->GetDeviceDescription().rocm_amdgpu_isa_version( + &stream_isa_version); + GpuVersion amd_isa_version = stream_isa_version; + TF_RET_CHECK(amd_isa_version == gpu_version_) + << "AMDGPU GCN ISA version mismatch; expected {" + << absl::get(gpu_version_) << ", but was " << stream_isa_version; + } else if (platform_kind == stream_executor::PlatformKind::kCuda) { + std::pair stream_compute_compatibility; + main_stream->parent()->GetDeviceDescription().cuda_compute_capability( + &stream_compute_compatibility.first, + &stream_compute_compatibility.second); + GpuVersion nvdia_compute_compatibility = stream_compute_compatibility; + TF_RET_CHECK(nvdia_compute_compatibility == gpu_version_) + << "Compute capability mismatch; expected {" + << absl::get>(gpu_version_).first << ", " + << absl::get>(gpu_version_).second << "}, but was {" + << stream_compute_compatibility.first << ", " + << stream_compute_compatibility.second << "}"; + } else { + return InternalError("Unknown platform: %d", platform_kind); + } + + return Status::OK(); +} + Status GpuExecutable::ExecuteThunks( const ServiceExecutableRunOptions* run_options, const BufferAllocations& buffer_allocations, bool block_host_until_done, HloExecutionProfile* hlo_execution_profile) { + TF_RETURN_IF_ERROR( + CheckCompatibilityWithServiceExecutableRunOptions(run_options)); GpuDebugInfoManager::Get()->OnModuleStart(module().name()); auto cleanup = MakeCleanup( [&]() { GpuDebugInfoManager::Get()->OnModuleStop(module().name()); }); @@ -100,16 +135,6 @@ Status GpuExecutable::ExecuteThunks( se::Stream* main_stream = run_options->stream(); se::StreamExecutor* executor = main_stream->parent(); - std::pair stream_compute_compatibility; - executor->GetDeviceDescription().cuda_compute_capability( - &stream_compute_compatibility.first, - &stream_compute_compatibility.second); - TF_RET_CHECK(stream_compute_compatibility == compute_capability_) - << "Compute capability mismatch; expected {" << compute_capability_.first - << ", " << compute_capability_.second << "}, but was {" - << stream_compute_compatibility.first << ", " - << stream_compute_compatibility.second << "}"; - bool do_profile = hlo_execution_profile != nullptr; if (do_profile) { LOG(WARNING) << "PROFILING: profiling is enabled"; @@ -210,10 +235,10 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { } se::MultiModuleLoaderSpec module_spec; - if (!cubin().empty()) { - module_spec.AddCudaCubinInMemory(cubin()); + if (!binary().empty()) { + module_spec.AddCudaCubinInMemory(binary()); } - module_spec.AddCudaPtxInMemory(ptx().c_str()); + module_spec.AddCudaPtxInMemory(text().c_str()); absl::flat_hash_map globals; se::ModuleHandle module_handle; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index fd1a17c0a92..5f9fe3e71ef 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/stream_executor/device_memory_allocator.h" @@ -44,18 +44,16 @@ namespace gpu { // GPU-targeting implementation of the XLA Executable interface. // -// Launches the given CUDA kernel via the StreamExecutor. +// Launches the given GPU kernel via the StreamExecutor. // // This is an immutable data type after initialization, and thus thread safe. class GpuExecutable : public Executable { public: - // cubin (i.e. the compiled ptx) may be empty, in which case we leave - // compilation up to the GPU driver. // We need to share ownership of hlo_module and assignment with profiler to // safely keep a reference to these objects during tracing period, thus they // are passed as shared pointers. - GpuExecutable(const string& ptx, const std::vector& cubin, - std::pair compute_capability, + GpuExecutable(const string& text, const std::vector& binary, + GpuVersion gpu_version, std::unique_ptr thunk_schedule, std::shared_ptr hlo_module, std::shared_ptr assignment, @@ -71,12 +69,14 @@ class GpuExecutable : public Executable { ir_module_string_ = ir_module_string; } - // Returns the compiled PTX for the computation. - const string& ptx() const { return ptx_; } + // Returns the compiled code for the computation. The compiled code is PTX in + // Cuda and unused empty string in ROCm. + const string& text() const { return text_; } - // Returns the cubin (compiled PTX) stored in this GpuExecutable. May be - // empty, in which case compilation is left up to the GPU driver. - const std::vector& cubin() const { return cubin_; } + // Returns the binary stored in this GpuExecutable. The binary is cubin in + // Cuda, and HSA code object in ROCm. It may be empty, in which case + // compilation is left up to the GPU driver. + const std::vector& binary() const { return binary_; } // ExecuteOnStream will fail if the compute capability of the stream doesn't // match the compute capability passed to this object's constructor. @@ -125,6 +125,11 @@ class GpuExecutable : public Executable { // Computes annotations for each thunk and store them in thunk_annotations_. void ComputeThunkAnnotations(); + // GpuExecutable check with either AMD's ISA version, or Nvdia's major minor + // version for compute capability, depending on the hardware. + Status CheckCompatibilityWithServiceExecutableRunOptions( + const ServiceExecutableRunOptions* run_options); + // The LLVM IR, in string format, of the unoptimized module generated for this // GpuExecutable. We save a string instead of an llvm::Module* because leaving // llvm::Module* in a singleton can cause the heap checker to emit false @@ -133,17 +138,17 @@ class GpuExecutable : public Executable { // This string should be modified only before ExecuteOnStream. string ir_module_string_; - // The PTX for the computation. - const string ptx_; + // The compiled code for the computation. + const string text_; // The GPU machine code for the computation, targeting GPUs at // compute_capability_. // // May be empty, in which case we leave compilation up to the GPU driver. - const std::vector cubin_; + const std::vector binary_; - // The compute capability of the GPU we're targeting with this GpuExecutable. - std::pair compute_capability_; + // The GPU version for compute compatibility check. + GpuVersion gpu_version_; // The thunks to be invoked by this GpuExecutable. They are generated by the // IrEmitter. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_types.h b/tensorflow/compiler/xla/service/gpu/gpu_types.h new file mode 100644 index 00000000000..1c51040fb82 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_types.h @@ -0,0 +1,31 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_ + +#include "absl/types/variant.h" + +namespace xla { +namespace gpu { + +// GpuVersion is used to abstract Gpu hardware version. On Cuda platform, +// it comprises a pair of integers denoting major and minor version. +// On ROCm platform, it comprises one integer for AMD GCN ISA version. +using GpuVersion = absl::variant, int>; +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index f1d1c0cf3e0..d976b5d8d4d 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -55,8 +55,8 @@ Status KernelThunk::Initialize(const GpuExecutable& executable, if (kernel_cache_.end() == it) { TF_ASSIGN_OR_RETURN( std::unique_ptr kernel, - CreateKernel(kernel_name_, args_.size(), executable.ptx(), - executable.cubin(), executor)); + CreateKernel(kernel_name_, args_.size(), executable.text(), + executable.binary(), executor)); kernel_cache_.emplace(executor, std::move(kernel)); } diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index ae1c169c960..20b3d64c417 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -672,7 +672,7 @@ StatusOr> NVPTXCompiler::RunBackend( } auto* gpu_executable = new GpuExecutable( - ptx, cubin, {cc_major, cc_minor}, std::move(thunk_schedule), + ptx, cubin, std::make_pair(cc_major, cc_minor), std::move(thunk_schedule), std::move(module), std::move(buffer_assignment), std::move(profile_printer), std::move(profile_index_map)); if (embed_ir_in_executable) { diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc index 456492f22e0..83fb6ebb443 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc @@ -46,7 +46,7 @@ void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr hlo_module, absl::string_view pattern) { std::unique_ptr executable = std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie()); - string ptx_str(static_cast(executable.get())->ptx()); + string ptx_str(static_cast(executable.get())->text()); StatusOr filecheck_result = RunFileCheck(ptx_str, pattern); ASSERT_TRUE(filecheck_result.ok()); EXPECT_TRUE(filecheck_result.ValueOrDie()); diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index daf678f6901..3b4e1aef08b 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -16,14 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TYPES_H_ #define TENSORFLOW_COMPILER_XLA_TYPES_H_ +#include #include #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/platform/types.h" -#include - namespace xla { using ::tensorflow::string; @@ -44,7 +43,6 @@ using complex64 = std::complex; using complex128 = std::complex; using ::Eigen::half; - } // namespace xla // Alias namespace ::stream_executor as ::xla::se. diff --git a/third_party/mkl/build_defs.bzl b/third_party/mkl/build_defs.bzl index b7da9850255..9f16fdd124e 100644 --- a/third_party/mkl/build_defs.bzl +++ b/third_party/mkl/build_defs.bzl @@ -106,12 +106,12 @@ def mkl_deps(): inclusion in the deps attribute of rules. """ return select({ - str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): ["//third_party/intel_mkl_dnn"], + str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): ["@mkl_dnn"], str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_v1_only")): ["@mkl_dnn_v1//:mkl_dnn"], str(Label("//third_party/mkl:build_with_mkl_ml_only")): ["//third_party/mkl:intel_binary_blob"], str(Label("//third_party/mkl:build_with_mkl")): [ "//third_party/mkl:intel_binary_blob", - "//third_party/intel_mkl_dnn", + "@mkl_dnn", ], "//conditions:default": [], })