diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 033d6a60e12..8dff06345fa 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -466,7 +466,7 @@ cc_library( ":gpu_autotuning_proto", ":gpu_executable", ":ir_emission_utils", - ":scratch_allocator", + ":redzone_allocator", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", @@ -497,6 +497,39 @@ cc_library( ], ) +cc_library( + name = "redzone_allocator", + srcs = ["redzone_allocator.cc"], + hdrs = ["redzone_allocator.h"], + deps = [ + ":gpu_constants", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +tf_cc_test( + name = "redzone_allocator_test", + srcs = ["redzone_allocator_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":redzone_allocator", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + "//tensorflow/core/platform/default/build_config:stream_executor_cuda", + "//tensorflow/stream_executor:event", + "//tensorflow/stream_executor:kernel", + "//tensorflow/stream_executor/cuda:cuda_activation", + "//tensorflow/stream_executor/cuda:cuda_gpu_executor", + ], +) + cc_library( name = "cudnn_conv_runner", srcs = ["cudnn_conv_runner.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc index 5a6c6ed08e1..272793629db 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/redzone_allocator.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logger.h" @@ -122,6 +122,56 @@ tensorflow::ComputeCapability GetComputeCapability( return cc; } +// Returns true if the redzones in `allocator`'s allocations are unmodified. +// +// If the redzones are modified, logs an error, sets the appropriate failure +// bits on `result`, and returns false. +// +// `name` is a user-friendly name for the set of redzones being checked, e.g. +// "input/output" or "scratch". +bool CheckRedzones(const RedzoneAllocator& allocator, se::Stream* stream, + absl::string_view name, const HloInstruction* instr, + AutotuneResult* result) { + Status status = allocator.CheckRedzones(stream); + if (status.ok()) { + return true; + } + + auto* fail = result->mutable_failure(); + fail->set_kind(AutotuneResult::REDZONE_MODIFIED); + *fail->mutable_msg() = status.ToString(); + + LOG(ERROR) << absl::StreamFormat( + "Detected cudnn out-of-bounds write in conv %s buffer! This is likely a " + "cudnn bug. We will skip this algorithm in the future, but your GPU " + "state may already be corrupted, leading to incorrect results. Within " + "Google, no action is needed on your part. Outside of Google, please " + "ensure you're running the latest version of cudnn. If that doesn't fix " + "the problem, please file a bug with this full error message and we'll " + "contact nvidia.", + name); + LOG(ERROR) << status.ToString(); + LOG(ERROR) << "HloInstruction " << instr->ToString(); + + auto* se = stream->parent(); + const auto& desc = se->GetDeviceDescription(); + LOG(ERROR) << "Device: " << desc.name(); + LOG(ERROR) << "Platform: " << desc.platform_version(); + LOG(ERROR) << "Driver: " << desc.driver_version(); + LOG(ERROR) << "Runtime: " << desc.runtime_version(); + + auto* dnn = se->AsDnn(); + if (dnn) { + auto dnn_version = dnn->GetVersion(); + if (dnn_version.ok()) { + auto v = dnn_version.ValueOrDie(); + LOG(ERROR) << "cudnn version: " << v.major_version() << "." + << v.minor_version() << "." << v.patch(); + } + } + return false; +} + } // anonymous namespace // We could have caching here so that we don't redo this work for two identical @@ -208,10 +258,8 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( } }; - // Allocate space for the input, filter, and output of the convolution. We - // use a ScratchAllocator for this instead of calling allocator_ directly so - // that our allocations don't leak. - ScratchAllocator input_output_allocator(device_ordinal, allocator); + // Allocate space for the input, filter, and output of the convolution. + RedzoneAllocator input_output_allocator(device_ordinal, allocator); std::vector operand_buffers; for (const auto* operand : instr->operands()) { TF_ASSIGN_OR_RETURN(auto buffer, @@ -245,7 +293,7 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( .xla_gpu_crash_on_verification_failures(); for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) { - ScratchAllocator scratch_allocator(device_ordinal, allocator); + RedzoneAllocator scratch_allocator(device_ordinal, allocator); se::dnn::ProfileResult profile_result; VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); @@ -271,11 +319,18 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( result.mutable_conv()->set_algorithm(alg.algo_id()); result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled()); - int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); - result.mutable_success()->set_scratch_bytes(scratch_bytes_used); - *result.mutable_success()->mutable_run_time() = - tensorflow::proto_utils::ToDurationProto( - absl::Milliseconds(profile_result.elapsed_time_in_ms())); + int64 scratch_bytes_used = + scratch_allocator.TotalAllocatedBytesExcludingRedzones(); + result.set_scratch_bytes(scratch_bytes_used); + *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); + + // Check for writes to redzones. + if (!CheckRedzones(input_output_allocator, &stream, "input/output", instr, + &result) || + !CheckRedzones(scratch_allocator, &stream, "scratch", instr, &result)) { + continue; + } if (comparator.has_value()) { StatusOr compare_result = comparator->CompareEqual( @@ -291,15 +346,18 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( } CHECK(!crash_on_checking_failure); } else if (!compare_result.ValueOrDie()) { - LOG(ERROR) << "Results mismatch between different convolution " - "algorithms. This is likely a bug in convolution, or " - "an excessive loss of precision in convolution. " - << instr->ToString() << " for " - << AlgorithmToString(first_algorithm) << " vs " - << AlgorithmToString(alg); - auto* failure = result.mutable_reference_conv(); - failure->set_algorithm(first_algorithm.algo_id()); - failure->set_tensor_ops_enabled(first_algorithm.tensor_ops_enabled()); + LOG(ERROR) + << "Results mismatch between different convolution algorithms. " + "This is likely a bug/unexpected loss of precision in cudnn.\n" + << instr->ToString() << " for " + << AlgorithmToString(first_algorithm) << " vs " + << AlgorithmToString(alg); + auto* fail = result.mutable_failure(); + fail->set_kind(AutotuneResult::WRONG_RESULT); + auto* reference_conv = fail->mutable_reference_conv(); + reference_conv->set_algorithm(first_algorithm.algo_id()); + reference_conv->set_tensor_ops_enabled( + first_algorithm.tensor_ops_enabled()); } } else { auto comp = @@ -336,35 +394,46 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( } *log.mutable_compute_capability() = GetComputeCapability(stream_exec_); *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec_); + log.set_device_pci_bus_id( + stream_exec_->GetDeviceDescription().pci_bus_id()); VLOG(2) << "Autotuning result:\n" << log.DebugString(); tensorflow::Logger::Singleton()->LogProto(log); } + // Crash on miscompares and redzone violations if desired. Do this after + // logging the autotuning results, otherwise we won't get any data! for (const auto& result : profile_results) { - if (result.has_reference_conv()) { + if (result.has_failure()) { CHECK(!crash_on_checking_failure); } } - auto* profile_results_end = profile_results.data() + profile_results.size(); - - const AutotuneResult* best_result = std::min_element( - profile_results.data(), profile_results_end, - [](const AutotuneResult& lhs, const AutotuneResult& rhs) { - // The successful one should have a smaller key, since we are doing - // min_element. If they are both unsuccessful, keep the earlier one in - // the vector by comparing pointers. - return std::make_tuple(!lhs.has_success(), - tensorflow::proto_utils::FromDurationProto( - lhs.success().run_time()), - &lhs) < - std::make_tuple(!rhs.has_success(), - tensorflow::proto_utils::FromDurationProto( - rhs.success().run_time()), - &rhs); + // Choose the fastest convolution that doesn't produce a REDZONE_MODIFIED + // error. + // + // For now, we ignore WRONG_RESULT failures because false-positives are + // possible (e.g. perhaps the reference algorithm is the one that's + // incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're + // quite severe and can be detected with high accuracy. + // + // TODO(jlebar): We ought to be able to detect redzone reads by noticing NaNs + // in the output of the conv and skip those. + // + // The successful one should have a smaller key, since we are doing + // min_element. If they are both unsuccessful, keep the earlier one in + // the vector by comparing pointers. + auto result_comparison_key = [](const AutotuneResult& r) { + return std::make_tuple( + r.has_failure() && r.failure().kind() != AutotuneResult::WRONG_RESULT, + tensorflow::proto_utils::FromDurationProto(r.run_time())); + }; + const auto& best_result = absl::c_min_element( + profile_results, + [&](const AutotuneResult& lhs, const AutotuneResult& rhs) { + return result_comparison_key(lhs) < result_comparison_key(rhs); }); - if (best_result != profile_results_end && best_result->has_success()) { + if (best_result != profile_results.end() && !best_result->has_failure()) { return *best_result; } @@ -388,7 +457,7 @@ StatusOr CudnnConvAlgorithmPicker::RunOnInstruction( auto best_algo = std::move(best_algo_or).ValueOrDie(); VLOG(1) << "Setting cudnn conv to use algorithm " << best_algo.conv().algorithm() << " and " - << NumBytesToString(best_algo.success().scratch_bytes()) + << NumBytesToString(best_algo.scratch_bytes()) << " of scratch memory: " << instr->ToString() << " tensor_ops_enabled: " << best_algo.conv().tensor_ops_enabled(); @@ -397,7 +466,7 @@ StatusOr CudnnConvAlgorithmPicker::RunOnInstruction( HloComputation* computation = instr->parent(); Shape new_call_shape = ShapeUtil::MakeTupleShape( {instr->shape().tuple_shapes(0), - ShapeUtil::MakeShape(U8, {best_algo.success().scratch_bytes()})}); + ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes()})}); TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, instr->backend_config()); diff --git a/tensorflow/compiler/xla/service/gpu/redzone_allocator.cc b/tensorflow/compiler/xla/service/gpu/redzone_allocator.cc new file mode 100644 index 00000000000..eeb0f3931df --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/redzone_allocator.cc @@ -0,0 +1,126 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/redzone_allocator.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { +namespace gpu { + +// The size of the redzone at the end of the user buffer is rounded up to a +// multiple of kRhsRedzoneAlign. This simplifies the implementation a bit. +constexpr int64 kRhsRedzoneAlign = 4; + +StatusOr> RedzoneAllocator::AllocateBytes( + se::Stream* stream, int64 byte_size) { + CHECK_GE(byte_size, 0) << "byte_size must be positive."; + if (byte_size > GetMemoryLimitInBytes(stream)) { + return se::port::Status( + se::port::error::RESOURCE_EXHAUSTED, + absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", + byte_size, GetMemoryLimitInBytes(stream))); + } + + int64 rhs_slop = RoundUpToNearest(byte_size, kRhsRedzoneAlign) - byte_size; + TF_ASSIGN_OR_RETURN( + OwningDeviceMemory allocated_buffer, + memory_allocator_->Allocate(device_ordinal_, + byte_size + 2 * redzone_size_ + rhs_slop, + /*retry_on_failure=*/false)); + allocated_bytes_excluding_redzones_ += byte_size; + + char* addr = + reinterpret_cast(allocated_buffer.AsDeviceMemoryBase().opaque()); + se::DeviceMemoryBase lhs_redzone(addr, redzone_size_, + /*is_sub_buffer=*/true); + + // Split up the RHS redzone into two pieces: + // - 0 to kRhsRedzoneAlign bytes adjacent to the user buffer, followed by + // - redzone_size_ bytes. + // We do this because Stream::ThenMemset32 requires the buffer address and + // size to be aligned to 4 bytes. + se::DeviceMemoryBase rhs_redzone_slop(addr + redzone_size_ + byte_size, + rhs_slop, /*is_sub_buffer=*/true); + se::DeviceMemoryBase rhs_redzone_nonslop( + addr + redzone_size_ + byte_size + rhs_slop, redzone_size_, + /*is_sub_buffer=*/true); + + uint8 pattern_arr[] = {redzone_pattern_, redzone_pattern_, redzone_pattern_, + redzone_pattern_}; + uint32 pattern32; + std::memcpy(&pattern32, pattern_arr, sizeof(pattern32)); + stream->ThenMemset32(&lhs_redzone, pattern32, redzone_size_); + if (rhs_slop != 0) { + stream->ThenMemcpy(&rhs_redzone_slop, &pattern32, rhs_slop); + } + stream->ThenMemset32(&rhs_redzone_nonslop, pattern32, redzone_size_); + + allocated_buffers_.emplace_back(std::move(allocated_buffer), byte_size); + return se::DeviceMemory(se::DeviceMemoryBase( + addr + redzone_size_, byte_size, /*is_sub_buffer=*/true)); +} + +Status RedzoneAllocator::CheckRedzones(se::Stream* stream) const { + for (const auto& buf_and_size : allocated_buffers_) { + const auto& allocated_buf = buf_and_size.first; + int64 user_alloc_size = buf_and_size.second; + char* addr = + reinterpret_cast(allocated_buf.AsDeviceMemoryBase().opaque()); + // user_alloc_size isn't necessarily the same as + // allocated_buf.size() - 2 * redzone_size_ because if user_alloc_size was + // not a multiple of kRhsRedzoneAlign, we rounded it up. + se::DeviceMemoryBase buf(addr + redzone_size_, user_alloc_size, + /*is_sub_buffer=*/true); + TF_RETURN_IF_ERROR(CheckBufferRedzones(buf, stream)); + } + return Status::OK(); +} + +Status RedzoneAllocator::CheckBufferRedzones(se::DeviceMemoryBase buf, + se::Stream* stream) const { + char* buf_start = reinterpret_cast(buf.opaque()); + auto check_redzone = [&](int64 offset, int64 size, absl::string_view name) { + se::DeviceMemoryBase redzone(buf_start + offset, size, + /*is_sub_buffer=*/true); + auto redzone_data = absl::make_unique(size); + TF_RETURN_IF_ERROR(stream->ThenMemcpy(redzone_data.get(), redzone, size) + .BlockHostUntilDone()); + for (int64 i = 0; i < size; ++i) { + uint8 rz_value = redzone_data[i]; + if (rz_value != redzone_pattern_) { + return InternalError( + "Redzone mismatch in %s redzone of buffer %p at offset %d; " + "expected %08x but was %08x.", + name, buf.opaque(), i, redzone_pattern_, rz_value); + } + } + return Status::OK(); + }; + + // `buf` points to the buffer returned to the user, so the LHS redzone starts + // before `buf`. + TF_RETURN_IF_ERROR(check_redzone(-redzone_size_, redzone_size_, "LHS")); + + int64 rhs_slop = + RoundUpToNearest(buf.size(), kRhsRedzoneAlign) - buf.size(); + TF_RETURN_IF_ERROR( + check_redzone(buf.size(), redzone_size_ + rhs_slop, "RHS")); + + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/redzone_allocator.h b/tensorflow/compiler/xla/service/gpu/redzone_allocator.h new file mode 100644 index 00000000000..d8b438c399e --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/redzone_allocator.h @@ -0,0 +1,95 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDZONE_ALLOCATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDZONE_ALLOCATOR_H_ + +#include + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" +#include "tensorflow/compiler/xla/service/owning_device_memory.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// An allocator that allocates a bit of extra memory around the beginning/end of +// every allocation and can check that this memory is unmodified. +// +// This can be used to check for out-of-bounds writes, and, if the redzone is +// filled with a sufficiently "ugly" pattern, may also be able to check for +// out-of-bounds reads. The default fill pattern of -1 is an unusual NaN +// pattern when interpreted as a floating-point number, so hopefully works for +// out-of-bounds reads and writes in those cases. +// +// This class implements se::ScratchAllocator, so can be used to allocate temp +// memory for cudnn convolutions. +class RedzoneAllocator : public se::ScratchAllocator { + public: + RedzoneAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator, + int64 redzone_size = 1 << 23, // 8MiB per side, 16MiB total + uint8 redzone_pattern = -1) + : device_ordinal_(device_ordinal), + redzone_size_( + RoundUpToNearest(redzone_size, kXlaAllocatedBufferAlignBytes)), + redzone_pattern_(redzone_pattern), + memory_allocator_(memory_allocator) {} + + // Redzones don't count towards the memory limit. + int64 GetMemoryLimitInBytes(se::Stream* stream) override { + return 1LL << 32; // 4GB. TODO(jlebar): Tune this? + } + int64 TotalAllocatedBytesExcludingRedzones() const { + return allocated_bytes_excluding_redzones_; + } + + StatusOr> AllocateBytes(se::Stream* stream, + int64 byte_size) override; + + // Determines whether redzones around all allocated buffers are unmodified. + Status CheckRedzones(se::Stream* stream) const; + + private: + // Checks that one buffer's redzones are unmodified. buf should point to the + // user-editable buffer, i.e. it should not include redzones. + Status CheckBufferRedzones(se::DeviceMemoryBase buf, + se::Stream* stream) const; + + const int device_ordinal_; + + // Redzone size on *one side* of allocation. + // + // Must be a multiple of kXlaAllocatedBufferAlignBytes, otherwise the buffers + // returned to users will be misaligned. + const int64 redzone_size_; + + const uint8 redzone_pattern_; + DeviceMemoryAllocator* memory_allocator_; + + // The second element of the pair is the size of the user allocation. This + // isn't necessarily just first.size() - 2 * redzone_size_ because when the + // user allocation size is not a multiple of 4 bytes, we round up the size of + // the RHS redzone. + std::vector> allocated_buffers_; + + int64 allocated_bytes_excluding_redzones_ = 0; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDZONE_ALLOCATOR_H_ diff --git a/tensorflow/compiler/xla/service/gpu/redzone_allocator_test.cc b/tensorflow/compiler/xla/service/gpu/redzone_allocator_test.cc new file mode 100644 index 00000000000..4ebe000e6be --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/redzone_allocator_test.cc @@ -0,0 +1,107 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/redzone_allocator.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/platform.h" + +namespace xla { +namespace gpu { +namespace { + +TEST(RedzoneAllocatorTest, WriteToRedzone) { + constexpr int64 kRedzoneSize = 1 << 23; // 8MiB redzone on each side + // Redzone pattern should not be equal to zero; otherwise modify_redzone will + // break. + constexpr uint8 kRedzonePattern = 0x7e; + + // Allocate 32MiB + 1 byte (to make things misaligned) + constexpr int64 kAllocSize = (1 << 25) + 1; + + se::Platform* platform = + se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie(); + se::StreamExecutor* stream_exec = platform->ExecutorForDevice(0).ValueOrDie(); + StreamExecutorMemoryAllocator se_allocator(platform, {stream_exec}); + RedzoneAllocator allocator(/*device_ordinal=*/0, &se_allocator, kRedzoneSize, + kRedzonePattern); + + se::Stream stream(stream_exec); + stream.Init(); + TF_ASSERT_OK_AND_ASSIGN(se::DeviceMemory buf, + allocator.AllocateBytes(&stream, + /*byte_size=*/kAllocSize)); + TF_EXPECT_OK(allocator.CheckRedzones(&stream)); + + char* buf_addr = reinterpret_cast(buf.opaque()); + se::DeviceMemoryBase lhs_redzone(buf_addr - kRedzoneSize, kRedzoneSize, + /*is_sub_buffer=*/true); + se::DeviceMemoryBase rhs_redzone(buf_addr + kAllocSize, kRedzoneSize, + /*is_sub_buffer=*/true); + + // Check that the redzones are in fact filled with kRedzonePattern. + auto check_redzone = [&](se::DeviceMemoryBase redzone, + absl::string_view name) { + std::vector host_buf(kRedzoneSize); + TF_ASSERT_OK(stream.ThenMemcpy(host_buf.data(), redzone, kRedzoneSize) + .BlockHostUntilDone()); + const int64 kMaxMismatches = 16; + int64 mismatches = 0; + for (int64 i = 0; i < host_buf.size(); ++i) { + if (mismatches == kMaxMismatches) { + ADD_FAILURE() << "Hit max number of mismatches; skipping others."; + break; + } + if (host_buf[i] != kRedzonePattern) { + ++mismatches; + EXPECT_EQ(host_buf[i], kRedzonePattern) + << "at index " << i << " of " << name << " redzone"; + } + } + }; + check_redzone(lhs_redzone, "lhs"); + check_redzone(rhs_redzone, "rhs"); + + // Modifies a redzone, checks that RedzonesAreUnmodified returns false, then + // reverts it back to its original value and checks that RedzonesAreUnmodified + // returns true. + auto modify_redzone = [&](se::DeviceMemoryBase redzone, int64 offset, + absl::string_view name) { + SCOPED_TRACE(absl::StrCat(name, ", offset=", offset)); + se::DeviceMemoryBase redzone_at_offset( + reinterpret_cast(redzone.opaque()) + offset, 1, + /*is_sub_buffer=*/true); + char old_redzone_value = 0; + TF_EXPECT_OK(allocator.CheckRedzones(&stream)); + stream.ThenMemcpy(&old_redzone_value, redzone_at_offset, 1) + .ThenMemZero(&redzone_at_offset, 1); + EXPECT_FALSE(allocator.CheckRedzones(&stream).ok()); + stream.ThenMemcpy(&redzone_at_offset, &old_redzone_value, 1); + TF_EXPECT_OK(allocator.CheckRedzones(&stream)); + }; + + modify_redzone(lhs_redzone, /*offset=*/0, "lhs"); + modify_redzone(lhs_redzone, /*offset=*/kRedzoneSize - 1, "lhs"); + modify_redzone(rhs_redzone, /*offset=*/0, "rhs"); + modify_redzone(rhs_redzone, /*offset=*/kRedzoneSize - 1, "rhs"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/owning_device_memory.h b/tensorflow/compiler/xla/service/owning_device_memory.h index 9cf071f0d9d..9e630486169 100644 --- a/tensorflow/compiler/xla/service/owning_device_memory.h +++ b/tensorflow/compiler/xla/service/owning_device_memory.h @@ -100,8 +100,15 @@ class OwningDeviceMemory { // !is_null() is sufficient but not necessary to imply `this` is active. bool is_null() const { return mem_.is_null(); } - se::DeviceMemoryBase AsDeviceMemoryBase() { - return se::DeviceMemoryBase(opaque(), size(), /*is_sub_buffer=*/false); + se::DeviceMemoryBase AsDeviceMemoryBase() const { + // This const_cast is necessary because DeviceMemoryBase's constructor + // doesn't accept a const void*. This isn't ideal, but it's better than the + // alternative of making a AsDeviceMemoryBase non-const member function. + // + // This is safe (i.e. not UB) because the casted pointer is derived from a + // non-const pointer, namely mem_.opaque(). + return se::DeviceMemoryBase(const_cast(opaque()), size(), + /*is_sub_buffer=*/false); } // Returns the wrapped DeviceMemoryBase without freeing it, and deactivates diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 0b3e3710721..bd901062ce5 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -334,6 +334,7 @@ void LogFusedConvAutotuneResults(const NodeDef& node, const Tensor& input, *log.mutable_cudnn_version() = internal::GetCudnnVersion(stream_exec); *log.mutable_compute_capability() = internal::GetComputeCapability(stream_exec); + log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id()); for (const auto& result : results) { *log.add_results() = result; } @@ -342,38 +343,29 @@ void LogFusedConvAutotuneResults(const NodeDef& node, const Tensor& input, Status BestCudnnConvAlgorithm(absl::Span results, se::dnn::AlgorithmConfig* algo) { - // For the "!xhs.has_success()" below, this is because we want successful ones - // to order first, therefore they need a smaller key per "min_element". const AutotuneResult* best_result = std::min_element( results.begin(), results.end(), [](const AutotuneResult& lhs, const AutotuneResult& rhs) { - return std::make_tuple( - !lhs.has_success(), - internal::FromDurationProto(lhs.success().run_time())) < - std::make_tuple( - !rhs.has_success(), - internal::FromDurationProto(rhs.success().run_time())); + return internal::FromDurationProto(lhs.run_time()) < + internal::FromDurationProto(rhs.run_time()); }); const AutotuneResult* best_result_no_scratch = std::min_element( results.begin(), results.end(), [](const AutotuneResult& lhs, const AutotuneResult& rhs) { - return std::make_tuple( - !lhs.has_success(), lhs.success().scratch_bytes(), - internal::FromDurationProto(lhs.success().run_time())) < - std::make_tuple( - !rhs.has_success(), rhs.success().scratch_bytes(), - internal::FromDurationProto(rhs.success().run_time())); + return std::make_tuple(lhs.scratch_bytes(), + internal::FromDurationProto(lhs.run_time())) < + std::make_tuple(rhs.scratch_bytes(), + internal::FromDurationProto(rhs.run_time())); }); - if (best_result == results.end() || !best_result->has_success()) { + if (best_result == results.end()) { return errors::NotFound("No algorithm worked!"); } algo->set_algorithm({best_result->conv().algorithm(), best_result->conv().tensor_ops_enabled()}); if (best_result_no_scratch != results.end() && - best_result_no_scratch->has_success() && - best_result_no_scratch->success().scratch_bytes() == 0) { + best_result_no_scratch->scratch_bytes() == 0) { algo->set_algorithm_no_scratch( {best_result_no_scratch->conv().algorithm(), best_result_no_scratch->conv().tensor_ops_enabled()}); @@ -726,19 +718,15 @@ void LaunchFusedConv2DBiasActivationOp:: output_desc, &output_ptr, &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm), &profile_result) .ok(); - if (cudnn_launch_status) { - if (profile_result.is_valid()) { - results.emplace_back(); - auto& result = results.back(); - result.mutable_conv()->set_algorithm(profile_algorithm.algo_id()); - result.mutable_conv()->set_tensor_ops_enabled( - profile_algorithm.tensor_ops_enabled()); - result.mutable_success()->set_scratch_bytes( - scratch_allocator.TotalByteSize()); - *result.mutable_success()->mutable_run_time() = - internal::ToDurationProto( - absl::Milliseconds(profile_result.elapsed_time_in_ms())); - } + if (cudnn_launch_status && profile_result.is_valid()) { + results.emplace_back(); + auto& result = results.back(); + result.mutable_conv()->set_algorithm(profile_algorithm.algo_id()); + result.mutable_conv()->set_tensor_ops_enabled( + profile_algorithm.tensor_ops_enabled()); + result.set_scratch_bytes(scratch_allocator.TotalByteSize()); + *result.mutable_run_time() = internal::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); } } internal::LogFusedConvAutotuneResults(ctx->op_kernel().def(), *conv_input, diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index efd701c7687..b8db71e41ec 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -858,19 +858,15 @@ void LaunchConv2DBackpropFilterOp::operator()( &scratch_allocator, AlgorithmConfig(profile_algorithm), &profile_result) .ok(); - if (cudnn_launch_status) { - if (profile_result.is_valid()) { - results.emplace_back(); - auto& result = results.back(); - result.mutable_conv()->set_algorithm(profile_algorithm.algo_id()); - result.mutable_conv()->set_tensor_ops_enabled( - profile_algorithm.tensor_ops_enabled()); - result.mutable_success()->set_scratch_bytes( - scratch_allocator.TotalByteSize()); - *result.mutable_success()->mutable_run_time() = - proto_utils::ToDurationProto( - absl::Milliseconds(profile_result.elapsed_time_in_ms())); - } + if (cudnn_launch_status && profile_result.is_valid()) { + results.emplace_back(); + auto& result = results.back(); + result.mutable_conv()->set_algorithm(profile_algorithm.algo_id()); + result.mutable_conv()->set_tensor_ops_enabled( + profile_algorithm.tensor_ops_enabled()); + result.set_scratch_bytes(scratch_allocator.TotalByteSize()); + *result.mutable_run_time() = proto_utils::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); } } LogConvAutotuneResults(ctx->op_kernel().def(), transformed_input, diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index 730c71e4a75..31bdd6c6c55 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -969,19 +969,15 @@ void LaunchConv2DBackpropInputOp::operator()( conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator, AlgorithmConfig(profile_algorithm), &profile_result) .ok(); - if (cudnn_launch_status) { - if (profile_result.is_valid()) { - results.emplace_back(); - auto& result = results.back(); - result.mutable_conv()->set_algorithm(profile_algorithm.algo_id()); - result.mutable_conv()->set_tensor_ops_enabled( - profile_algorithm.tensor_ops_enabled()); - result.mutable_success()->set_scratch_bytes( - scratch_allocator.TotalByteSize()); - *result.mutable_success()->mutable_run_time() = - proto_utils::ToDurationProto( - absl::Milliseconds(profile_result.elapsed_time_in_ms())); - } + if (cudnn_launch_status && profile_result.is_valid()) { + results.emplace_back(); + auto& result = results.back(); + result.mutable_conv()->set_algorithm(profile_algorithm.algo_id()); + result.mutable_conv()->set_tensor_ops_enabled( + profile_algorithm.tensor_ops_enabled()); + result.set_scratch_bytes(scratch_allocator.TotalByteSize()); + *result.mutable_run_time() = proto_utils::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); } } LogConvAutotuneResults(ctx->op_kernel().def(), pre_transformed_in_backprop, diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 2e6cb006a27..799a99577f8 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -870,19 +870,15 @@ void LaunchConv2DOp::operator()( output_desc, &output_ptr, &scratch_allocator, AlgorithmConfig(profile_algorithm), &profile_result) .ok(); - if (cudnn_launch_status) { - if (profile_result.is_valid()) { - results.emplace_back(); - auto& result = results.back(); - result.mutable_conv()->set_algorithm(profile_algorithm.algo_id()); - result.mutable_conv()->set_tensor_ops_enabled( - profile_algorithm.tensor_ops_enabled()); - result.mutable_success()->set_scratch_bytes( - scratch_allocator.TotalByteSize()); - *result.mutable_success()->mutable_run_time() = - proto_utils::ToDurationProto( - absl::Milliseconds(profile_result.elapsed_time_in_ms())); - } + if (cudnn_launch_status && profile_result.is_valid()) { + results.emplace_back(); + auto& result = results.back(); + result.mutable_conv()->set_algorithm(profile_algorithm.algo_id()); + result.mutable_conv()->set_tensor_ops_enabled( + profile_algorithm.tensor_ops_enabled()); + result.set_scratch_bytes(scratch_allocator.TotalByteSize()); + *result.mutable_run_time() = proto_utils::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); } } LogConvAutotuneResults(ctx->op_kernel().def(), input, transformed_filter, diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc index 3ea4742d206..e968ad4d934 100644 --- a/tensorflow/core/kernels/conv_ops_3d.cc +++ b/tensorflow/core/kernels/conv_ops_3d.cc @@ -467,11 +467,9 @@ struct LaunchConvOp { result.mutable_conv()->set_algorithm(profile_algorithm.algo_id()); result.mutable_conv()->set_tensor_ops_enabled( profile_algorithm.tensor_ops_enabled()); - result.mutable_success()->set_scratch_bytes( - scratch_allocator.TotalByteSize()); - *result.mutable_success()->mutable_run_time() = - proto_utils::ToDurationProto( - absl::Milliseconds(profile_result.elapsed_time_in_ms())); + result.set_scratch_bytes(scratch_allocator.TotalByteSize()); + *result.mutable_run_time() = proto_utils::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); } } } diff --git a/tensorflow/core/kernels/conv_ops_fused_impl.h b/tensorflow/core/kernels/conv_ops_fused_impl.h index f207af65565..524303199da 100644 --- a/tensorflow/core/kernels/conv_ops_fused_impl.h +++ b/tensorflow/core/kernels/conv_ops_fused_impl.h @@ -536,11 +536,9 @@ Status FindBestConvolveAlgorithm(const FusedConvParameters& params, result.mutable_conv()->set_algorithm(profile_algorithm.algo_id()); result.mutable_conv()->set_tensor_ops_enabled( profile_algorithm.tensor_ops_enabled()); - result.mutable_success()->set_scratch_bytes( - scratch_allocator.TotalByteSize()); - *result.mutable_success()->mutable_run_time() = - proto_utils::ToDurationProto( - absl::Milliseconds(profile_result.elapsed_time_in_ms())); + result.set_scratch_bytes(scratch_allocator.TotalByteSize()); + *result.mutable_run_time() = proto_utils::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); } } // Only log on an AutoTuneFusedConv cache miss. diff --git a/tensorflow/core/kernels/gpu_utils.cc b/tensorflow/core/kernels/gpu_utils.cc index 298acfba54d..c0617eb51f1 100644 --- a/tensorflow/core/kernels/gpu_utils.cc +++ b/tensorflow/core/kernels/gpu_utils.cc @@ -71,6 +71,7 @@ void LogConvAutotuneResults(const NodeDef& node, const Tensor& input, log.mutable_instr()->PackFrom(std::move(instr)); *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec); *log.mutable_compute_capability() = GetComputeCapability(stream_exec); + log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id()); for (const auto& result : results) { *log.add_results() = result; } @@ -101,6 +102,7 @@ void LogFusedConvAutotuneResults(const NodeDef& node, const Tensor& input, log.mutable_instr()->PackFrom(std::move(instr)); *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec); *log.mutable_compute_capability() = GetComputeCapability(stream_exec); + log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id()); for (const auto& result : results) { *log.add_results() = result; } @@ -109,38 +111,32 @@ void LogFusedConvAutotuneResults(const NodeDef& node, const Tensor& input, Status BestCudnnConvAlgorithm(absl::Span results, se::dnn::AlgorithmConfig* algo) { - // For the "!xhs.has_success()" below, this is because we want successful ones - // to order first, therefore they need a smaller key per "min_element". + // TODO(jlebar): Exclude conv ops with failures, once we have failure checking + // and have confidence that it's correct. + const AutotuneResult* best_result = std::min_element( results.begin(), results.end(), [](const AutotuneResult& lhs, const AutotuneResult& rhs) { - return std::make_tuple( - !lhs.has_success(), - proto_utils::FromDurationProto(lhs.success().run_time())) < - std::make_tuple( - !rhs.has_success(), - proto_utils::FromDurationProto(rhs.success().run_time())); + return proto_utils::FromDurationProto(lhs.run_time()) < + proto_utils::FromDurationProto(rhs.run_time()); }); const AutotuneResult* best_result_no_scratch = std::min_element( results.begin(), results.end(), [](const AutotuneResult& lhs, const AutotuneResult& rhs) { - return std::make_tuple( - !lhs.has_success(), lhs.success().scratch_bytes(), - proto_utils::FromDurationProto(lhs.success().run_time())) < - std::make_tuple( - !rhs.has_success(), rhs.success().scratch_bytes(), - proto_utils::FromDurationProto(rhs.success().run_time())); + return std::make_tuple(lhs.scratch_bytes(), + proto_utils::FromDurationProto(lhs.run_time())) < + std::make_tuple(rhs.scratch_bytes(), + proto_utils::FromDurationProto(rhs.run_time())); }); - if (best_result == results.end() || !best_result->has_success()) { + if (best_result == results.end()) { return errors::NotFound("No algorithm worked!"); } algo->set_algorithm({best_result->conv().algorithm(), best_result->conv().tensor_ops_enabled()}); if (best_result_no_scratch != results.end() && - best_result_no_scratch->has_success() && - best_result_no_scratch->success().scratch_bytes() == 0) { + best_result_no_scratch->scratch_bytes() == 0) { algo->set_algorithm_no_scratch( {best_result_no_scratch->conv().algorithm(), best_result_no_scratch->conv().tensor_ops_enabled()}); diff --git a/tensorflow/core/protobuf/autotuning.proto b/tensorflow/core/protobuf/autotuning.proto index 29e4d00a85f..2edc70b34c5 100644 --- a/tensorflow/core/protobuf/autotuning.proto +++ b/tensorflow/core/protobuf/autotuning.proto @@ -22,9 +22,25 @@ message ComputeCapability { } message AutotuneResult { - message SuccessResult { - int64 scratch_bytes = 1; - google.protobuf.Duration run_time = 2; + enum FailureKind { + UNKNOWN = 0; + REDZONE_MODIFIED = 1; + WRONG_RESULT = 2; + } + + message FailureResult { + FailureKind kind = 1; + string msg = 2; + + // For failure_kind == WRONG_RESULT, this field indicates the reference + // configuration that we compared against. + // + // Note that the reference algorithm isn't always correct. However, + // empirically it's more correct, as it's "algo 0", less fancy than the + // compared one. + oneof key { + ConvKey reference_conv = 11; + } } message ConvKey { @@ -32,34 +48,16 @@ message AutotuneResult { bool tensor_ops_enabled = 2; } - // If the conv runs successfully, success will be populated with the - // autotuning result. Otherwise, the error message is propagated. - oneof result { - SuccessResult success = 3; - string error_string = 4; - } + int64 scratch_bytes = 8; + google.protobuf.Duration run_time = 9; + + FailureResult failure = 7; oneof key { ConvKey conv = 5; } - // Sometimes we run a correctness checker during autotuning. It compares the - // result buffer content between two algorithms, say, "reference" and "test" - // algorithms. The "test" algorithm is the one associated with this - // AutotuneResult. - // - // This field records the reference algorithm used. Notice that naming it - // "reference" doesn't mean it's always correct. However, empirically it's - // more correct, as it's "algo 0", less fancy than the compared one. - // - // Notice that the checker_failure may exist even in the success case. - // This is because the error string in `result` comes from the underlying - // implementation like cuDNN, which isn't aware that it produced an incorrect - // result. And even if the checker detects an incorrect result, we can still - // retrieve scratch_bytes and runtime_ms. - oneof checker_failure { - ConvKey reference_conv = 6; - } + // Next ID: 12 } message AutotuningLog { @@ -70,4 +68,9 @@ message AutotuningLog { CudnnVersion cudnn_version = 3; ComputeCapability compute_capability = 4; + + // stream_executor::DeviceDescription::pci_bus_id. + string device_pci_bus_id = 5; + + // Next ID: 6 }