Enabling ROCm parallel logic for gpu_conv_algorithm_picker

This commit is contained in:
jerryyin 2019-09-09 18:16:50 +00:00
parent a3fbb1c352
commit b21f969731
5 changed files with 208 additions and 60 deletions

View File

@ -16,9 +16,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
// TODO(whchung@gmail.com): Add gpu_conv_algorithm_picker after its PR merged.
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
@ -97,7 +97,7 @@ Status AMDGPUCompiler::OptimizeHloPostLayoutAssignment(
options.set_is_layout_sensitive(true);
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
// TODO(whchung@gmail.com): Add gpu_conv_algorithm_picker after its PR merged.
pipeline.AddPass<GpuConvAlgorithmPicker>(stream_exec, device_allocator);
// Clean up new_tuple described above.
pipeline.AddPass<TupleSimplifier>();

View File

@ -48,6 +48,54 @@ using se::DeviceMemoryBase;
using se::dnn::AlgorithmDesc;
using tensorflow::AutotuneResult;
class ScratchAllocator : public se::ScratchAllocator {
public:
ScratchAllocator(int device_ordinal,
se::DeviceMemoryAllocator* memory_allocator)
: device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
int64 GetMemoryLimitInBytes() override {
return 1LL << 32; // 4GB. TODO(jlebar): Tune this?
}
int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
StatusOr<se::DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override;
template <typename T>
StatusOr<se::DeviceMemory<T>> Allocate(int64 num_elements) {
TF_ASSIGN_OR_RETURN(se::DeviceMemory<uint8> bytes,
AllocateBytes(num_elements * sizeof(T)));
return se::DeviceMemory<T>(bytes);
}
private:
const int device_ordinal_;
se::DeviceMemoryAllocator* memory_allocator_;
std::vector<se::OwningDeviceMemory> allocated_buffers_;
int64 total_allocated_bytes_ = 0;
};
StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
int64 byte_size) {
CHECK_GE(byte_size, 0) << "byte_size must be positive.";
if (byte_size > GetMemoryLimitInBytes()) {
return se::port::Status(
se::port::error::RESOURCE_EXHAUSTED,
absl::StrFormat(
"Allocating %d bytes exceeds the memory limit of %d bytes.",
byte_size, GetMemoryLimitInBytes()));
}
TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer,
memory_allocator_->Allocate(device_ordinal_, byte_size,
/*retry_on_failure=*/false));
total_allocated_bytes_ += byte_size;
se::DeviceMemoryBase buffer_addr = *allocated_buffer;
allocated_buffers_.push_back(std::move(allocated_buffer));
return se::DeviceMemory<uint8>(buffer_addr);
}
std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
se::StreamExecutor* stream_exec) {
std::vector<AlgorithmDesc> algorithms;
@ -198,7 +246,7 @@ auto& autotune_cache_stats GUARDED_BY(autotune_cache_lock) =
*new ConvCacheStats();
} // anonymous namespace
StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithm(
const HloCustomCallInstruction* instr) {
// Don't run this function concurrently on the same GPU.
//
@ -226,22 +274,6 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
autotune_cache_stats.cache_misses++;
}
StatusOr<AutotuneResult> result_or = PickBestAlgorithmNoCache(instr);
if (result_or.ok()) {
tensorflow::mutex_lock lock(autotune_cache_lock);
CHECK(autotune_cache.insert({key, result_or.ValueOrDie()}).second);
}
return result_or;
}
StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
const HloCustomCallInstruction* instr) {
XLA_SCOPED_LOGGING_TIMER(
absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithmImpl for ",
instr->ToString()));
const Shape& result_shape = instr->shape().tuple_shapes(0);
// Make sure any previous activity on this executor is done. We don't want to
// interfere with programs that are still running on the GPU.
if (!stream_exec_->SynchronizeAllActivity()) {
@ -269,6 +301,34 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
return &stream_opt.value();
}();
StatusOr<AutotuneResult> result_or(InternalError("Unknown platform."));
// Check StreamExecutor on which platform it is. ROCm and Cuda implementation
// have diverged. Secifically, we need to make sure redzone allocator related
// utilities are not used in ROCm routine
if (stream_exec_->platform_kind() == se::PlatformKind::kROCm) {
result_or = PickBestAlgorithmNoCacheRocm(*instr, allocator, stream);
} else if (stream_exec_->platform_kind() == se::PlatformKind::kCuda) {
result_or = PickBestAlgorithmNoCacheCuda(*instr, allocator, stream);
}
if (result_or.ok()) {
tensorflow::mutex_lock lock(autotune_cache_lock);
CHECK(autotune_cache.insert({key, result_or.ValueOrDie()}).second);
}
return result_or;
}
StatusOr<tensorflow::AutotuneResult>
GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
const HloCustomCallInstruction& instr, se::DeviceMemoryAllocator* allocator,
se::Stream* stream) {
// Right now Redzone allocator is available in Cuda target only
XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
"GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr.ToString()));
const Shape& result_shape = instr.shape().tuple_shapes(0);
const auto device_ordinal = stream_exec_->device_ordinal();
int64 rng_state = 0;
const auto initialize_buffer = [&stream, &rng_state](
@ -277,13 +337,13 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
InitializeBuffer(stream, buffer_shape.element_type(), &rng_state, buffer);
};
const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
const HloModuleConfig& hlo_module_config = instr.GetModule()->config();
// Allocate space for the input, filter, and output of the convolution.
se::RedzoneAllocator input_output_allocator(
stream, allocator, PtxOptsFromConfig(hlo_module_config));
std::vector<se::DeviceMemoryBase> operand_buffers;
for (const auto* operand : instr->operands()) {
for (const auto* operand : instr.operands()) {
TF_ASSIGN_OR_RETURN(auto buffer,
input_output_allocator.AllocateBytes(
ShapeUtil::ByteSizeOf(operand->shape())));
@ -296,7 +356,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
initialize_buffer(result_buffer, result_shape);
TF_ASSIGN_OR_RETURN(auto backend_config,
instr->backend_config<CudnnConvBackendConfig>());
instr.backend_config<CudnnConvBackendConfig>());
optional<BufferComparator> comparator;
// Use the first algorithm that's supported as reference. There isn't a
@ -305,17 +365,17 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
se::DeviceMemoryBase reference_result_buffer;
AlgorithmDesc first_algorithm;
TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr));
TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(&instr));
std::vector<AutotuneResult> profile_results;
const DebugOptions& debug_options =
instr->GetModule()->config().debug_options();
instr.GetModule()->config().debug_options();
const bool crash_on_checking_failure =
debug_options.xla_gpu_crash_on_verification_failures();
const auto canonical_hlo =
std::get<1>(AutotuneCacheKeyfromInstruction(instr, stream_exec_));
std::get<1>(AutotuneCacheKeyfromInstruction(&instr, stream_exec_));
string blas_version;
if (auto* blas = stream_exec_->AsBlas()) {
@ -335,7 +395,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
if (absl::c_linear_search(blacklisted_algos, alg)) {
LOG(INFO) << "Omitted potentially buggy algorithm "
<< AlgorithmToString(alg) << " for conv " << instr->ToString();
<< AlgorithmToString(alg) << " for conv " << instr.ToString();
continue;
}
@ -343,7 +403,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
stream, allocator, PtxOptsFromConfig(hlo_module_config));
se::dnn::ProfileResult profile_result;
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
<< instr->ToString();
<< instr.ToString();
// Use assignment instead of brace-list to make GCC 4.9 happy.
RunConvOptions options;
@ -375,11 +435,11 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
// Check for writes to redzones.
TF_ASSIGN_OR_RETURN(bool input_output_allocator_redzone_clear,
CheckRedzones(input_output_allocator, stream,
"input/output", instr, &result));
"input/output", &instr, &result));
TF_ASSIGN_OR_RETURN(
bool scratch_allocator_redzone_clear,
CheckRedzones(scratch_allocator, stream, "scratch", instr, &result));
CheckRedzones(scratch_allocator, stream, "scratch", &instr, &result));
if (!input_output_allocator_redzone_clear ||
!scratch_allocator_redzone_clear) {
@ -410,7 +470,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
if (!compare_result.ok()) {
LOG(ERROR) << "Unable to compare " << AlgorithmToString(first_algorithm)
<< " against " << AlgorithmToString(alg) << " for "
<< instr->ToString() << ": " << compare_result.status();
<< instr.ToString() << ": " << compare_result.status();
if (compare_result.status().code() ==
tensorflow::error::RESOURCE_EXHAUSTED) {
// Possibly OOM. Propatate the error.
@ -421,12 +481,11 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
LOG(ERROR)
<< "Results mismatch between different convolution algorithms. "
"This is likely a bug/unexpected loss of precision in cudnn.\n"
<< instr->ToString() << " for "
<< AlgorithmToString(first_algorithm) << " vs "
<< AlgorithmToString(alg);
<< instr.ToString() << " for " << AlgorithmToString(first_algorithm)
<< " vs " << AlgorithmToString(alg);
PrintPlatformInfo(stream);
VLOG(1) << "Full module on failure: \n"
<< instr->GetModule()->ToString();
<< instr.GetModule()->ToString();
auto* fail = result.mutable_failure();
fail->set_kind(AutotuneResult::WRONG_RESULT);
fail->set_buffer_address(
@ -453,11 +512,11 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
tensorflow::AutotuningLog log;
{
ConvInstructionLog instr_log;
*instr_log.mutable_instruction() = instr->ToProto();
for (int i = 0; i < instr->operand_count(); i++) {
*instr_log.add_operand_shapes() = instr->operand(i)->shape().ToProto();
*instr_log.mutable_instruction() = instr.ToProto();
for (int i = 0; i < instr.operand_count(); i++) {
*instr_log.add_operand_shapes() = instr.operand(i)->shape().ToProto();
instr_log.add_operand_addresses(
reinterpret_cast<uint64>(operand_buffers[i].opaque()));
reinterpret_cast<uint64>((operand_buffers)[i].opaque()));
}
instr_log.set_result_address(
reinterpret_cast<uint64>(result_buffer.opaque()));
@ -523,11 +582,81 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
return InternalError(
"All algorithms tried for convolution %s failed. Falling back to "
"default algorithm.",
instr->ToString());
instr.ToString());
}
StatusOr<bool> CudnnConvAlgorithmPicker::RunOnInstruction(
HloInstruction* instr) {
StatusOr<tensorflow::AutotuneResult>
GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm(
const HloCustomCallInstruction& instr, se::DeviceMemoryAllocator* allocator,
se::Stream* stream) {
XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
"GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr.ToString()));
const auto device_ordinal = stream_exec_->device_ordinal();
std::vector<se::DeviceMemoryBase> operand_buffers;
ScratchAllocator input_output_allocator(device_ordinal, allocator);
const auto initialize_buffer = [stream](DeviceMemoryBase buffer) {
// Although we don't have evidence this matters, zero out the buffers
// before autotuning. It's conceivable that using uninitialized memory as
// the inputs might affect performance if e.g. the inputs contain
// denormals, and this is easy enough.
stream->ThenMemZero(&buffer, buffer.size());
};
// Allocate space for the input, filter, and output of the convolution. We
// use a ScratchAllocator for this instead of calling allocator_ directly so
// that our allocations don't leak.
for (const auto* operand : instr.operands()) {
TF_ASSIGN_OR_RETURN(auto buffer,
input_output_allocator.AllocateBytes(
ShapeUtil::ByteSizeOf(operand->shape())));
initialize_buffer(buffer);
operand_buffers.push_back(buffer);
}
TF_ASSIGN_OR_RETURN(
auto result_buffer,
input_output_allocator.AllocateBytes(
ShapeUtil::ByteSizeOf(instr.shape().tuple_shapes(0))));
initialize_buffer(result_buffer);
ScratchAllocator scratch_allocator(device_ordinal, allocator);
se::dnn::ProfileResult profile_result;
VLOG(3) << "Auto-tuning for " << instr.ToString();
RunConvOptions options;
options.profile_result = &profile_result;
// ROCm: Set the overriding algorithm to empty to remind cudnn_conv_runner
// that the AlgorithmConfig in running convolution needs to be empty
options.algo_override = se::dnn::AlgorithmDesc();
bool launch_ok =
RunCudnnConv(&instr, absl::MakeSpan(operand_buffers), result_buffer,
&scratch_allocator, stream, options)
.ok();
AutotuneResult best_result;
if (launch_ok && profile_result.is_valid()) {
best_result.mutable_conv()->set_algorithm(
profile_result.algorithm().algo_id());
best_result.mutable_conv()->set_tensor_ops_enabled(
profile_result.algorithm().tensor_ops_enabled());
int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
best_result.set_scratch_bytes(scratch_bytes_used);
*best_result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
return best_result;
}
return InternalError(
"All algorithms tried for convolution %s failed. Falling back to "
"default algorithm.",
instr.ToString());
}
StatusOr<bool> GpuConvAlgorithmPicker::RunOnInstruction(HloInstruction* instr) {
CHECK(IsCustomCallToDnnConvolution(*instr));
StatusOr<AutotuneResult> best_algo_or =
@ -577,7 +706,7 @@ StatusOr<bool> CudnnConvAlgorithmPicker::RunOnInstruction(
return true;
}
StatusOr<bool> CudnnConvAlgorithmPicker::RunOnComputation(
StatusOr<bool> GpuConvAlgorithmPicker::RunOnComputation(
HloComputation* computation) {
std::vector<HloInstruction*> convs;
for (auto* instr : computation->instructions()) {
@ -594,11 +723,11 @@ StatusOr<bool> CudnnConvAlgorithmPicker::RunOnComputation(
return changed;
}
StatusOr<bool> CudnnConvAlgorithmPicker::Run(HloModule* module) {
XLA_SCOPED_LOGGING_TIMER("CudnnConvAlgorithmPicker");
StatusOr<bool> GpuConvAlgorithmPicker::Run(HloModule* module) {
XLA_SCOPED_LOGGING_TIMER("GpuConvAlgorithmPicker");
if (module->config().debug_options().xla_gpu_disable_autotune()) {
VLOG(2) << "Convolution auto-tuning disabled, CudnnConvAlgorithmPicker "
VLOG(2) << "Convolution auto-tuning disabled, GpuConvAlgorithmPicker "
"returning early.";
return false;
}

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_
#include "absl/time/time.h"
#include "absl/types/optional.h"
@ -32,17 +32,17 @@ namespace gpu {
// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for
// each and adding explicit scratch space to the CustomCalls.
class CudnnConvAlgorithmPicker : public HloModulePass {
class GpuConvAlgorithmPicker : public HloModulePass {
public:
// If the `allocator` parameter is not null, we will use it to allocate temp
// memory while timing the various convolution algorithms. If it's null,
// we'll use the default allocator on the StreamExecutor.
CudnnConvAlgorithmPicker(se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* allocator)
GpuConvAlgorithmPicker(se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* allocator)
: stream_exec_(stream_exec), allocator_(allocator) {}
absl::string_view name() const override {
return "cudnn-conv-algorithm-picker";
return "gpu-conv-algorithm-picker";
}
StatusOr<bool> Run(HloModule* module) override;
@ -52,8 +52,14 @@ class CudnnConvAlgorithmPicker : public HloModulePass {
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
StatusOr<tensorflow::AutotuneResult> PickBestAlgorithm(
const HloCustomCallInstruction* instr);
StatusOr<tensorflow::AutotuneResult> PickBestAlgorithmNoCache(
const HloCustomCallInstruction* instr);
StatusOr<tensorflow::AutotuneResult> PickBestAlgorithmNoCacheCuda(
const HloCustomCallInstruction& instr,
se::DeviceMemoryAllocator* allocator, se::Stream* stream);
StatusOr<tensorflow::AutotuneResult> PickBestAlgorithmNoCacheRocm(
const HloCustomCallInstruction& instr,
se::DeviceMemoryAllocator* allocator, se::Stream* stream);
se::StreamExecutor* stream_exec_; // never null
se::DeviceMemoryAllocator* allocator_; // may be null
@ -61,5 +67,4 @@ class CudnnConvAlgorithmPicker : public HloModulePass {
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_

View File

@ -223,7 +223,16 @@ Status RunGpuConvImpl(const GpuConvParams& params,
auto output_buf = se::DeviceMemory<OutputType>(params.output_buf);
AlgorithmConfig algorithm = params.algorithm;
if (options.algo_override) {
// in ROCm mode, the first call to run the convolution needs to trigger the
// code that calls miopenFind* API. That triggger is implicit, it is based
// on whether or not the AlgorithmConfig::algorithm is empty! So for the
// first call we need to ensure that the AlgorithmConfig::algorithm is
// empty. For all subsequent calls, we should use the value retrieved from
// the backend_config
if ((options.algo_override.has_value()) &&
(*options.algo_override == se::dnn::AlgorithmDesc())) {
algorithm = AlgorithmConfig();
} else if (options.algo_override.has_value()) {
algorithm = AlgorithmConfig(*options.algo_override);
}
@ -261,8 +270,13 @@ StatusOr<GpuConvParams> GetGpuConvParams(
const Shape* filter_shape;
const Shape* output_shape;
params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
backend_config.algorithm(), backend_config.tensor_ops_enabled()));
// The third field is scratch size stored from conv_algorithm_picker
// The operand is added to the shape field of the conv instruction
// in GpuConvAlgorithmPicker::RunOnInstruction() call.
params.algorithm = se::dnn::AlgorithmConfig(
se::dnn::AlgorithmDesc(backend_config.algorithm(),
backend_config.tensor_ops_enabled()),
conv->shape().tuple_shapes(1).dimensions(0));
params.conv_result_scale = backend_config.conv_result_scale();
switch (params.kind) {

View File

@ -188,11 +188,11 @@ Status NVPTXCompiler::OptimizeHloPostLayoutAssignment(
// The new tuple and gte instructions then be simplified away, because
// nobody is expected to use the scratch value.
//
// However, if we were to run CudnnConvAlgorithmPicker after fusion
// However, if we were to run GpuConvAlgorithmPicker after fusion
// the gte(customcall, 0) would probably already be into a fusion node. We
// can't simplify across HloComputation boundaries, so in this case we
// wouldn't be able to simplify away the new_tuple bits.
pipeline.AddPass<CudnnConvAlgorithmPicker>(stream_exec, device_allocator);
pipeline.AddPass<GpuConvAlgorithmPicker>(stream_exec, device_allocator);
// Find the fastest algorithm for GEMMs.
pipeline.AddPass<GemmAlgorithmPicker>(stream_exec, device_allocator);