Provide a flag TF_DISABLE_RZ_CHECK to disable redzone check during TF convolution autotuning
The flag has to be set to "1" for the redzone checking to be disabled. PiperOrigin-RevId: 255295979
This commit is contained in:
parent
5b469d2115
commit
bb2a1d3ee8
@ -73,6 +73,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
|
|||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
struct LaunchGeneric {
|
struct LaunchGeneric {
|
||||||
void operator()(OpKernelContext* ctx, const Tensor& input,
|
void operator()(OpKernelContext* ctx, const Tensor& input,
|
||||||
@ -575,6 +576,11 @@ template struct LaunchConv2DOp<CPUDevice, float>;
|
|||||||
template struct LaunchConv2DOp<CPUDevice, double>;
|
template struct LaunchConv2DOp<CPUDevice, double>;
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
static bool RedzoneCheckDisabled() {
|
||||||
|
const char* disable_rz_str = std::getenv("TF_DISABLE_RZ_CHECK");
|
||||||
|
return disable_rz_str != nullptr && std::strcmp(disable_rz_str, "1") == 0;
|
||||||
|
}
|
||||||
|
|
||||||
int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
|
int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
|
||||||
int64 default_value_in_bytes) {
|
int64 default_value_in_bytes) {
|
||||||
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
|
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
|
||||||
@ -997,6 +1003,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
|||||||
se::cuda::PtxCompilationOptions());
|
se::cuda::PtxCompilationOptions());
|
||||||
|
|
||||||
se::DeviceMemory<T> output_tensor;
|
se::DeviceMemory<T> output_tensor;
|
||||||
|
|
||||||
|
if (!RedzoneCheckDisabled()) {
|
||||||
auto output_rz_or = rz_allocator.AllocateBytes(stream, output_ptr.size());
|
auto output_rz_or = rz_allocator.AllocateBytes(stream, output_ptr.size());
|
||||||
if (!output_rz_or.ok()) {
|
if (!output_rz_or.ok()) {
|
||||||
static std::once_flag rz_allocation_failure_logged;
|
static std::once_flag rz_allocation_failure_logged;
|
||||||
@ -1011,6 +1019,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
|||||||
} else {
|
} else {
|
||||||
output_tensor = se::DeviceMemory<T>(output_rz_or.ValueOrDie());
|
output_tensor = se::DeviceMemory<T>(output_rz_or.ValueOrDie());
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
output_tensor = output_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<tensorflow::AutotuneResult> results;
|
std::vector<tensorflow::AutotuneResult> results;
|
||||||
for (auto profile_algorithm : algorithms) {
|
for (auto profile_algorithm : algorithms) {
|
||||||
@ -1019,13 +1030,18 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
|||||||
se::cuda::RedzoneAllocator rz_scratch_allocator(
|
se::cuda::RedzoneAllocator rz_scratch_allocator(
|
||||||
stream->parent()->device_ordinal(), &tf_allocator_adapter,
|
stream->parent()->device_ordinal(), &tf_allocator_adapter,
|
||||||
se::cuda::PtxCompilationOptions());
|
se::cuda::PtxCompilationOptions());
|
||||||
|
DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
|
||||||
|
se::ScratchAllocator* allocator_used =
|
||||||
|
!RedzoneCheckDisabled()
|
||||||
|
? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
|
||||||
|
: static_cast<se::ScratchAllocator*>(&scratch_allocator);
|
||||||
|
|
||||||
ProfileResult profile_result;
|
ProfileResult profile_result;
|
||||||
bool cudnn_launch_status =
|
bool cudnn_launch_status =
|
||||||
stream
|
stream
|
||||||
->ThenConvolveWithAlgorithm(
|
->ThenConvolveWithAlgorithm(
|
||||||
input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
|
input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
|
||||||
output_desc, &output_tensor, &rz_scratch_allocator,
|
output_desc, &output_tensor, allocator_used,
|
||||||
AlgorithmConfig(profile_algorithm), &profile_result)
|
AlgorithmConfig(profile_algorithm), &profile_result)
|
||||||
.ok();
|
.ok();
|
||||||
if (cudnn_launch_status && profile_result.is_valid()) {
|
if (cudnn_launch_status && profile_result.is_valid()) {
|
||||||
|
Loading…
Reference in New Issue
Block a user