diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 4c57d4f28f0..a576ecebe53 100755 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -672,7 +672,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:autotuning_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/util/proto:proto_utils", "//tensorflow/stream_executor:device_memory_allocator", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc index e2327686223..4562996a65f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logger.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/proto/proto_utils.h" #include "tensorflow/stream_executor/gpu/redzone_allocator.h" @@ -333,35 +332,6 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithm( return result_or; } -// The following function allows deterministic ops to be implemented relatively -// quickly using environment variables. It is intended to be temporary. The -// longer-term intention is to enable deterministic ops via tf.config and -// appropriate plumbing. See the discussion on PR 34951 for more information: -// https://github.com/tensorflow/tensorflow/pull/34951#discussion_r355682316 -// This function and associated comment are replicated in the following three -// places: -// 1. tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc -// 2. tensorflow/core/kernels/gpu_utils.cc -// 3. tensorflow/stream_executor/cuda/cuda_dnn.cc -// When implementing the plumbing, you should also search for the use of -// TF_DETERMINISTIC_OPS on its own. -// TODO(duncanriach): move to an API that uses tf.config and implement the first -// phase of plumbing. -static bool RequireCudnnDeterminism() { - static bool require_cudnn_determinism = [] { - bool deterministic_ops = false; - TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS", - /*default_val=*/false, - &deterministic_ops)); - bool cudnn_deterministic = false; - TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC", - /*default_val=*/false, - &cudnn_deterministic)); - return deterministic_ops || cudnn_deterministic; - }(); - return require_cudnn_determinism; -} - StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator, @@ -598,41 +568,43 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( } } + // For now, we ignore WRONG_RESULT failures because false-positives are + // possible (e.g. perhaps the reference algorithm is the one that's + // incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're + // quite severe and can be detected with high accuracy. + auto has_failure = [](const AutotuneResult& r) { + return r.has_failure() && + r.failure().kind() != AutotuneResult::WRONG_RESULT; + }; + // Choose the fastest convolution that doesn't produce a REDZONE_MODIFIED // error. // // TODO(jlebar): We ought to be able to detect redzone reads by noticing NaNs // in the output of the conv and skip those. // - // For now, we ignore WRONG_RESULT failures because false-positives are - // possible (e.g. perhaps the reference algorithm is the one that's - // incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're - // quite severe and can be detected with high accuracy. - std::vector filtered_results; - absl::c_copy_if( - profile_results, std::back_inserter(filtered_results), - [](const AutotuneResult& r) { - return !(r.has_failure() && - r.failure().kind() != AutotuneResult::WRONG_RESULT); + // The successful one should have a smaller key, since we are doing + // min_element. If they are both unsuccessful, keep the earlier one in + // the vector by comparing pointers. + auto result_comparison_key = [&has_failure](const AutotuneResult& r) { + return std::make_tuple( + has_failure(r), + tensorflow::proto_utils::FromDurationProto(r.run_time())); + }; + const auto& best_result = absl::c_min_element( + profile_results, + [&](const AutotuneResult& lhs, const AutotuneResult& rhs) { + return result_comparison_key(lhs) < result_comparison_key(rhs); }); - if (filtered_results.empty()) { - return InternalError( - "All algorithms tried for convolution %s failed. Falling back to " - "default algorithm. ", - instr->ToString()); + + if (best_result != profile_results.end() && !has_failure(*best_result)) { + return *best_result; } - auto selected_result = filtered_results.begin(); - if (!RequireCudnnDeterminism()) { - selected_result = absl::c_min_element( - filtered_results, - [](const AutotuneResult& lhs, const AutotuneResult& rhs) { - return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) < - tensorflow::proto_utils::FromDurationProto(rhs.run_time()); - }); - } - - return *selected_result; + return InternalError( + "All algorithms tried for convolution %s failed. Falling back to " + "default algorithm.", + instr->ToString()); } StatusOr diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 7577d820beb..0109597e9cc 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -532,7 +532,6 @@ tf_cuda_library( "//tensorflow/core:conv_autotuning_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:stream_executor", - "//tensorflow/core/util:env_var", "//tensorflow/core/util/proto:proto_utils", "//tensorflow/stream_executor/gpu:asm_compiler", "//tensorflow/stream_executor/gpu:redzone_allocator", diff --git a/tensorflow/core/kernels/gpu_utils.cc b/tensorflow/core/kernels/gpu_utils.cc index 144d03814c8..5bf211dcdf2 100644 --- a/tensorflow/core/kernels/gpu_utils.cc +++ b/tensorflow/core/kernels/gpu_utils.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/core/platform/logger.h" #include "tensorflow/core/protobuf/autotuning.pb.h" #include "tensorflow/core/protobuf/conv_autotuning.pb.h" -#include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/proto/proto_utils.h" #include "tensorflow/stream_executor/gpu/asm_compiler.h" #include "tensorflow/stream_executor/gpu/redzone_allocator.h" @@ -211,35 +210,6 @@ void LogFusedConvForwardAutotuneResults( Logger::GetSingleton()->LogProto(log); } -// The following function allows deterministic ops to be implemented relatively -// quickly using environment variables. It is intended to be temporary. The -// longer-term intention is to enable deterministic ops via tf.config and -// appropriate plumbing. See the discussion on PR 34951 for more information: -// https://github.com/tensorflow/tensorflow/pull/34951#discussion_r355682316 -// This function and associated comment are replicated in the following three -// places: -// 1. tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc -// 2. tensorflow/core/kernels/gpu_utils.cc -// 3. tensorflow/stream_executor/cuda/cuda_dnn.cc -// When implementing the plumbing, you should also search for the use of -// TF_DETERMINISTIC_OPS on its own. -// TODO(duncanriach): move to an API that uses tf.config and implement the first -// phase of plumbing. -bool RequireCudnnDeterminism() { - static bool require_cudnn_determinism = [] { - bool deterministic_ops = false; - TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS", - /*default_val=*/false, - &deterministic_ops)); - bool cudnn_deterministic = false; - TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC", - /*default_val=*/false, - &cudnn_deterministic)); - return deterministic_ops || cudnn_deterministic; - }(); - return require_cudnn_determinism; -} - Status BestCudnnConvAlgorithm(absl::Span results, se::dnn::AlgorithmConfig* algo) { std::vector filtered_results; @@ -249,32 +219,31 @@ Status BestCudnnConvAlgorithm(absl::Span results, if (filtered_results.empty()) { return errors::NotFound("No algorithm worked!"); } - std::vector filtered_results_no_scratch; - absl::c_copy_if( - filtered_results, std::back_inserter(filtered_results_no_scratch), - [](const AutotuneResult& result) { return result.scratch_bytes() == 0; }); - auto selected_result = filtered_results.begin(); - auto selected_result_no_scratch = filtered_results_no_scratch.begin(); - if (!RequireCudnnDeterminism()) { - auto compare_run_times = [](const AutotuneResult& lhs, - const AutotuneResult& rhs) { - return proto_utils::FromDurationProto(lhs.run_time()) < - proto_utils::FromDurationProto(rhs.run_time()); - }; - selected_result = absl::c_min_element(filtered_results, compare_run_times); - selected_result_no_scratch = - absl::c_min_element(filtered_results_no_scratch, compare_run_times); - } + const auto best_result = absl::c_min_element( + filtered_results, + [](const AutotuneResult& lhs, const AutotuneResult& rhs) { + return proto_utils::FromDurationProto(lhs.run_time()) < + proto_utils::FromDurationProto(rhs.run_time()); + }); - algo->set_algorithm({selected_result->conv().algorithm(), - selected_result->conv().tensor_ops_enabled()}); - if (selected_result_no_scratch != filtered_results_no_scratch.end()) { + const auto best_result_no_scratch = absl::c_min_element( + filtered_results, + [](const AutotuneResult& lhs, const AutotuneResult& rhs) { + return std::make_tuple(lhs.scratch_bytes(), + proto_utils::FromDurationProto(lhs.run_time())) < + std::make_tuple(rhs.scratch_bytes(), + proto_utils::FromDurationProto(rhs.run_time())); + }); + + algo->set_algorithm({best_result->conv().algorithm(), + best_result->conv().tensor_ops_enabled()}); + if (best_result_no_scratch != filtered_results.end() && + best_result_no_scratch->scratch_bytes() == 0) { algo->set_algorithm_no_scratch( - {selected_result_no_scratch->conv().algorithm(), - selected_result_no_scratch->conv().tensor_ops_enabled()}); + {best_result_no_scratch->conv().algorithm(), + best_result_no_scratch->conv().tensor_ops_enabled()}); } - return Status::OK(); } diff --git a/tensorflow/python/kernel_tests/cudnn_deterministic_base.py b/tensorflow/python/kernel_tests/cudnn_deterministic_base.py index 9886913a775..289cc393042 100644 --- a/tensorflow/python/kernel_tests/cudnn_deterministic_base.py +++ b/tensorflow/python/kernel_tests/cudnn_deterministic_base.py @@ -28,39 +28,22 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test -# Notes: -# -# Deterministic cuDNN operation is selected by setting either of the two -# environment variables TF_CUDNN_DETERMINISTIC or TF_DETERMINISTIC_OPS to 'true' -# or '1' while also not setting the environment variable TF_CUDNN_USE_AUTOTUNE -# to 'false' or '0'. -# -# Where both deterministic and non-deterministic cuDNN algorithms are available, -# selecting determinitic operation will lead to only the deterministic -# algorithms being chosen. Additionally, selecting deterministic operation will -# result in a deterministic, or reproducible, selection of algorithms (for any -# given layer configuration) for each of the forward and the two backward paths. -# -# These tests intend to confirm that deterministic algorithms are chosen (for -# the back-prop paths) when desterministic operation is selected. The tested -# configurations were first confirmed to produce non-deterministic results when -# the above-mentioned environment variables are not set. -# -# Even though selecting determinitic operation should ensure that the same -# algorithms, for a given layer configuration, are always used (i.e. that -# algorithm selection is deterministic / reproducible), this is not tested. +# Setting either of the two environment variables TF_CUDNN_DETERMINISTIC or +# TF_DETERMINISTIC_OPS to "true" or "1" will disable autotuning of cuDNN +# algorithms and cause deterministic cuDNN algorithms to be selected when both +# deterministic and non-deterministic algorithms are available. These tests are +# intended to confirm that deterministic algorithms are chosen when either +# environment variable is set to "true" or "1". The tested configurations were +# first confirmed to produce non-deterministic results when the environment +# variables are not set. -# TODO(duncanriach): Add test for deterministic cuDNN max-pooling +_PADDING = 'SAME' +_STRIDES = [1, 1, 1, 1] -LayerShapeNHWC = collections.namedtuple('LayerShapeNHWC', - 'batch, height, width, channels') -FilterShape2D = collections.namedtuple( - 'FilterShape2D', 'height, width, in_channels, out_channels') - -LayerShapeNCDHW = collections.namedtuple( - 'LayerShapeNCDHW', 'batch, channels, depth, height, width') -FilterShape3D = collections.namedtuple( - 'FilterShape3D', 'depth, height, width, in_channels, out_channels') +LayerShape = collections.namedtuple('LayerShape', + 'batch, height, width, channels') +FilterShape = collections.namedtuple( + 'FilterShape', 'height, width, in_channels, out_channels') class ConvolutionTest(test.TestCase): @@ -71,13 +54,14 @@ class ConvolutionTest(test.TestCase): return constant_op.constant( 2 * np.random.random_sample(shape) - 1, dtype=dtypes.float32) - def _random_out_op(self, in_shape, filter_shape, strides, padding): + def _random_out_op(self, in_shape, filter_shape): # Choosing not to use array_op.zeros() to prevent possible removal by # optimization in_op = self._random_data_op(in_shape) filter_op = self._random_data_op(filter_shape) # Use the forward op's shape-inference - conv_op = nn_ops.conv2d(in_op, filter_op, strides=strides, padding=padding) + conv_op = nn_ops.conv2d( + in_op, filter_op, strides=_STRIDES, padding=_PADDING) out_shape = conv_op.get_shape() out_op = self._random_data_op(out_shape) return out_op @@ -88,54 +72,29 @@ class ConvolutionTest(test.TestCase): result_2 = self.evaluate(operation) self.assertAllEqual(result_1, result_2) - # The default forward algorithm choice, when using cuDNN 7, does not support - # the following layer configuration. This test case intends to confirm that - # an alternative algorithm is selected. Note that, in cuDNN 7, all forward - # algorithms are determnistic. - @test_util.run_cuda_only - def testForward(self): - np.random.seed(3) - in_shape = LayerShapeNCDHW(batch=2, channels=3, depth=5, height=7, width=6) - filter_shape = FilterShape3D( - depth=3, height=3, width=3, in_channels=3, out_channels=2) - in_op = self._random_data_op(in_shape) - filter_op = self._random_data_op(filter_shape) - strides = [1, 1, 1, 1, 1] - padding = 'VALID' - dilations = [1, 1, 2, 2, 2] - out_op = nn_ops.conv3d( - in_op, - filter_op, - strides=strides, - padding=padding, - data_format='NCDHW', - dilations=dilations) - self._assert_reproducible(out_op) - @test_util.run_cuda_only def testBackwardFilterGradient(self): np.random.seed(1) - in_shape = LayerShapeNHWC(batch=8, height=128, width=128, channels=8) - filter_shape = FilterShape2D( - height=3, width=3, in_channels=8, out_channels=8) + in_shape = LayerShape(batch=8, height=128, width=128, channels=8) + filter_shape = FilterShape(height=3, width=3, in_channels=8, out_channels=8) in_op = self._random_data_op(in_shape) - strides = [1, 1, 1, 1] - padding = 'SAME' - out_op = self._random_out_op(in_shape, filter_shape, strides, padding) + out_op = self._random_out_op(in_shape, filter_shape) filter_gradient_op = nn_ops.conv2d_backprop_filter( - in_op, filter_shape, out_op, strides=strides, padding=padding) + in_op, filter_shape, out_op, strides=_STRIDES, padding=_PADDING) self._assert_reproducible(filter_gradient_op) @test_util.run_cuda_only def testBackwardInputGradient(self): np.random.seed(2) - in_shape = LayerShapeNHWC(batch=8, height=32, width=32, channels=8) - filter_shape = FilterShape2D( + in_shape = LayerShape(batch=8, height=32, width=32, channels=8) + filter_shape = FilterShape( height=7, width=7, in_channels=8, out_channels=128) filter_op = self._random_data_op(filter_shape) - strides = [1, 1, 1, 1] - padding = 'SAME' - out_op = self._random_out_op(in_shape, filter_shape, strides, padding) + out_op = self._random_out_op(in_shape, filter_shape) input_gradient_op = nn_ops.conv2d_backprop_input( - in_shape, filter_op, out_op, strides=strides, padding=padding) + in_shape, filter_op, out_op, strides=_STRIDES, padding=_PADDING) self._assert_reproducible(input_gradient_op) + + # TODO(duncanriach): (1) add test to confirm that forward autotuning is + # disabled for cuDNN convolution; (2) add test for deterministic cuDNN + # max-pooling diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 1bb20d7a6ce..f87c0496733 100755 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -648,22 +648,9 @@ bool BatchnormSpatialPersistentEnabled() { return is_enabled; } -// The following function allows deterministic ops to be implemented relatively -// quickly using environment variables. It is intended to be temporary. The -// longer-term intention is to enable deterministic ops via tf.config and -// appropriate plumbing. See the discussion on PR 34951 for more information: -// https://github.com/tensorflow/tensorflow/pull/34951#discussion_r355682316 -// This function and associated comment are replicated in the following three -// places: -// 1. tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc -// 2. tensorflow/core/kernels/gpu_utils.cc -// 3. tensorflow/stream_executor/cuda/cuda_dnn.cc -// When implementing the plumbing, you should also search for the use of -// TF_DETERMINISTIC_OPS on its own. -// TODO(duncanriach): move to an API that uses tf.config and implement the first -// phase of plumbing. -bool RequireCudnnDeterminism() { - static bool require_cudnn_determinism = [] { +// A helper function to decide whether to enable deterministic functionality. +bool RequireDeterminism() { + static bool require_determinism = [] { bool deterministic_ops = false; TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS", /*default_val=*/false, @@ -674,7 +661,7 @@ bool RequireCudnnDeterminism() { &cudnn_deterministic)); return deterministic_ops || cudnn_deterministic; }(); - return require_cudnn_determinism; + return require_determinism; } std::tuple GetCcMajorMinor(Stream* stream) { @@ -775,7 +762,7 @@ class CudnnPoolingDescriptor { std::transform(shape64.cbegin(), shape64.cend(), shape.begin(), &CheckedNarrowing); bool propagate_nans = pooling_descriptor.propagate_nans(); - const auto cudnn_max_pooling_mode = RequireCudnnDeterminism() + const auto cudnn_max_pooling_mode = RequireDeterminism() ? CUDNN_POOLING_MAX_DETERMINISTIC : CUDNN_POOLING_MAX; CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor( @@ -3343,15 +3330,21 @@ bool CudnnSupport::GetConvolveAlgorithms( bool tensor_op_math_available = TensorOpMathAvailable(cc_major); out_algorithms->clear(); + if (RequireDeterminism()) { + out_algorithms->push_back({CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + tensor_op_math_available}); + return true; + } + std::vector algo_types = { - // clang-format off - CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + // clang-format off CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, CUDNN_CONVOLUTION_FWD_ALGO_GEMM, CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, CUDNN_CONVOLUTION_FWD_ALGO_FFT, CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, - // clang-format on + // clang-format on }; if (CudnnEnvVar::IsEnabled()) { algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING); @@ -3360,12 +3353,11 @@ bool CudnnSupport::GetConvolveAlgorithms( algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED); } - // The algorithms are intentionally ordered for deterministic operation for (auto i : algo_types) { + out_algorithms->push_back({i, /*use_tensor_ops=*/false}); if (tensor_op_math_available) { out_algorithms->push_back({i, /*use_tensor_ops=*/true}); } - out_algorithms->push_back({i, /*use_tensor_ops=*/false}); } return true; @@ -3399,8 +3391,15 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms( bool tensor_op_math_available = TensorOpMathAvailable(cc_major); out_algorithms->clear(); + if (RequireDeterminism()) { + out_algorithms->push_back( + {CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, tensor_op_math_available}); + return true; + } + std::vector algo_types = { // clang-format off + CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, @@ -3410,16 +3409,12 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms( if (CudnnEnvVar::IsEnabled() && with_winograd_nonfused) { algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED); } - if (!RequireCudnnDeterminism()) { - algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0); - } - // The algorithms are intentionally ordered for deterministic operation for (auto i : algo_types) { + out_algorithms->push_back({i, /*use_tensor_ops=*/false}); if (tensor_op_math_available) { out_algorithms->push_back({i, /*use_tensor_ops=*/true}); } - out_algorithms->push_back({i, /*use_tensor_ops=*/false}); } return true; @@ -3431,10 +3426,18 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms( bool tensor_op_math_available = TensorOpMathAvailable(cc_major); out_algorithms->clear(); + if (RequireDeterminism()) { + out_algorithms->push_back( + {CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, tensor_op_math_available}); + return true; + } + std::vector algo_types = { // clang-format off + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, // Based on cudnn.h, the following is not implemented. // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD, @@ -3446,17 +3449,12 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms( if (CudnnEnvVar::IsEnabled() && with_winograd_nonfused) { algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED); } - if (!RequireCudnnDeterminism()) { - algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0); - algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3); - } - // The algorithms are intentionally ordered for deterministic operation for (auto i : algo_types) { + out_algorithms->push_back({i, /*use_tensor_ops=*/false}); if (tensor_op_math_available) { out_algorithms->push_back({i, /*use_tensor_ops=*/true}); } - out_algorithms->push_back({i, /*use_tensor_ops=*/false}); } return true;