Add multi-algorithm deterministic cuDNN convolutions

This commit is contained in:
Duncan Riach 2019-12-06 18:18:29 -08:00
parent 2beb2d53ba
commit 5341e3d299
9 changed files with 191 additions and 125 deletions

View File

@ -669,6 +669,7 @@ cc_library(
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/util/proto:proto_utils", "//tensorflow/core/util/proto:proto_utils",
"//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/stream_executor/cuda:cuda_helpers",
"//tensorflow/stream_executor/gpu:redzone_allocator", "//tensorflow/stream_executor/gpu:redzone_allocator",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/platform/logger.h" #include "tensorflow/core/platform/logger.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/proto/proto_utils.h" #include "tensorflow/core/util/proto/proto_utils.h"
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
#include "tensorflow/stream_executor/gpu/redzone_allocator.h" #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
namespace xla { namespace xla {
@ -536,43 +537,41 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
} }
} }
// For now, we ignore WRONG_RESULT failures because false-positives are
// possible (e.g. perhaps the reference algorithm is the one that's
// incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're
// quite severe and can be detected with high accuracy.
auto has_failure = [](const AutotuneResult& r) {
return r.has_failure() &&
r.failure().kind() != AutotuneResult::WRONG_RESULT;
};
// Choose the fastest convolution that doesn't produce a REDZONE_MODIFIED // Choose the fastest convolution that doesn't produce a REDZONE_MODIFIED
// error. // error.
// //
// TODO(jlebar): We ought to be able to detect redzone reads by noticing NaNs // TODO(jlebar): We ought to be able to detect redzone reads by noticing NaNs
// in the output of the conv and skip those. // in the output of the conv and skip those.
// //
// The successful one should have a smaller key, since we are doing // For now, we ignore WRONG_RESULT failures because false-positives are
// min_element. If they are both unsuccessful, keep the earlier one in // possible (e.g. perhaps the reference algorithm is the one that's
// the vector by comparing pointers. // incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're
auto result_comparison_key = [&has_failure](const AutotuneResult& r) { // quite severe and can be detected with high accuracy.
return std::make_tuple( std::vector<AutotuneResult> filtered_results;
has_failure(r), absl::c_copy_if(
tensorflow::proto_utils::FromDurationProto(r.run_time())); profile_results, std::back_inserter(filtered_results),
}; [](const AutotuneResult& r) {
const auto& best_result = absl::c_min_element( return !(r.has_failure() &&
profile_results, r.failure().kind() != AutotuneResult::WRONG_RESULT);
[&](const AutotuneResult& lhs, const AutotuneResult& rhs) {
return result_comparison_key(lhs) < result_comparison_key(rhs);
}); });
if (filtered_results.empty()) {
if (best_result != profile_results.end() && !has_failure(*best_result)) { return InternalError(
return *best_result; "All algorithms tried for convolution %s failed. Falling back to "
"default algorithm. ",
instr->ToString());
} }
return InternalError( auto selected_result = filtered_results.begin();
"All algorithms tried for convolution %s failed. Falling back to " if (!se::cuda::RequireCuDNNDeterminism()) {
"default algorithm.", selected_result = absl::c_min_element(
instr->ToString()); filtered_results,
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) <
tensorflow::proto_utils::FromDurationProto(rhs.run_time());
});
}
return *selected_result;
} }
StatusOr<tensorflow::AutotuneResult> StatusOr<tensorflow::AutotuneResult>

View File

@ -531,6 +531,7 @@ tf_cuda_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:stream_executor", "//tensorflow/core:stream_executor",
"//tensorflow/core/util/proto:proto_utils", "//tensorflow/core/util/proto:proto_utils",
"//tensorflow/stream_executor/cuda:cuda_helpers",
"//tensorflow/stream_executor/gpu:asm_compiler", "//tensorflow/stream_executor/gpu:asm_compiler",
"//tensorflow/stream_executor/gpu:redzone_allocator", "//tensorflow/stream_executor/gpu:redzone_allocator",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/protobuf/autotuning.pb.h" #include "tensorflow/core/protobuf/autotuning.pb.h"
#include "tensorflow/core/protobuf/conv_autotuning.pb.h" #include "tensorflow/core/protobuf/conv_autotuning.pb.h"
#include "tensorflow/core/util/proto/proto_utils.h" #include "tensorflow/core/util/proto/proto_utils.h"
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
#include "tensorflow/stream_executor/gpu/asm_compiler.h" #include "tensorflow/stream_executor/gpu/asm_compiler.h"
#include "tensorflow/stream_executor/gpu/redzone_allocator.h" #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
@ -220,31 +221,32 @@ Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
if (filtered_results.empty()) { if (filtered_results.empty()) {
return errors::NotFound("No algorithm worked!"); return errors::NotFound("No algorithm worked!");
} }
std::vector<AutotuneResult> filtered_results_no_scratch;
absl::c_copy_if(
filtered_results, std::back_inserter(filtered_results_no_scratch),
[](const AutotuneResult& result) { return result.scratch_bytes() == 0; });
const auto best_result = absl::c_min_element( auto selected_result = filtered_results.begin();
filtered_results, auto selected_result_no_scratch = filtered_results_no_scratch.begin();
[](const AutotuneResult& lhs, const AutotuneResult& rhs) { if (!se::cuda::RequireCuDNNDeterminism()) {
return proto_utils::FromDurationProto(lhs.run_time()) < auto compare_run_times = [](const AutotuneResult& lhs,
proto_utils::FromDurationProto(rhs.run_time()); const AutotuneResult& rhs) {
}); return proto_utils::FromDurationProto(lhs.run_time()) <
proto_utils::FromDurationProto(rhs.run_time());
const auto best_result_no_scratch = absl::c_min_element( };
filtered_results, selected_result = absl::c_min_element(filtered_results, compare_run_times);
[](const AutotuneResult& lhs, const AutotuneResult& rhs) { selected_result_no_scratch = absl::c_min_element(
return std::make_tuple(lhs.scratch_bytes(), filtered_results_no_scratch, compare_run_times);
proto_utils::FromDurationProto(lhs.run_time())) <
std::make_tuple(rhs.scratch_bytes(),
proto_utils::FromDurationProto(rhs.run_time()));
});
algo->set_algorithm({best_result->conv().algorithm(),
best_result->conv().tensor_ops_enabled()});
if (best_result_no_scratch != filtered_results.end() &&
best_result_no_scratch->scratch_bytes() == 0) {
algo->set_algorithm_no_scratch(
{best_result_no_scratch->conv().algorithm(),
best_result_no_scratch->conv().tensor_ops_enabled()});
} }
algo->set_algorithm({selected_result->conv().algorithm(),
selected_result->conv().tensor_ops_enabled()});
if (selected_result_no_scratch != filtered_results_no_scratch.end()) {
algo->set_algorithm_no_scratch(
{selected_result_no_scratch->conv().algorithm(),
selected_result_no_scratch->conv().tensor_ops_enabled()});
}
return Status::OK(); return Status::OK();
} }

View File

@ -27,22 +27,39 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
# Setting either of the two environment variables TF_CUDNN_DETERMINISTIC or # Notes:
# TF_DETERMINISTIC_OPS to "true" or "1" will disable autotuning of cuDNN #
# algorithms and cause deterministic cuDNN algorithms to be selected when both # Deterministic cuDNN operation is selected by setting either of the two
# deterministic and non-deterministic algorithms are available. These tests are # environment variables TF_CUDNN_DETERMINISTIC or TF_DETERMINISTIC_OPS to 'true'
# intended to confirm that deterministic algorithms are chosen when either # or '1' while also not setting the environment variable TF_CUDNN_USE_AUTOTUNE
# environment variable is set to "true" or "1". The tested configurations were # to 'false' or '0'.
# first confirmed to produce non-deterministic results when the environment #
# variables are not set. # Where both deterministic and non-deterministic cuDNN algorithms are available,
# selecting determinitic operation will lead to only the deterministic
# algorithms being chosen. Additionally, selecting deterministic operation will
# result in a deterministic, or reproducible, selection of algorithms (for any
# given layer configuration) for each of the forward and the two backward paths.
#
# These tests intend to confirm that deterministic algorithms are chosen (for
# the back-prop paths) when desterministic operation is selected. The tested
# configurations were first confirmed to produce non-deterministic results when
# the above-mentioned environment variables are not set.
#
# Even though selecting determinitic operation should ensure that the same
# algorithms, for a given layer configuration, are always used (i.e. that
# algorithm selection is deterministic / reproducible), this is not tested.
_PADDING = 'SAME' # TODO(duncanriach): Add test for deterministic cuDNN max-pooling
_STRIDES = [1, 1, 1, 1]
LayerShape = collections.namedtuple('LayerShape', LayerShapeNHWC = collections.namedtuple('LayerShapeNHWC',
'batch, height, width, channels') 'batch, height, width, channels')
FilterShape = collections.namedtuple( FilterShape2D = collections.namedtuple(
'FilterShape', 'height, width, in_channels, out_channels') 'FilterShape2D', 'height, width, in_channels, out_channels')
LayerShapeNCDHW = collections.namedtuple(
'LayerShapeNCDHW', 'batch, channels, depth, height, width')
FilterShape3D = collections.namedtuple(
'FilterShape3D', 'depth, height, width, in_channels, out_channels')
class ConvolutionTest(test.TestCase): class ConvolutionTest(test.TestCase):
@ -53,14 +70,14 @@ class ConvolutionTest(test.TestCase):
return constant_op.constant( return constant_op.constant(
2 * np.random.random_sample(shape) - 1, dtype=dtypes.float32) 2 * np.random.random_sample(shape) - 1, dtype=dtypes.float32)
def _random_out_op(self, in_shape, filter_shape): def _random_out_op(self, in_shape, filter_shape, strides, padding):
# Choosing not to use array_op.zeros() to prevent possible removal by # Choosing not to use array_op.zeros() to prevent possible removal by
# optimization # optimization
in_op = self._random_data_op(in_shape) in_op = self._random_data_op(in_shape)
filter_op = self._random_data_op(filter_shape) filter_op = self._random_data_op(filter_shape)
# Use the forward op's shape-inference # Use the forward op's shape-inference
conv_op = nn_ops.conv2d( conv_op = nn_ops.conv2d(
in_op, filter_op, strides=_STRIDES, padding=_PADDING) in_op, filter_op, strides=strides, padding=padding)
out_shape = conv_op.get_shape() out_shape = conv_op.get_shape()
out_op = self._random_data_op(out_shape) out_op = self._random_data_op(out_shape)
return out_op return out_op
@ -71,29 +88,49 @@ class ConvolutionTest(test.TestCase):
result_2 = self.evaluate(operation) result_2 = self.evaluate(operation)
self.assertAllEqual(result_1, result_2) self.assertAllEqual(result_1, result_2)
# The default forward algorithm choice, when using cuDNN 7, does not support
# the following layer configuration. This test case intends to confirm that
# an alternative algorithm is selected. Note that, in cuDNN 7, all forward
# algorithms are determnistic.
@test_util.run_cuda_only
def testForward(self):
np.random.seed(3)
in_shape = LayerShapeNCDHW(batch=2, channels=3, depth=5, height=7, width=6)
filter_shape = FilterShape3D(depth=3, height=3, width=3, in_channels=3,
out_channels=2)
in_op = self._random_data_op(in_shape)
filter_op = self._random_data_op(filter_shape)
strides = [1, 1, 1, 1, 1]
padding = 'VALID'
dilations = [1, 1, 2, 2, 2]
out_op = nn_ops.conv3d(in_op, filter_op, strides=strides, padding=padding,
data_format='NCDHW', dilations=dilations)
self._assert_reproducible(out_op)
@test_util.run_cuda_only @test_util.run_cuda_only
def testBackwardFilterGradient(self): def testBackwardFilterGradient(self):
np.random.seed(1) np.random.seed(1)
in_shape = LayerShape(batch=8, height=128, width=128, channels=8) in_shape = LayerShapeNHWC(batch=8, height=128, width=128, channels=8)
filter_shape = FilterShape(height=3, width=3, in_channels=8, out_channels=8) filter_shape = FilterShape2D(height=3, width=3, in_channels=8,
out_channels=8)
in_op = self._random_data_op(in_shape) in_op = self._random_data_op(in_shape)
out_op = self._random_out_op(in_shape, filter_shape) strides = [1, 1, 1, 1]
padding = 'SAME'
out_op = self._random_out_op(in_shape, filter_shape, strides, padding)
filter_gradient_op = nn_ops.conv2d_backprop_filter( filter_gradient_op = nn_ops.conv2d_backprop_filter(
in_op, filter_shape, out_op, strides=_STRIDES, padding=_PADDING) in_op, filter_shape, out_op, strides=strides, padding=padding)
self._assert_reproducible(filter_gradient_op) self._assert_reproducible(filter_gradient_op)
@test_util.run_cuda_only @test_util.run_cuda_only
def testBackwardInputGradient(self): def testBackwardInputGradient(self):
np.random.seed(2) np.random.seed(2)
in_shape = LayerShape(batch=8, height=32, width=32, channels=8) in_shape = LayerShapeNHWC(batch=8, height=32, width=32, channels=8)
filter_shape = FilterShape( filter_shape = FilterShape2D(height=7, width=7, in_channels=8,
height=7, width=7, in_channels=8, out_channels=128) out_channels=128)
filter_op = self._random_data_op(filter_shape) filter_op = self._random_data_op(filter_shape)
out_op = self._random_out_op(in_shape, filter_shape) strides = [1, 1, 1, 1]
padding = 'SAME'
out_op = self._random_out_op(in_shape, filter_shape, strides, padding)
input_gradient_op = nn_ops.conv2d_backprop_input( input_gradient_op = nn_ops.conv2d_backprop_input(
in_shape, filter_op, out_op, strides=_STRIDES, padding=_PADDING) in_shape, filter_op, out_op, strides=strides, padding=padding)
self._assert_reproducible(input_gradient_op) self._assert_reproducible(input_gradient_op)
# TODO(duncanriach): (1) add test to confirm that forward autotuning is
# disabled for cuDNN convolution; (2) add test for deterministic cuDNN
# max-pooling

View File

@ -343,6 +343,7 @@ cc_library(
":cuda_platform_id", ":cuda_platform_id",
":cuda_stream", ":cuda_stream",
":cuda_timer", ":cuda_timer",
":cuda_helpers",
":cudnn_version", ":cudnn_version",
":cudnn_lib", ":cudnn_lib",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
@ -480,7 +481,8 @@ cc_library(
# TODO(leary) we likely need to canonicalize/eliminate this. # TODO(leary) we likely need to canonicalize/eliminate this.
cc_library( cc_library(
name = "cuda_helpers", name = "cuda_helpers",
textual_hdrs = if_cuda_is_configured(["cuda_helpers.h"]), srcs = if_cuda_is_configured(["cuda_helpers.cc"]),
hdrs = if_cuda_is_configured(["cuda_helpers.h"]),
deps = if_cuda_is_configured([ deps = if_cuda_is_configured([
"//tensorflow/stream_executor/gpu:gpu_helpers_header", "//tensorflow/stream_executor/gpu:gpu_helpers_header",
]), ]),

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
#include "tensorflow/stream_executor/cuda/cuda_driver.h" #include "tensorflow/stream_executor/cuda/cuda_driver.h"
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h" #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/cuda/cuda_stream.h" #include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/cuda/cuda_timer.h" #include "tensorflow/stream_executor/cuda/cuda_timer.h"
@ -630,22 +631,6 @@ bool BatchnormSpatialPersistentEnabled() {
return is_enabled; return is_enabled;
} }
// A helper function to decide whether to enable deterministic functionality.
bool RequireDeterminism() {
static bool require_determinism = [] {
bool deterministic_ops = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
/*default_val=*/false,
&deterministic_ops));
bool cudnn_deterministic = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
/*default_val=*/false,
&cudnn_deterministic));
return deterministic_ops || cudnn_deterministic;
}();
return require_determinism;
}
std::tuple<int, int> GetCcMajorMinor(Stream* stream) { std::tuple<int, int> GetCcMajorMinor(Stream* stream) {
int cc_major, cc_minor; int cc_major, cc_minor;
stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
@ -744,9 +729,10 @@ class CudnnPoolingDescriptor {
std::transform(shape64.cbegin(), shape64.cend(), shape.begin(), std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
&CheckedNarrowing<int64, int>); &CheckedNarrowing<int64, int>);
bool propagate_nans = pooling_descriptor.propagate_nans(); bool propagate_nans = pooling_descriptor.propagate_nans();
const auto cudnn_max_pooling_mode = RequireDeterminism() const auto cudnn_max_pooling_mode =
? CUDNN_POOLING_MAX_DETERMINISTIC stream_executor::cuda::RequireCuDNNDeterminism()
: CUDNN_POOLING_MAX; ? CUDNN_POOLING_MAX_DETERMINISTIC
: CUDNN_POOLING_MAX;
CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor( CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor(
handle_.get(), handle_.get(),
(pooling_descriptor.mode() == dnn::PoolingMode::kMaximum (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
@ -3247,16 +3233,10 @@ bool CudnnSupport::GetConvolveAlgorithms(
bool tensor_op_math_available = TensorOpMathAvailable(cc_major); bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
out_algorithms->clear(); out_algorithms->clear();
if (RequireDeterminism()) {
out_algorithms->push_back({CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
tensor_op_math_available});
return true;
}
std::vector<dnn::AlgorithmDesc::Index> algo_types = { std::vector<dnn::AlgorithmDesc::Index> algo_types = {
// clang-format off // clang-format off
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_GEMM, CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
CUDNN_CONVOLUTION_FWD_ALGO_FFT, CUDNN_CONVOLUTION_FWD_ALGO_FFT,
@ -3270,11 +3250,12 @@ bool CudnnSupport::GetConvolveAlgorithms(
algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED); algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
} }
// The algorithms are intentionally ordered for deterministic operation
for (auto i : algo_types) { for (auto i : algo_types) {
out_algorithms->push_back({i, /*use_tensor_ops=*/false});
if (tensor_op_math_available) { if (tensor_op_math_available) {
out_algorithms->push_back({i, /*use_tensor_ops=*/true}); out_algorithms->push_back({i, /*use_tensor_ops=*/true});
} }
out_algorithms->push_back({i, /*use_tensor_ops=*/false});
} }
return true; return true;
@ -3308,15 +3289,8 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
bool tensor_op_math_available = TensorOpMathAvailable(cc_major); bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
out_algorithms->clear(); out_algorithms->clear();
if (RequireDeterminism()) {
out_algorithms->push_back(
{CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, tensor_op_math_available});
return true;
}
std::vector<dnn::AlgorithmDesc::Index> algo_types = { std::vector<dnn::AlgorithmDesc::Index> algo_types = {
// clang-format off // clang-format off
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
@ -3326,12 +3300,16 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) { if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED); algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
} }
if (!stream_executor::cuda::RequireCuDNNDeterminism()) {
algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0);
}
// The algorithms are intentionally ordered for deterministic operation
for (auto i : algo_types) { for (auto i : algo_types) {
out_algorithms->push_back({i, /*use_tensor_ops=*/false});
if (tensor_op_math_available) { if (tensor_op_math_available) {
out_algorithms->push_back({i, /*use_tensor_ops=*/true}); out_algorithms->push_back({i, /*use_tensor_ops=*/true});
} }
out_algorithms->push_back({i, /*use_tensor_ops=*/false});
} }
return true; return true;
@ -3343,18 +3321,10 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
bool tensor_op_math_available = TensorOpMathAvailable(cc_major); bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
out_algorithms->clear(); out_algorithms->clear();
if (RequireDeterminism()) {
out_algorithms->push_back(
{CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, tensor_op_math_available});
return true;
}
std::vector<dnn::AlgorithmDesc::Index> algo_types = { std::vector<dnn::AlgorithmDesc::Index> algo_types = {
// clang-format off // clang-format off
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
// Based on cudnn.h, the following is not implemented. // Based on cudnn.h, the following is not implemented.
// CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD, // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
@ -3366,12 +3336,17 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) { if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED); algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
} }
if (!stream_executor::cuda::RequireCuDNNDeterminism()) {
algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0);
algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3);
}
// The algorithms are intentionally ordered for deterministic operation
for (auto i : algo_types) { for (auto i : algo_types) {
out_algorithms->push_back({i, /*use_tensor_ops=*/false});
if (tensor_op_math_available) { if (tensor_op_math_available) {
out_algorithms->push_back({i, /*use_tensor_ops=*/true}); out_algorithms->push_back({i, /*use_tensor_ops=*/true});
} }
out_algorithms->push_back({i, /*use_tensor_ops=*/false});
} }
return true; return true;

View File

@ -0,0 +1,39 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
#include "tensorflow/core/util/env_var.h"
namespace stream_executor {
namespace cuda {
bool RequireCuDNNDeterminism() {
static bool require_cudnn_determinism = [] {
bool deterministic_ops = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
/*default_val=*/false,
&deterministic_ops));
bool cudnn_deterministic = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
/*default_val=*/false,
&cudnn_deterministic));
return deterministic_ops || cudnn_deterministic;
}();
return require_cudnn_determinism;
}
} // namespace cuda
} // namespace stream_executor

View File

@ -22,4 +22,14 @@ limitations under the License.
#include "tensorflow/stream_executor/gpu/gpu_helpers.h" #include "tensorflow/stream_executor/gpu/gpu_helpers.h"
namespace stream_executor {
namespace cuda {
// A helper function to decide whether to enable deterministic cuDNN
// functionality.
bool RequireCuDNNDeterminism();
} // namespace cuda
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_HELPERS_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_HELPERS_H_