[XLA] Add a RedzoneAllocator, and use it for checking conv correctness.

Specifically, this checks for out-of-bounds reads and writes in cudnn.

This is a BACKWARDS-INCOMPATIBLE CHANGE to autotuning.proto.  We agreed this is
necessary to clean it up.

Incidentally this cleans up some code in the use of the autotuning proto.  Code
to sort autotuning results always checked whether the run was successful, but
in fact runs were never marked as failing.  (Runs that failed were simply never
entered into the vector of results.)  So I removed the checks for success.

PiperOrigin-RevId: 240903636
This commit is contained in:
Justin Lebar 2019-03-28 19:31:51 -07:00 committed by TensorFlower Gardener
parent d75adcad0d
commit 937bf0a2e6
14 changed files with 573 additions and 165 deletions

View File

@ -466,7 +466,7 @@ cc_library(
":gpu_autotuning_proto",
":gpu_executable",
":ir_emission_utils",
":scratch_allocator",
":redzone_allocator",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
@ -497,6 +497,39 @@ cc_library(
],
)
cc_library(
name = "redzone_allocator",
srcs = ["redzone_allocator.cc"],
hdrs = ["redzone_allocator.h"],
deps = [
":gpu_constants",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:stream_executor_no_cuda",
],
)
tf_cc_test(
name = "redzone_allocator_test",
srcs = ["redzone_allocator_test.cc"],
tags = tf_cuda_tests_tags(),
deps = [
":redzone_allocator",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
"//tensorflow/core/platform/default/build_config:stream_executor_cuda",
"//tensorflow/stream_executor:event",
"//tensorflow/stream_executor:kernel",
"//tensorflow/stream_executor/cuda:cuda_activation",
"//tensorflow/stream_executor/cuda:cuda_gpu_executor",
],
)
cc_library(
name = "cudnn_conv_runner",
srcs = ["cudnn_conv_runner.cc"],

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h"
#include "tensorflow/compiler/xla/service/gpu/redzone_allocator.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logger.h"
@ -122,6 +122,56 @@ tensorflow::ComputeCapability GetComputeCapability(
return cc;
}
// Returns true if the redzones in `allocator`'s allocations are unmodified.
//
// If the redzones are modified, logs an error, sets the appropriate failure
// bits on `result`, and returns false.
//
// `name` is a user-friendly name for the set of redzones being checked, e.g.
// "input/output" or "scratch".
bool CheckRedzones(const RedzoneAllocator& allocator, se::Stream* stream,
absl::string_view name, const HloInstruction* instr,
AutotuneResult* result) {
Status status = allocator.CheckRedzones(stream);
if (status.ok()) {
return true;
}
auto* fail = result->mutable_failure();
fail->set_kind(AutotuneResult::REDZONE_MODIFIED);
*fail->mutable_msg() = status.ToString();
LOG(ERROR) << absl::StreamFormat(
"Detected cudnn out-of-bounds write in conv %s buffer! This is likely a "
"cudnn bug. We will skip this algorithm in the future, but your GPU "
"state may already be corrupted, leading to incorrect results. Within "
"Google, no action is needed on your part. Outside of Google, please "
"ensure you're running the latest version of cudnn. If that doesn't fix "
"the problem, please file a bug with this full error message and we'll "
"contact nvidia.",
name);
LOG(ERROR) << status.ToString();
LOG(ERROR) << "HloInstruction " << instr->ToString();
auto* se = stream->parent();
const auto& desc = se->GetDeviceDescription();
LOG(ERROR) << "Device: " << desc.name();
LOG(ERROR) << "Platform: " << desc.platform_version();
LOG(ERROR) << "Driver: " << desc.driver_version();
LOG(ERROR) << "Runtime: " << desc.runtime_version();
auto* dnn = se->AsDnn();
if (dnn) {
auto dnn_version = dnn->GetVersion();
if (dnn_version.ok()) {
auto v = dnn_version.ValueOrDie();
LOG(ERROR) << "cudnn version: " << v.major_version() << "."
<< v.minor_version() << "." << v.patch();
}
}
return false;
}
} // anonymous namespace
// We could have caching here so that we don't redo this work for two identical
@ -208,10 +258,8 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
}
};
// Allocate space for the input, filter, and output of the convolution. We
// use a ScratchAllocator for this instead of calling allocator_ directly so
// that our allocations don't leak.
ScratchAllocator input_output_allocator(device_ordinal, allocator);
// Allocate space for the input, filter, and output of the convolution.
RedzoneAllocator input_output_allocator(device_ordinal, allocator);
std::vector<se::DeviceMemoryBase> operand_buffers;
for (const auto* operand : instr->operands()) {
TF_ASSIGN_OR_RETURN(auto buffer,
@ -245,7 +293,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
.xla_gpu_crash_on_verification_failures();
for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) {
ScratchAllocator scratch_allocator(device_ordinal, allocator);
RedzoneAllocator scratch_allocator(device_ordinal, allocator);
se::dnn::ProfileResult profile_result;
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
<< instr->ToString();
@ -271,12 +319,19 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
result.mutable_conv()->set_algorithm(alg.algo_id());
result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled());
int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
result.mutable_success()->set_scratch_bytes(scratch_bytes_used);
*result.mutable_success()->mutable_run_time() =
tensorflow::proto_utils::ToDurationProto(
int64 scratch_bytes_used =
scratch_allocator.TotalAllocatedBytesExcludingRedzones();
result.set_scratch_bytes(scratch_bytes_used);
*result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
// Check for writes to redzones.
if (!CheckRedzones(input_output_allocator, &stream, "input/output", instr,
&result) ||
!CheckRedzones(scratch_allocator, &stream, "scratch", instr, &result)) {
continue;
}
if (comparator.has_value()) {
StatusOr<bool> compare_result = comparator->CompareEqual(
&stream, allocator, reference_result_buffer, result_buffer);
@ -291,15 +346,18 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
}
CHECK(!crash_on_checking_failure);
} else if (!compare_result.ValueOrDie()) {
LOG(ERROR) << "Results mismatch between different convolution "
"algorithms. This is likely a bug in convolution, or "
"an excessive loss of precision in convolution. "
LOG(ERROR)
<< "Results mismatch between different convolution algorithms. "
"This is likely a bug/unexpected loss of precision in cudnn.\n"
<< instr->ToString() << " for "
<< AlgorithmToString(first_algorithm) << " vs "
<< AlgorithmToString(alg);
auto* failure = result.mutable_reference_conv();
failure->set_algorithm(first_algorithm.algo_id());
failure->set_tensor_ops_enabled(first_algorithm.tensor_ops_enabled());
auto* fail = result.mutable_failure();
fail->set_kind(AutotuneResult::WRONG_RESULT);
auto* reference_conv = fail->mutable_reference_conv();
reference_conv->set_algorithm(first_algorithm.algo_id());
reference_conv->set_tensor_ops_enabled(
first_algorithm.tensor_ops_enabled());
}
} else {
auto comp =
@ -336,35 +394,46 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
}
*log.mutable_compute_capability() = GetComputeCapability(stream_exec_);
*log.mutable_cudnn_version() = GetCudnnVersion(stream_exec_);
log.set_device_pci_bus_id(
stream_exec_->GetDeviceDescription().pci_bus_id());
VLOG(2) << "Autotuning result:\n" << log.DebugString();
tensorflow::Logger::Singleton()->LogProto(log);
}
// Crash on miscompares and redzone violations if desired. Do this after
// logging the autotuning results, otherwise we won't get any data!
for (const auto& result : profile_results) {
if (result.has_reference_conv()) {
if (result.has_failure()) {
CHECK(!crash_on_checking_failure);
}
}
auto* profile_results_end = profile_results.data() + profile_results.size();
const AutotuneResult* best_result = std::min_element(
profile_results.data(), profile_results_end,
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
// Choose the fastest convolution that doesn't produce a REDZONE_MODIFIED
// error.
//
// For now, we ignore WRONG_RESULT failures because false-positives are
// possible (e.g. perhaps the reference algorithm is the one that's
// incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're
// quite severe and can be detected with high accuracy.
//
// TODO(jlebar): We ought to be able to detect redzone reads by noticing NaNs
// in the output of the conv and skip those.
//
// The successful one should have a smaller key, since we are doing
// min_element. If they are both unsuccessful, keep the earlier one in
// the vector by comparing pointers.
return std::make_tuple(!lhs.has_success(),
tensorflow::proto_utils::FromDurationProto(
lhs.success().run_time()),
&lhs) <
std::make_tuple(!rhs.has_success(),
tensorflow::proto_utils::FromDurationProto(
rhs.success().run_time()),
&rhs);
auto result_comparison_key = [](const AutotuneResult& r) {
return std::make_tuple(
r.has_failure() && r.failure().kind() != AutotuneResult::WRONG_RESULT,
tensorflow::proto_utils::FromDurationProto(r.run_time()));
};
const auto& best_result = absl::c_min_element(
profile_results,
[&](const AutotuneResult& lhs, const AutotuneResult& rhs) {
return result_comparison_key(lhs) < result_comparison_key(rhs);
});
if (best_result != profile_results_end && best_result->has_success()) {
if (best_result != profile_results.end() && !best_result->has_failure()) {
return *best_result;
}
@ -388,7 +457,7 @@ StatusOr<bool> CudnnConvAlgorithmPicker::RunOnInstruction(
auto best_algo = std::move(best_algo_or).ValueOrDie();
VLOG(1) << "Setting cudnn conv to use algorithm "
<< best_algo.conv().algorithm() << " and "
<< NumBytesToString(best_algo.success().scratch_bytes())
<< NumBytesToString(best_algo.scratch_bytes())
<< " of scratch memory: " << instr->ToString()
<< " tensor_ops_enabled: " << best_algo.conv().tensor_ops_enabled();
@ -397,7 +466,7 @@ StatusOr<bool> CudnnConvAlgorithmPicker::RunOnInstruction(
HloComputation* computation = instr->parent();
Shape new_call_shape = ShapeUtil::MakeTupleShape(
{instr->shape().tuple_shapes(0),
ShapeUtil::MakeShape(U8, {best_algo.success().scratch_bytes()})});
ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes()})});
TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
instr->backend_config<CudnnConvBackendConfig>());

View File

@ -0,0 +1,126 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/redzone_allocator.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace xla {
namespace gpu {
// The size of the redzone at the end of the user buffer is rounded up to a
// multiple of kRhsRedzoneAlign. This simplifies the implementation a bit.
constexpr int64 kRhsRedzoneAlign = 4;
StatusOr<se::DeviceMemory<uint8>> RedzoneAllocator::AllocateBytes(
se::Stream* stream, int64 byte_size) {
CHECK_GE(byte_size, 0) << "byte_size must be positive.";
if (byte_size > GetMemoryLimitInBytes(stream)) {
return se::port::Status(
se::port::error::RESOURCE_EXHAUSTED,
absl::StrFormat(
"Allocating %d bytes exceeds the memory limit of %d bytes.",
byte_size, GetMemoryLimitInBytes(stream)));
}
int64 rhs_slop = RoundUpToNearest(byte_size, kRhsRedzoneAlign) - byte_size;
TF_ASSIGN_OR_RETURN(
OwningDeviceMemory allocated_buffer,
memory_allocator_->Allocate(device_ordinal_,
byte_size + 2 * redzone_size_ + rhs_slop,
/*retry_on_failure=*/false));
allocated_bytes_excluding_redzones_ += byte_size;
char* addr =
reinterpret_cast<char*>(allocated_buffer.AsDeviceMemoryBase().opaque());
se::DeviceMemoryBase lhs_redzone(addr, redzone_size_,
/*is_sub_buffer=*/true);
// Split up the RHS redzone into two pieces:
// - 0 to kRhsRedzoneAlign bytes adjacent to the user buffer, followed by
// - redzone_size_ bytes.
// We do this because Stream::ThenMemset32 requires the buffer address and
// size to be aligned to 4 bytes.
se::DeviceMemoryBase rhs_redzone_slop(addr + redzone_size_ + byte_size,
rhs_slop, /*is_sub_buffer=*/true);
se::DeviceMemoryBase rhs_redzone_nonslop(
addr + redzone_size_ + byte_size + rhs_slop, redzone_size_,
/*is_sub_buffer=*/true);
uint8 pattern_arr[] = {redzone_pattern_, redzone_pattern_, redzone_pattern_,
redzone_pattern_};
uint32 pattern32;
std::memcpy(&pattern32, pattern_arr, sizeof(pattern32));
stream->ThenMemset32(&lhs_redzone, pattern32, redzone_size_);
if (rhs_slop != 0) {
stream->ThenMemcpy(&rhs_redzone_slop, &pattern32, rhs_slop);
}
stream->ThenMemset32(&rhs_redzone_nonslop, pattern32, redzone_size_);
allocated_buffers_.emplace_back(std::move(allocated_buffer), byte_size);
return se::DeviceMemory<uint8>(se::DeviceMemoryBase(
addr + redzone_size_, byte_size, /*is_sub_buffer=*/true));
}
Status RedzoneAllocator::CheckRedzones(se::Stream* stream) const {
for (const auto& buf_and_size : allocated_buffers_) {
const auto& allocated_buf = buf_and_size.first;
int64 user_alloc_size = buf_and_size.second;
char* addr =
reinterpret_cast<char*>(allocated_buf.AsDeviceMemoryBase().opaque());
// user_alloc_size isn't necessarily the same as
// allocated_buf.size() - 2 * redzone_size_ because if user_alloc_size was
// not a multiple of kRhsRedzoneAlign, we rounded it up.
se::DeviceMemoryBase buf(addr + redzone_size_, user_alloc_size,
/*is_sub_buffer=*/true);
TF_RETURN_IF_ERROR(CheckBufferRedzones(buf, stream));
}
return Status::OK();
}
Status RedzoneAllocator::CheckBufferRedzones(se::DeviceMemoryBase buf,
se::Stream* stream) const {
char* buf_start = reinterpret_cast<char*>(buf.opaque());
auto check_redzone = [&](int64 offset, int64 size, absl::string_view name) {
se::DeviceMemoryBase redzone(buf_start + offset, size,
/*is_sub_buffer=*/true);
auto redzone_data = absl::make_unique<uint8[]>(size);
TF_RETURN_IF_ERROR(stream->ThenMemcpy(redzone_data.get(), redzone, size)
.BlockHostUntilDone());
for (int64 i = 0; i < size; ++i) {
uint8 rz_value = redzone_data[i];
if (rz_value != redzone_pattern_) {
return InternalError(
"Redzone mismatch in %s redzone of buffer %p at offset %d; "
"expected %08x but was %08x.",
name, buf.opaque(), i, redzone_pattern_, rz_value);
}
}
return Status::OK();
};
// `buf` points to the buffer returned to the user, so the LHS redzone starts
// before `buf`.
TF_RETURN_IF_ERROR(check_redzone(-redzone_size_, redzone_size_, "LHS"));
int64 rhs_slop =
RoundUpToNearest<int64>(buf.size(), kRhsRedzoneAlign) - buf.size();
TF_RETURN_IF_ERROR(
check_redzone(buf.size(), redzone_size_ + rhs_slop, "RHS"));
return Status::OK();
}
} // namespace gpu
} // namespace xla

View File

@ -0,0 +1,95 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDZONE_ALLOCATOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDZONE_ALLOCATOR_H_
#include <vector>
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/service/owning_device_memory.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
namespace gpu {
// An allocator that allocates a bit of extra memory around the beginning/end of
// every allocation and can check that this memory is unmodified.
//
// This can be used to check for out-of-bounds writes, and, if the redzone is
// filled with a sufficiently "ugly" pattern, may also be able to check for
// out-of-bounds reads. The default fill pattern of -1 is an unusual NaN
// pattern when interpreted as a floating-point number, so hopefully works for
// out-of-bounds reads and writes in those cases.
//
// This class implements se::ScratchAllocator, so can be used to allocate temp
// memory for cudnn convolutions.
class RedzoneAllocator : public se::ScratchAllocator {
public:
RedzoneAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator,
int64 redzone_size = 1 << 23, // 8MiB per side, 16MiB total
uint8 redzone_pattern = -1)
: device_ordinal_(device_ordinal),
redzone_size_(
RoundUpToNearest(redzone_size, kXlaAllocatedBufferAlignBytes)),
redzone_pattern_(redzone_pattern),
memory_allocator_(memory_allocator) {}
// Redzones don't count towards the memory limit.
int64 GetMemoryLimitInBytes(se::Stream* stream) override {
return 1LL << 32; // 4GB. TODO(jlebar): Tune this?
}
int64 TotalAllocatedBytesExcludingRedzones() const {
return allocated_bytes_excluding_redzones_;
}
StatusOr<se::DeviceMemory<uint8>> AllocateBytes(se::Stream* stream,
int64 byte_size) override;
// Determines whether redzones around all allocated buffers are unmodified.
Status CheckRedzones(se::Stream* stream) const;
private:
// Checks that one buffer's redzones are unmodified. buf should point to the
// user-editable buffer, i.e. it should not include redzones.
Status CheckBufferRedzones(se::DeviceMemoryBase buf,
se::Stream* stream) const;
const int device_ordinal_;
// Redzone size on *one side* of allocation.
//
// Must be a multiple of kXlaAllocatedBufferAlignBytes, otherwise the buffers
// returned to users will be misaligned.
const int64 redzone_size_;
const uint8 redzone_pattern_;
DeviceMemoryAllocator* memory_allocator_;
// The second element of the pair is the size of the user allocation. This
// isn't necessarily just first.size() - 2 * redzone_size_ because when the
// user allocation size is not a multiple of 4 bytes, we round up the size of
// the RHS redzone.
std::vector<std::pair<OwningDeviceMemory, int64>> allocated_buffers_;
int64 allocated_bytes_excluding_redzones_ = 0;
};
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDZONE_ALLOCATOR_H_

View File

@ -0,0 +1,107 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/redzone_allocator.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/platform.h"
namespace xla {
namespace gpu {
namespace {
TEST(RedzoneAllocatorTest, WriteToRedzone) {
constexpr int64 kRedzoneSize = 1 << 23; // 8MiB redzone on each side
// Redzone pattern should not be equal to zero; otherwise modify_redzone will
// break.
constexpr uint8 kRedzonePattern = 0x7e;
// Allocate 32MiB + 1 byte (to make things misaligned)
constexpr int64 kAllocSize = (1 << 25) + 1;
se::Platform* platform =
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
se::StreamExecutor* stream_exec = platform->ExecutorForDevice(0).ValueOrDie();
StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec});
RedzoneAllocator allocator(/*device_ordinal=*/0, &se_allocator, kRedzoneSize,
kRedzonePattern);
se::Stream stream(stream_exec);
stream.Init();
TF_ASSERT_OK_AND_ASSIGN(se::DeviceMemory<uint8> buf,
allocator.AllocateBytes(&stream,
/*byte_size=*/kAllocSize));
TF_EXPECT_OK(allocator.CheckRedzones(&stream));
char* buf_addr = reinterpret_cast<char*>(buf.opaque());
se::DeviceMemoryBase lhs_redzone(buf_addr - kRedzoneSize, kRedzoneSize,
/*is_sub_buffer=*/true);
se::DeviceMemoryBase rhs_redzone(buf_addr + kAllocSize, kRedzoneSize,
/*is_sub_buffer=*/true);
// Check that the redzones are in fact filled with kRedzonePattern.
auto check_redzone = [&](se::DeviceMemoryBase redzone,
absl::string_view name) {
std::vector<uint8> host_buf(kRedzoneSize);
TF_ASSERT_OK(stream.ThenMemcpy(host_buf.data(), redzone, kRedzoneSize)
.BlockHostUntilDone());
const int64 kMaxMismatches = 16;
int64 mismatches = 0;
for (int64 i = 0; i < host_buf.size(); ++i) {
if (mismatches == kMaxMismatches) {
ADD_FAILURE() << "Hit max number of mismatches; skipping others.";
break;
}
if (host_buf[i] != kRedzonePattern) {
++mismatches;
EXPECT_EQ(host_buf[i], kRedzonePattern)
<< "at index " << i << " of " << name << " redzone";
}
}
};
check_redzone(lhs_redzone, "lhs");
check_redzone(rhs_redzone, "rhs");
// Modifies a redzone, checks that RedzonesAreUnmodified returns false, then
// reverts it back to its original value and checks that RedzonesAreUnmodified
// returns true.
auto modify_redzone = [&](se::DeviceMemoryBase redzone, int64 offset,
absl::string_view name) {
SCOPED_TRACE(absl::StrCat(name, ", offset=", offset));
se::DeviceMemoryBase redzone_at_offset(
reinterpret_cast<char*>(redzone.opaque()) + offset, 1,
/*is_sub_buffer=*/true);
char old_redzone_value = 0;
TF_EXPECT_OK(allocator.CheckRedzones(&stream));
stream.ThenMemcpy(&old_redzone_value, redzone_at_offset, 1)
.ThenMemZero(&redzone_at_offset, 1);
EXPECT_FALSE(allocator.CheckRedzones(&stream).ok());
stream.ThenMemcpy(&redzone_at_offset, &old_redzone_value, 1);
TF_EXPECT_OK(allocator.CheckRedzones(&stream));
};
modify_redzone(lhs_redzone, /*offset=*/0, "lhs");
modify_redzone(lhs_redzone, /*offset=*/kRedzoneSize - 1, "lhs");
modify_redzone(rhs_redzone, /*offset=*/0, "rhs");
modify_redzone(rhs_redzone, /*offset=*/kRedzoneSize - 1, "rhs");
}
} // namespace
} // namespace gpu
} // namespace xla

View File

@ -100,8 +100,15 @@ class OwningDeviceMemory {
// !is_null() is sufficient but not necessary to imply `this` is active.
bool is_null() const { return mem_.is_null(); }
se::DeviceMemoryBase AsDeviceMemoryBase() {
return se::DeviceMemoryBase(opaque(), size(), /*is_sub_buffer=*/false);
se::DeviceMemoryBase AsDeviceMemoryBase() const {
// This const_cast is necessary because DeviceMemoryBase's constructor
// doesn't accept a const void*. This isn't ideal, but it's better than the
// alternative of making a AsDeviceMemoryBase non-const member function.
//
// This is safe (i.e. not UB) because the casted pointer is derived from a
// non-const pointer, namely mem_.opaque().
return se::DeviceMemoryBase(const_cast<void*>(opaque()), size(),
/*is_sub_buffer=*/false);
}
// Returns the wrapped DeviceMemoryBase without freeing it, and deactivates

View File

@ -334,6 +334,7 @@ void LogFusedConvAutotuneResults(const NodeDef& node, const Tensor& input,
*log.mutable_cudnn_version() = internal::GetCudnnVersion(stream_exec);
*log.mutable_compute_capability() =
internal::GetComputeCapability(stream_exec);
log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id());
for (const auto& result : results) {
*log.add_results() = result;
}
@ -342,38 +343,29 @@ void LogFusedConvAutotuneResults(const NodeDef& node, const Tensor& input,
Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
se::dnn::AlgorithmConfig* algo) {
// For the "!xhs.has_success()" below, this is because we want successful ones
// to order first, therefore they need a smaller key per "min_element".
const AutotuneResult* best_result = std::min_element(
results.begin(), results.end(),
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
return std::make_tuple(
!lhs.has_success(),
internal::FromDurationProto(lhs.success().run_time())) <
std::make_tuple(
!rhs.has_success(),
internal::FromDurationProto(rhs.success().run_time()));
return internal::FromDurationProto(lhs.run_time()) <
internal::FromDurationProto(rhs.run_time());
});
const AutotuneResult* best_result_no_scratch = std::min_element(
results.begin(), results.end(),
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
return std::make_tuple(
!lhs.has_success(), lhs.success().scratch_bytes(),
internal::FromDurationProto(lhs.success().run_time())) <
std::make_tuple(
!rhs.has_success(), rhs.success().scratch_bytes(),
internal::FromDurationProto(rhs.success().run_time()));
return std::make_tuple(lhs.scratch_bytes(),
internal::FromDurationProto(lhs.run_time())) <
std::make_tuple(rhs.scratch_bytes(),
internal::FromDurationProto(rhs.run_time()));
});
if (best_result == results.end() || !best_result->has_success()) {
if (best_result == results.end()) {
return errors::NotFound("No algorithm worked!");
}
algo->set_algorithm({best_result->conv().algorithm(),
best_result->conv().tensor_ops_enabled()});
if (best_result_no_scratch != results.end() &&
best_result_no_scratch->has_success() &&
best_result_no_scratch->success().scratch_bytes() == 0) {
best_result_no_scratch->scratch_bytes() == 0) {
algo->set_algorithm_no_scratch(
{best_result_no_scratch->conv().algorithm(),
best_result_no_scratch->conv().tensor_ops_enabled()});
@ -726,21 +718,17 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
output_desc, &output_ptr, &scratch_allocator,
dnn::AlgorithmConfig(profile_algorithm), &profile_result)
.ok();
if (cudnn_launch_status) {
if (profile_result.is_valid()) {
if (cudnn_launch_status && profile_result.is_valid()) {
results.emplace_back();
auto& result = results.back();
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_algorithm.tensor_ops_enabled());
result.mutable_success()->set_scratch_bytes(
scratch_allocator.TotalByteSize());
*result.mutable_success()->mutable_run_time() =
internal::ToDurationProto(
result.set_scratch_bytes(scratch_allocator.TotalByteSize());
*result.mutable_run_time() = internal::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
}
}
}
internal::LogFusedConvAutotuneResults(ctx->op_kernel().def(), *conv_input,
*filter, *output, bias, side_input,
stream->parent(), results);

View File

@ -858,21 +858,17 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
&scratch_allocator, AlgorithmConfig(profile_algorithm),
&profile_result)
.ok();
if (cudnn_launch_status) {
if (profile_result.is_valid()) {
if (cudnn_launch_status && profile_result.is_valid()) {
results.emplace_back();
auto& result = results.back();
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_algorithm.tensor_ops_enabled());
result.mutable_success()->set_scratch_bytes(
scratch_allocator.TotalByteSize());
*result.mutable_success()->mutable_run_time() =
proto_utils::ToDurationProto(
result.set_scratch_bytes(scratch_allocator.TotalByteSize());
*result.mutable_run_time() = proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
}
}
}
LogConvAutotuneResults(ctx->op_kernel().def(), transformed_input,
pre_transformed_filter_backprop,
transformed_out_backprop, stream->parent(), results);

View File

@ -969,21 +969,17 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
AlgorithmConfig(profile_algorithm), &profile_result)
.ok();
if (cudnn_launch_status) {
if (profile_result.is_valid()) {
if (cudnn_launch_status && profile_result.is_valid()) {
results.emplace_back();
auto& result = results.back();
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_algorithm.tensor_ops_enabled());
result.mutable_success()->set_scratch_bytes(
scratch_allocator.TotalByteSize());
*result.mutable_success()->mutable_run_time() =
proto_utils::ToDurationProto(
result.set_scratch_bytes(scratch_allocator.TotalByteSize());
*result.mutable_run_time() = proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
}
}
}
LogConvAutotuneResults(ctx->op_kernel().def(), pre_transformed_in_backprop,
transformed_filter, transformed_out_backprop,
stream->parent(), results);

View File

@ -870,21 +870,17 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
output_desc, &output_ptr, &scratch_allocator,
AlgorithmConfig(profile_algorithm), &profile_result)
.ok();
if (cudnn_launch_status) {
if (profile_result.is_valid()) {
if (cudnn_launch_status && profile_result.is_valid()) {
results.emplace_back();
auto& result = results.back();
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_algorithm.tensor_ops_enabled());
result.mutable_success()->set_scratch_bytes(
scratch_allocator.TotalByteSize());
*result.mutable_success()->mutable_run_time() =
proto_utils::ToDurationProto(
result.set_scratch_bytes(scratch_allocator.TotalByteSize());
*result.mutable_run_time() = proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
}
}
}
LogConvAutotuneResults(ctx->op_kernel().def(), input, transformed_filter,
transformed_output, stream->parent(), results);
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));

View File

@ -467,10 +467,8 @@ struct LaunchConvOp<GPUDevice, T> {
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_algorithm.tensor_ops_enabled());
result.mutable_success()->set_scratch_bytes(
scratch_allocator.TotalByteSize());
*result.mutable_success()->mutable_run_time() =
proto_utils::ToDurationProto(
result.set_scratch_bytes(scratch_allocator.TotalByteSize());
*result.mutable_run_time() = proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
}
}

View File

@ -536,10 +536,8 @@ Status FindBestConvolveAlgorithm(const FusedConvParameters& params,
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_algorithm.tensor_ops_enabled());
result.mutable_success()->set_scratch_bytes(
scratch_allocator.TotalByteSize());
*result.mutable_success()->mutable_run_time() =
proto_utils::ToDurationProto(
result.set_scratch_bytes(scratch_allocator.TotalByteSize());
*result.mutable_run_time() = proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
}
}

View File

@ -71,6 +71,7 @@ void LogConvAutotuneResults(const NodeDef& node, const Tensor& input,
log.mutable_instr()->PackFrom(std::move(instr));
*log.mutable_cudnn_version() = GetCudnnVersion(stream_exec);
*log.mutable_compute_capability() = GetComputeCapability(stream_exec);
log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id());
for (const auto& result : results) {
*log.add_results() = result;
}
@ -101,6 +102,7 @@ void LogFusedConvAutotuneResults(const NodeDef& node, const Tensor& input,
log.mutable_instr()->PackFrom(std::move(instr));
*log.mutable_cudnn_version() = GetCudnnVersion(stream_exec);
*log.mutable_compute_capability() = GetComputeCapability(stream_exec);
log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id());
for (const auto& result : results) {
*log.add_results() = result;
}
@ -109,38 +111,32 @@ void LogFusedConvAutotuneResults(const NodeDef& node, const Tensor& input,
Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
se::dnn::AlgorithmConfig* algo) {
// For the "!xhs.has_success()" below, this is because we want successful ones
// to order first, therefore they need a smaller key per "min_element".
// TODO(jlebar): Exclude conv ops with failures, once we have failure checking
// and have confidence that it's correct.
const AutotuneResult* best_result = std::min_element(
results.begin(), results.end(),
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
return std::make_tuple(
!lhs.has_success(),
proto_utils::FromDurationProto(lhs.success().run_time())) <
std::make_tuple(
!rhs.has_success(),
proto_utils::FromDurationProto(rhs.success().run_time()));
return proto_utils::FromDurationProto(lhs.run_time()) <
proto_utils::FromDurationProto(rhs.run_time());
});
const AutotuneResult* best_result_no_scratch = std::min_element(
results.begin(), results.end(),
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
return std::make_tuple(
!lhs.has_success(), lhs.success().scratch_bytes(),
proto_utils::FromDurationProto(lhs.success().run_time())) <
std::make_tuple(
!rhs.has_success(), rhs.success().scratch_bytes(),
proto_utils::FromDurationProto(rhs.success().run_time()));
return std::make_tuple(lhs.scratch_bytes(),
proto_utils::FromDurationProto(lhs.run_time())) <
std::make_tuple(rhs.scratch_bytes(),
proto_utils::FromDurationProto(rhs.run_time()));
});
if (best_result == results.end() || !best_result->has_success()) {
if (best_result == results.end()) {
return errors::NotFound("No algorithm worked!");
}
algo->set_algorithm({best_result->conv().algorithm(),
best_result->conv().tensor_ops_enabled()});
if (best_result_no_scratch != results.end() &&
best_result_no_scratch->has_success() &&
best_result_no_scratch->success().scratch_bytes() == 0) {
best_result_no_scratch->scratch_bytes() == 0) {
algo->set_algorithm_no_scratch(
{best_result_no_scratch->conv().algorithm(),
best_result_no_scratch->conv().tensor_ops_enabled()});

View File

@ -22,9 +22,25 @@ message ComputeCapability {
}
message AutotuneResult {
message SuccessResult {
int64 scratch_bytes = 1;
google.protobuf.Duration run_time = 2;
enum FailureKind {
UNKNOWN = 0;
REDZONE_MODIFIED = 1;
WRONG_RESULT = 2;
}
message FailureResult {
FailureKind kind = 1;
string msg = 2;
// For failure_kind == WRONG_RESULT, this field indicates the reference
// configuration that we compared against.
//
// Note that the reference algorithm isn't always correct. However,
// empirically it's more correct, as it's "algo 0", less fancy than the
// compared one.
oneof key {
ConvKey reference_conv = 11;
}
}
message ConvKey {
@ -32,34 +48,16 @@ message AutotuneResult {
bool tensor_ops_enabled = 2;
}
// If the conv runs successfully, success will be populated with the
// autotuning result. Otherwise, the error message is propagated.
oneof result {
SuccessResult success = 3;
string error_string = 4;
}
int64 scratch_bytes = 8;
google.protobuf.Duration run_time = 9;
FailureResult failure = 7;
oneof key {
ConvKey conv = 5;
}
// Sometimes we run a correctness checker during autotuning. It compares the
// result buffer content between two algorithms, say, "reference" and "test"
// algorithms. The "test" algorithm is the one associated with this
// AutotuneResult.
//
// This field records the reference algorithm used. Notice that naming it
// "reference" doesn't mean it's always correct. However, empirically it's
// more correct, as it's "algo 0", less fancy than the compared one.
//
// Notice that the checker_failure may exist even in the success case.
// This is because the error string in `result` comes from the underlying
// implementation like cuDNN, which isn't aware that it produced an incorrect
// result. And even if the checker detects an incorrect result, we can still
// retrieve scratch_bytes and runtime_ms.
oneof checker_failure {
ConvKey reference_conv = 6;
}
// Next ID: 12
}
message AutotuningLog {
@ -70,4 +68,9 @@ message AutotuningLog {
CudnnVersion cudnn_version = 3;
ComputeCapability compute_capability = 4;
// stream_executor::DeviceDescription::pci_bus_id.
string device_pci_bus_id = 5;
// Next ID: 6
}