From ed10b372133870c9ab57d76a20faed77fa0de97d Mon Sep 17 00:00:00 2001 From: "Wen-Heng (Jack) Chung" Date: Fri, 9 Aug 2019 17:32:46 +0000 Subject: [PATCH 1/2] Introduce amdgpu_compiler to XLA. --- tensorflow/compiler/xla/service/gpu/BUILD | 15 +- .../xla/service/gpu/amdgpu_compiler.cc | 155 ++++++++++++++++++ .../xla/service/gpu/amdgpu_compiler.h | 70 ++++++++ .../gpu/amdgpu_compiler_registration.cc | 26 +++ 4 files changed, 256 insertions(+), 10 deletions(-) create mode 100644 tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc create mode 100644 tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h create mode 100644 tensorflow/compiler/xla/service/gpu/amdgpu_compiler_registration.cc diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 1b41d2ffc97..36f82089a69 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1133,8 +1133,7 @@ cc_library( cc_library( name = "amdgpu_compiler", srcs = [ - # TODO(whchung@gmail.com): Enable in the subsequent PR. - # "amdgpu_compiler_registration.cc", + "amdgpu_compiler_registration.cc", ], deps = [ ":amdgpu_compiler_impl", @@ -1145,18 +1144,14 @@ cc_library( cc_library( name = "amdgpu_compiler_impl", srcs = [ - # TODO(whchung@gmail.com) : enable in the subsequent PR. - #"amdgpu_compiler.cc", + "amdgpu_compiler.cc", ], hdrs = [ - # TODO(whchung@gmail.com): enable in the subsequent PR. - #"amdgpu_compiler.h" + "amdgpu_compiler.h" ], deps = [ - # TODO(whchung@gmail.com): Enable these after pending PRs get merged. - #":gpu_compiler_impl", - #":miopen_conv_algorithm_picker", - #"//tensorflow/core:rocm_rocdl_path", + ":gpu_compiler_impl", + "//tensorflow/core/platform:rocm_rocdl_path", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc new file mode 100644 index 00000000000..844acced034 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc @@ -0,0 +1,155 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h" + +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h" +// TODO(whchung@gmail.com): Add gpu_conv_algorithm_picker after its PR merged. +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/target_constants.h" +#include "tensorflow/compiler/xla/service/hlo_constant_folding.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" +#include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/core/platform/rocm_rocdl_path.h" + +namespace xla { +namespace gpu { + +namespace { + +// Returns the directory containing ROCm-Device-Libs files. This function is +// called in AMDGPUCompiler's constructor, so can't return an error. But +// AMDGPUCompiler::Compile will return an error when the wanted rocdl file +// doesn't exist in the folder this function returns. +string GetROCDLDir(const HloModuleConfig& config) { + std::vector potential_rocdl_dirs; + const string datadir = config.debug_options().xla_gpu_cuda_data_dir(); + if (!datadir.empty()) { + potential_rocdl_dirs.push_back(datadir); + } + potential_rocdl_dirs.push_back(tensorflow::RocdlRoot()); + + // Tries all potential ROCDL directories in the order they are inserted. + // Returns the first directory that exists in the file system. + for (const string& potential_rocdl_dir : potential_rocdl_dirs) { + if (tensorflow::Env::Default()->IsDirectory(potential_rocdl_dir).ok()) { + VLOG(2) << "Found ROCm-Device-Libs dir " << potential_rocdl_dir; + return potential_rocdl_dir; + } + VLOG(2) << "Unable to find potential ROCm-Device-Libs dir " + << potential_rocdl_dir; + } + + // Last resort: maybe in the current folder. + return "."; +} + +} // namespace + +Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( + HloModule* hlo_module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) { + // Convert convolutions into CustomCalls to MIOpen, then canonicalize them + // (PadInsertion). + HloPassPipeline pipeline("conv_canonicalization"); + pipeline.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); + pipeline.AddPass(); + pipeline.AddPass(); + + pipeline.AddPass(); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + + return Status::OK(); +} + +Status AMDGPUCompiler::OptimizeHloPostLayoutAssignment( + HloModule* hlo_module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) { + HloPassPipeline pipeline("post-layout_assignment"); + pipeline.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); + + // The LayoutAssignment pass may leave behind kCopy instructions which are + // duplicate or NOPs, so remove them with algebraic simplification and CSE. + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + pipeline.AddPass>(options); + + // TODO(whchung@gmail.com): Add gpu_conv_algorithm_picker after its PR merged. + + // Clean up new_tuple described above. + pipeline.AddPass(); + + pipeline.AddPass(/*is_layout_sensitive=*/true); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + + return Status::OK(); +} + +AMDGPUCompiler::AMDGPUCompiler() + : GpuCompiler(stream_executor::rocm::kROCmPlatformId, amdgpu::kTargetTriple, amdgpu::kDataLayout) {} + +GpuVersion AMDGPUCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) { + int isa_version = 0; + if (!stream_exec->GetDeviceDescription(). + rocm_amdgpu_isa_version(&isa_version)) { + LOG(WARNING) + << "Couldn't get AMDGPU ISA version for device; assuming gfx803."; + isa_version = 803; + } + + return isa_version; +} + +StatusOr>> +AMDGPUCompiler::CompileTargetBinary(const HloModule* module, + llvm::Module* llvm_module, + GpuVersion gpu_version, + se::StreamExecutor* stream_exec) { + if (rocdl_dir_.empty()) { + // Compute rocdl_dir_ just once and cache it in this member. + rocdl_dir_ = GetROCDLDir(module->config()); + } + + std::vector hsaco; + { + XLA_SCOPED_LOGGING_TIMER( + "AMDGPUCompiler::CompileTargetBinary - CompileToHsaco"); + TF_ASSIGN_OR_RETURN(hsaco, + amdgpu::CompileToHsaco(llvm_module, gpu_version, + module->config(), rocdl_dir_)); + } + + llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false); + + if (user_post_optimization_hook_) { + user_post_optimization_hook_(*llvm_module); + } + + return std::pair>("", std::move(hsaco)); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h new file mode 100644 index 00000000000..b8a3bad47b0 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AMDGPU_COMPILER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AMDGPU_COMPILER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/llvm_compiler.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace xla { +namespace gpu { + +// AMDGPUCompiler generates efficient GPU executables for AMDGPU target. +class AMDGPUCompiler : public GpuCompiler { + public: + AMDGPUCompiler(); + ~AMDGPUCompiler() override {} + + Status OptimizeHloConvolutionCanonicalization( + HloModule* hlo_module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) override; + + Status OptimizeHloPostLayoutAssignment( + HloModule* hlo_module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) override; + + GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override; + + StatusOr>> CompileTargetBinary( + const HloModule* hlo_module, llvm::Module* llvm_module, + GpuVersion gpu_version, se::StreamExecutor* stream_exec) override; + + private: + // The parent directory of ROCm-Device-Libs IR libraries. + string rocdl_dir_; + + TF_DISALLOW_COPY_AND_ASSIGN(AMDGPUCompiler); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AMDGPU_COMPILER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler_registration.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler_registration.cc new file mode 100644 index 00000000000..706cffa3cc0 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler_registration.cc @@ -0,0 +1,26 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h" + +static bool InitModule() { + xla::Compiler::RegisterCompilerFactory( + stream_executor::rocm::kROCmPlatformId, []() { + return absl::make_unique(); + }); + return true; +} +static bool module_initialized = InitModule(); + From 3edd97fa35b1ca16da2c64fb391b891d622150a1 Mon Sep 17 00:00:00 2001 From: "Wen-Heng (Jack) Chung" Date: Wed, 14 Aug 2019 13:19:36 -0500 Subject: [PATCH 2/2] Address issues found in sanity checks. --- tensorflow/compiler/xla/service/gpu/BUILD | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 36f82089a69..fdae0c7e508 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1133,7 +1133,7 @@ cc_library( cc_library( name = "amdgpu_compiler", srcs = [ - "amdgpu_compiler_registration.cc", + "amdgpu_compiler_registration.cc", ], deps = [ ":amdgpu_compiler_impl", @@ -1147,10 +1147,9 @@ cc_library( "amdgpu_compiler.cc", ], hdrs = [ - "amdgpu_compiler.h" + "amdgpu_compiler.h", ], deps = [ - ":gpu_compiler_impl", "//tensorflow/core/platform:rocm_rocdl_path", ], )