Add multi-algorithm deterministic cuDNN convolutions
This commit is contained in:
parent
2beb2d53ba
commit
5341e3d299
@ -669,6 +669,7 @@ cc_library(
|
|||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
"//tensorflow/core/util/proto:proto_utils",
|
"//tensorflow/core/util/proto:proto_utils",
|
||||||
"//tensorflow/stream_executor:device_memory_allocator",
|
"//tensorflow/stream_executor:device_memory_allocator",
|
||||||
|
"//tensorflow/stream_executor/cuda:cuda_helpers",
|
||||||
"//tensorflow/stream_executor/gpu:redzone_allocator",
|
"//tensorflow/stream_executor/gpu:redzone_allocator",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/logger.h"
|
#include "tensorflow/core/platform/logger.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||||
|
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
|
||||||
#include "tensorflow/stream_executor/gpu/redzone_allocator.h"
|
#include "tensorflow/stream_executor/gpu/redzone_allocator.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -536,43 +537,41 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// For now, we ignore WRONG_RESULT failures because false-positives are
|
|
||||||
// possible (e.g. perhaps the reference algorithm is the one that's
|
|
||||||
// incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're
|
|
||||||
// quite severe and can be detected with high accuracy.
|
|
||||||
auto has_failure = [](const AutotuneResult& r) {
|
|
||||||
return r.has_failure() &&
|
|
||||||
r.failure().kind() != AutotuneResult::WRONG_RESULT;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Choose the fastest convolution that doesn't produce a REDZONE_MODIFIED
|
// Choose the fastest convolution that doesn't produce a REDZONE_MODIFIED
|
||||||
// error.
|
// error.
|
||||||
//
|
//
|
||||||
// TODO(jlebar): We ought to be able to detect redzone reads by noticing NaNs
|
// TODO(jlebar): We ought to be able to detect redzone reads by noticing NaNs
|
||||||
// in the output of the conv and skip those.
|
// in the output of the conv and skip those.
|
||||||
//
|
//
|
||||||
// The successful one should have a smaller key, since we are doing
|
// For now, we ignore WRONG_RESULT failures because false-positives are
|
||||||
// min_element. If they are both unsuccessful, keep the earlier one in
|
// possible (e.g. perhaps the reference algorithm is the one that's
|
||||||
// the vector by comparing pointers.
|
// incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're
|
||||||
auto result_comparison_key = [&has_failure](const AutotuneResult& r) {
|
// quite severe and can be detected with high accuracy.
|
||||||
return std::make_tuple(
|
std::vector<AutotuneResult> filtered_results;
|
||||||
has_failure(r),
|
absl::c_copy_if(
|
||||||
tensorflow::proto_utils::FromDurationProto(r.run_time()));
|
profile_results, std::back_inserter(filtered_results),
|
||||||
};
|
[](const AutotuneResult& r) {
|
||||||
const auto& best_result = absl::c_min_element(
|
return !(r.has_failure() &&
|
||||||
profile_results,
|
r.failure().kind() != AutotuneResult::WRONG_RESULT);
|
||||||
[&](const AutotuneResult& lhs, const AutotuneResult& rhs) {
|
|
||||||
return result_comparison_key(lhs) < result_comparison_key(rhs);
|
|
||||||
});
|
});
|
||||||
|
if (filtered_results.empty()) {
|
||||||
if (best_result != profile_results.end() && !has_failure(*best_result)) {
|
return InternalError(
|
||||||
return *best_result;
|
"All algorithms tried for convolution %s failed. Falling back to "
|
||||||
|
"default algorithm. ",
|
||||||
|
instr->ToString());
|
||||||
}
|
}
|
||||||
|
|
||||||
return InternalError(
|
auto selected_result = filtered_results.begin();
|
||||||
"All algorithms tried for convolution %s failed. Falling back to "
|
if (!se::cuda::RequireCuDNNDeterminism()) {
|
||||||
"default algorithm.",
|
selected_result = absl::c_min_element(
|
||||||
instr->ToString());
|
filtered_results,
|
||||||
|
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
|
||||||
|
return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) <
|
||||||
|
tensorflow::proto_utils::FromDurationProto(rhs.run_time());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return *selected_result;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<tensorflow::AutotuneResult>
|
StatusOr<tensorflow::AutotuneResult>
|
||||||
|
@ -531,6 +531,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:stream_executor",
|
"//tensorflow/core:stream_executor",
|
||||||
"//tensorflow/core/util/proto:proto_utils",
|
"//tensorflow/core/util/proto:proto_utils",
|
||||||
|
"//tensorflow/stream_executor/cuda:cuda_helpers",
|
||||||
"//tensorflow/stream_executor/gpu:asm_compiler",
|
"//tensorflow/stream_executor/gpu:asm_compiler",
|
||||||
"//tensorflow/stream_executor/gpu:redzone_allocator",
|
"//tensorflow/stream_executor/gpu:redzone_allocator",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||||
#include "tensorflow/core/protobuf/conv_autotuning.pb.h"
|
#include "tensorflow/core/protobuf/conv_autotuning.pb.h"
|
||||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||||
|
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
|
||||||
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
|
#include "tensorflow/stream_executor/gpu/asm_compiler.h"
|
||||||
#include "tensorflow/stream_executor/gpu/redzone_allocator.h"
|
#include "tensorflow/stream_executor/gpu/redzone_allocator.h"
|
||||||
|
|
||||||
@ -220,31 +221,32 @@ Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
|||||||
if (filtered_results.empty()) {
|
if (filtered_results.empty()) {
|
||||||
return errors::NotFound("No algorithm worked!");
|
return errors::NotFound("No algorithm worked!");
|
||||||
}
|
}
|
||||||
|
std::vector<AutotuneResult> filtered_results_no_scratch;
|
||||||
|
absl::c_copy_if(
|
||||||
|
filtered_results, std::back_inserter(filtered_results_no_scratch),
|
||||||
|
[](const AutotuneResult& result) { return result.scratch_bytes() == 0; });
|
||||||
|
|
||||||
const auto best_result = absl::c_min_element(
|
auto selected_result = filtered_results.begin();
|
||||||
filtered_results,
|
auto selected_result_no_scratch = filtered_results_no_scratch.begin();
|
||||||
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
|
if (!se::cuda::RequireCuDNNDeterminism()) {
|
||||||
return proto_utils::FromDurationProto(lhs.run_time()) <
|
auto compare_run_times = [](const AutotuneResult& lhs,
|
||||||
proto_utils::FromDurationProto(rhs.run_time());
|
const AutotuneResult& rhs) {
|
||||||
});
|
return proto_utils::FromDurationProto(lhs.run_time()) <
|
||||||
|
proto_utils::FromDurationProto(rhs.run_time());
|
||||||
const auto best_result_no_scratch = absl::c_min_element(
|
};
|
||||||
filtered_results,
|
selected_result = absl::c_min_element(filtered_results, compare_run_times);
|
||||||
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
|
selected_result_no_scratch = absl::c_min_element(
|
||||||
return std::make_tuple(lhs.scratch_bytes(),
|
filtered_results_no_scratch, compare_run_times);
|
||||||
proto_utils::FromDurationProto(lhs.run_time())) <
|
|
||||||
std::make_tuple(rhs.scratch_bytes(),
|
|
||||||
proto_utils::FromDurationProto(rhs.run_time()));
|
|
||||||
});
|
|
||||||
|
|
||||||
algo->set_algorithm({best_result->conv().algorithm(),
|
|
||||||
best_result->conv().tensor_ops_enabled()});
|
|
||||||
if (best_result_no_scratch != filtered_results.end() &&
|
|
||||||
best_result_no_scratch->scratch_bytes() == 0) {
|
|
||||||
algo->set_algorithm_no_scratch(
|
|
||||||
{best_result_no_scratch->conv().algorithm(),
|
|
||||||
best_result_no_scratch->conv().tensor_ops_enabled()});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
algo->set_algorithm({selected_result->conv().algorithm(),
|
||||||
|
selected_result->conv().tensor_ops_enabled()});
|
||||||
|
if (selected_result_no_scratch != filtered_results_no_scratch.end()) {
|
||||||
|
algo->set_algorithm_no_scratch(
|
||||||
|
{selected_result_no_scratch->conv().algorithm(),
|
||||||
|
selected_result_no_scratch->conv().tensor_ops_enabled()});
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,22 +27,39 @@ from tensorflow.python.framework import test_util
|
|||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
# Setting either of the two environment variables TF_CUDNN_DETERMINISTIC or
|
# Notes:
|
||||||
# TF_DETERMINISTIC_OPS to "true" or "1" will disable autotuning of cuDNN
|
#
|
||||||
# algorithms and cause deterministic cuDNN algorithms to be selected when both
|
# Deterministic cuDNN operation is selected by setting either of the two
|
||||||
# deterministic and non-deterministic algorithms are available. These tests are
|
# environment variables TF_CUDNN_DETERMINISTIC or TF_DETERMINISTIC_OPS to 'true'
|
||||||
# intended to confirm that deterministic algorithms are chosen when either
|
# or '1' while also not setting the environment variable TF_CUDNN_USE_AUTOTUNE
|
||||||
# environment variable is set to "true" or "1". The tested configurations were
|
# to 'false' or '0'.
|
||||||
# first confirmed to produce non-deterministic results when the environment
|
#
|
||||||
# variables are not set.
|
# Where both deterministic and non-deterministic cuDNN algorithms are available,
|
||||||
|
# selecting determinitic operation will lead to only the deterministic
|
||||||
|
# algorithms being chosen. Additionally, selecting deterministic operation will
|
||||||
|
# result in a deterministic, or reproducible, selection of algorithms (for any
|
||||||
|
# given layer configuration) for each of the forward and the two backward paths.
|
||||||
|
#
|
||||||
|
# These tests intend to confirm that deterministic algorithms are chosen (for
|
||||||
|
# the back-prop paths) when desterministic operation is selected. The tested
|
||||||
|
# configurations were first confirmed to produce non-deterministic results when
|
||||||
|
# the above-mentioned environment variables are not set.
|
||||||
|
#
|
||||||
|
# Even though selecting determinitic operation should ensure that the same
|
||||||
|
# algorithms, for a given layer configuration, are always used (i.e. that
|
||||||
|
# algorithm selection is deterministic / reproducible), this is not tested.
|
||||||
|
|
||||||
_PADDING = 'SAME'
|
# TODO(duncanriach): Add test for deterministic cuDNN max-pooling
|
||||||
_STRIDES = [1, 1, 1, 1]
|
|
||||||
|
|
||||||
LayerShape = collections.namedtuple('LayerShape',
|
LayerShapeNHWC = collections.namedtuple('LayerShapeNHWC',
|
||||||
'batch, height, width, channels')
|
'batch, height, width, channels')
|
||||||
FilterShape = collections.namedtuple(
|
FilterShape2D = collections.namedtuple(
|
||||||
'FilterShape', 'height, width, in_channels, out_channels')
|
'FilterShape2D', 'height, width, in_channels, out_channels')
|
||||||
|
|
||||||
|
LayerShapeNCDHW = collections.namedtuple(
|
||||||
|
'LayerShapeNCDHW', 'batch, channels, depth, height, width')
|
||||||
|
FilterShape3D = collections.namedtuple(
|
||||||
|
'FilterShape3D', 'depth, height, width, in_channels, out_channels')
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionTest(test.TestCase):
|
class ConvolutionTest(test.TestCase):
|
||||||
@ -53,14 +70,14 @@ class ConvolutionTest(test.TestCase):
|
|||||||
return constant_op.constant(
|
return constant_op.constant(
|
||||||
2 * np.random.random_sample(shape) - 1, dtype=dtypes.float32)
|
2 * np.random.random_sample(shape) - 1, dtype=dtypes.float32)
|
||||||
|
|
||||||
def _random_out_op(self, in_shape, filter_shape):
|
def _random_out_op(self, in_shape, filter_shape, strides, padding):
|
||||||
# Choosing not to use array_op.zeros() to prevent possible removal by
|
# Choosing not to use array_op.zeros() to prevent possible removal by
|
||||||
# optimization
|
# optimization
|
||||||
in_op = self._random_data_op(in_shape)
|
in_op = self._random_data_op(in_shape)
|
||||||
filter_op = self._random_data_op(filter_shape)
|
filter_op = self._random_data_op(filter_shape)
|
||||||
# Use the forward op's shape-inference
|
# Use the forward op's shape-inference
|
||||||
conv_op = nn_ops.conv2d(
|
conv_op = nn_ops.conv2d(
|
||||||
in_op, filter_op, strides=_STRIDES, padding=_PADDING)
|
in_op, filter_op, strides=strides, padding=padding)
|
||||||
out_shape = conv_op.get_shape()
|
out_shape = conv_op.get_shape()
|
||||||
out_op = self._random_data_op(out_shape)
|
out_op = self._random_data_op(out_shape)
|
||||||
return out_op
|
return out_op
|
||||||
@ -71,29 +88,49 @@ class ConvolutionTest(test.TestCase):
|
|||||||
result_2 = self.evaluate(operation)
|
result_2 = self.evaluate(operation)
|
||||||
self.assertAllEqual(result_1, result_2)
|
self.assertAllEqual(result_1, result_2)
|
||||||
|
|
||||||
|
# The default forward algorithm choice, when using cuDNN 7, does not support
|
||||||
|
# the following layer configuration. This test case intends to confirm that
|
||||||
|
# an alternative algorithm is selected. Note that, in cuDNN 7, all forward
|
||||||
|
# algorithms are determnistic.
|
||||||
|
@test_util.run_cuda_only
|
||||||
|
def testForward(self):
|
||||||
|
np.random.seed(3)
|
||||||
|
in_shape = LayerShapeNCDHW(batch=2, channels=3, depth=5, height=7, width=6)
|
||||||
|
filter_shape = FilterShape3D(depth=3, height=3, width=3, in_channels=3,
|
||||||
|
out_channels=2)
|
||||||
|
in_op = self._random_data_op(in_shape)
|
||||||
|
filter_op = self._random_data_op(filter_shape)
|
||||||
|
strides = [1, 1, 1, 1, 1]
|
||||||
|
padding = 'VALID'
|
||||||
|
dilations = [1, 1, 2, 2, 2]
|
||||||
|
out_op = nn_ops.conv3d(in_op, filter_op, strides=strides, padding=padding,
|
||||||
|
data_format='NCDHW', dilations=dilations)
|
||||||
|
self._assert_reproducible(out_op)
|
||||||
|
|
||||||
@test_util.run_cuda_only
|
@test_util.run_cuda_only
|
||||||
def testBackwardFilterGradient(self):
|
def testBackwardFilterGradient(self):
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
in_shape = LayerShape(batch=8, height=128, width=128, channels=8)
|
in_shape = LayerShapeNHWC(batch=8, height=128, width=128, channels=8)
|
||||||
filter_shape = FilterShape(height=3, width=3, in_channels=8, out_channels=8)
|
filter_shape = FilterShape2D(height=3, width=3, in_channels=8,
|
||||||
|
out_channels=8)
|
||||||
in_op = self._random_data_op(in_shape)
|
in_op = self._random_data_op(in_shape)
|
||||||
out_op = self._random_out_op(in_shape, filter_shape)
|
strides = [1, 1, 1, 1]
|
||||||
|
padding = 'SAME'
|
||||||
|
out_op = self._random_out_op(in_shape, filter_shape, strides, padding)
|
||||||
filter_gradient_op = nn_ops.conv2d_backprop_filter(
|
filter_gradient_op = nn_ops.conv2d_backprop_filter(
|
||||||
in_op, filter_shape, out_op, strides=_STRIDES, padding=_PADDING)
|
in_op, filter_shape, out_op, strides=strides, padding=padding)
|
||||||
self._assert_reproducible(filter_gradient_op)
|
self._assert_reproducible(filter_gradient_op)
|
||||||
|
|
||||||
@test_util.run_cuda_only
|
@test_util.run_cuda_only
|
||||||
def testBackwardInputGradient(self):
|
def testBackwardInputGradient(self):
|
||||||
np.random.seed(2)
|
np.random.seed(2)
|
||||||
in_shape = LayerShape(batch=8, height=32, width=32, channels=8)
|
in_shape = LayerShapeNHWC(batch=8, height=32, width=32, channels=8)
|
||||||
filter_shape = FilterShape(
|
filter_shape = FilterShape2D(height=7, width=7, in_channels=8,
|
||||||
height=7, width=7, in_channels=8, out_channels=128)
|
out_channels=128)
|
||||||
filter_op = self._random_data_op(filter_shape)
|
filter_op = self._random_data_op(filter_shape)
|
||||||
out_op = self._random_out_op(in_shape, filter_shape)
|
strides = [1, 1, 1, 1]
|
||||||
|
padding = 'SAME'
|
||||||
|
out_op = self._random_out_op(in_shape, filter_shape, strides, padding)
|
||||||
input_gradient_op = nn_ops.conv2d_backprop_input(
|
input_gradient_op = nn_ops.conv2d_backprop_input(
|
||||||
in_shape, filter_op, out_op, strides=_STRIDES, padding=_PADDING)
|
in_shape, filter_op, out_op, strides=strides, padding=padding)
|
||||||
self._assert_reproducible(input_gradient_op)
|
self._assert_reproducible(input_gradient_op)
|
||||||
|
|
||||||
# TODO(duncanriach): (1) add test to confirm that forward autotuning is
|
|
||||||
# disabled for cuDNN convolution; (2) add test for deterministic cuDNN
|
|
||||||
# max-pooling
|
|
||||||
|
@ -343,6 +343,7 @@ cc_library(
|
|||||||
":cuda_platform_id",
|
":cuda_platform_id",
|
||||||
":cuda_stream",
|
":cuda_stream",
|
||||||
":cuda_timer",
|
":cuda_timer",
|
||||||
|
":cuda_helpers",
|
||||||
":cudnn_version",
|
":cudnn_version",
|
||||||
":cudnn_lib",
|
":cudnn_lib",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
@ -480,7 +481,8 @@ cc_library(
|
|||||||
# TODO(leary) we likely need to canonicalize/eliminate this.
|
# TODO(leary) we likely need to canonicalize/eliminate this.
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cuda_helpers",
|
name = "cuda_helpers",
|
||||||
textual_hdrs = if_cuda_is_configured(["cuda_helpers.h"]),
|
srcs = if_cuda_is_configured(["cuda_helpers.cc"]),
|
||||||
|
hdrs = if_cuda_is_configured(["cuda_helpers.h"]),
|
||||||
deps = if_cuda_is_configured([
|
deps = if_cuda_is_configured([
|
||||||
"//tensorflow/stream_executor/gpu:gpu_helpers_header",
|
"//tensorflow/stream_executor/gpu:gpu_helpers_header",
|
||||||
]),
|
]),
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
|
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
|
||||||
#include "tensorflow/stream_executor/cuda/cuda_driver.h"
|
#include "tensorflow/stream_executor/cuda/cuda_driver.h"
|
||||||
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
|
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
|
||||||
|
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
|
||||||
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
|
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
|
||||||
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
|
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
|
||||||
#include "tensorflow/stream_executor/cuda/cuda_timer.h"
|
#include "tensorflow/stream_executor/cuda/cuda_timer.h"
|
||||||
@ -630,22 +631,6 @@ bool BatchnormSpatialPersistentEnabled() {
|
|||||||
return is_enabled;
|
return is_enabled;
|
||||||
}
|
}
|
||||||
|
|
||||||
// A helper function to decide whether to enable deterministic functionality.
|
|
||||||
bool RequireDeterminism() {
|
|
||||||
static bool require_determinism = [] {
|
|
||||||
bool deterministic_ops = false;
|
|
||||||
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
|
|
||||||
/*default_val=*/false,
|
|
||||||
&deterministic_ops));
|
|
||||||
bool cudnn_deterministic = false;
|
|
||||||
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
|
|
||||||
/*default_val=*/false,
|
|
||||||
&cudnn_deterministic));
|
|
||||||
return deterministic_ops || cudnn_deterministic;
|
|
||||||
}();
|
|
||||||
return require_determinism;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<int, int> GetCcMajorMinor(Stream* stream) {
|
std::tuple<int, int> GetCcMajorMinor(Stream* stream) {
|
||||||
int cc_major, cc_minor;
|
int cc_major, cc_minor;
|
||||||
stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
|
stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
|
||||||
@ -744,9 +729,10 @@ class CudnnPoolingDescriptor {
|
|||||||
std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
|
std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
|
||||||
&CheckedNarrowing<int64, int>);
|
&CheckedNarrowing<int64, int>);
|
||||||
bool propagate_nans = pooling_descriptor.propagate_nans();
|
bool propagate_nans = pooling_descriptor.propagate_nans();
|
||||||
const auto cudnn_max_pooling_mode = RequireDeterminism()
|
const auto cudnn_max_pooling_mode =
|
||||||
? CUDNN_POOLING_MAX_DETERMINISTIC
|
stream_executor::cuda::RequireCuDNNDeterminism()
|
||||||
: CUDNN_POOLING_MAX;
|
? CUDNN_POOLING_MAX_DETERMINISTIC
|
||||||
|
: CUDNN_POOLING_MAX;
|
||||||
CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor(
|
CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor(
|
||||||
handle_.get(),
|
handle_.get(),
|
||||||
(pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
|
(pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
|
||||||
@ -3247,16 +3233,10 @@ bool CudnnSupport::GetConvolveAlgorithms(
|
|||||||
bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
|
bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
|
||||||
out_algorithms->clear();
|
out_algorithms->clear();
|
||||||
|
|
||||||
if (RequireDeterminism()) {
|
|
||||||
out_algorithms->push_back({CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
|
|
||||||
tensor_op_math_available});
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<dnn::AlgorithmDesc::Index> algo_types = {
|
std::vector<dnn::AlgorithmDesc::Index> algo_types = {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
|
||||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
|
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
|
||||||
|
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||||
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||||
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||||
@ -3270,11 +3250,12 @@ bool CudnnSupport::GetConvolveAlgorithms(
|
|||||||
algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
|
algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The algorithms are intentionally ordered for deterministic operation
|
||||||
for (auto i : algo_types) {
|
for (auto i : algo_types) {
|
||||||
out_algorithms->push_back({i, /*use_tensor_ops=*/false});
|
|
||||||
if (tensor_op_math_available) {
|
if (tensor_op_math_available) {
|
||||||
out_algorithms->push_back({i, /*use_tensor_ops=*/true});
|
out_algorithms->push_back({i, /*use_tensor_ops=*/true});
|
||||||
}
|
}
|
||||||
|
out_algorithms->push_back({i, /*use_tensor_ops=*/false});
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -3308,15 +3289,8 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
|
|||||||
bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
|
bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
|
||||||
out_algorithms->clear();
|
out_algorithms->clear();
|
||||||
|
|
||||||
if (RequireDeterminism()) {
|
|
||||||
out_algorithms->push_back(
|
|
||||||
{CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, tensor_op_math_available});
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<dnn::AlgorithmDesc::Index> algo_types = {
|
std::vector<dnn::AlgorithmDesc::Index> algo_types = {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
|
|
||||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
|
||||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
|
||||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
|
||||||
@ -3326,12 +3300,16 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
|
|||||||
if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
|
if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
|
||||||
algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
|
algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
|
||||||
}
|
}
|
||||||
|
if (!stream_executor::cuda::RequireCuDNNDeterminism()) {
|
||||||
|
algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The algorithms are intentionally ordered for deterministic operation
|
||||||
for (auto i : algo_types) {
|
for (auto i : algo_types) {
|
||||||
out_algorithms->push_back({i, /*use_tensor_ops=*/false});
|
|
||||||
if (tensor_op_math_available) {
|
if (tensor_op_math_available) {
|
||||||
out_algorithms->push_back({i, /*use_tensor_ops=*/true});
|
out_algorithms->push_back({i, /*use_tensor_ops=*/true});
|
||||||
}
|
}
|
||||||
|
out_algorithms->push_back({i, /*use_tensor_ops=*/false});
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -3343,18 +3321,10 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
|
|||||||
bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
|
bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
|
||||||
out_algorithms->clear();
|
out_algorithms->clear();
|
||||||
|
|
||||||
if (RequireDeterminism()) {
|
|
||||||
out_algorithms->push_back(
|
|
||||||
{CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, tensor_op_math_available});
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<dnn::AlgorithmDesc::Index> algo_types = {
|
std::vector<dnn::AlgorithmDesc::Index> algo_types = {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
|
|
||||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
|
||||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
|
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
|
||||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
|
|
||||||
// Based on cudnn.h, the following is not implemented.
|
// Based on cudnn.h, the following is not implemented.
|
||||||
// CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
|
// CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
|
||||||
|
|
||||||
@ -3366,12 +3336,17 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
|
|||||||
if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
|
if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
|
||||||
algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
|
algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
|
||||||
}
|
}
|
||||||
|
if (!stream_executor::cuda::RequireCuDNNDeterminism()) {
|
||||||
|
algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0);
|
||||||
|
algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The algorithms are intentionally ordered for deterministic operation
|
||||||
for (auto i : algo_types) {
|
for (auto i : algo_types) {
|
||||||
out_algorithms->push_back({i, /*use_tensor_ops=*/false});
|
|
||||||
if (tensor_op_math_available) {
|
if (tensor_op_math_available) {
|
||||||
out_algorithms->push_back({i, /*use_tensor_ops=*/true});
|
out_algorithms->push_back({i, /*use_tensor_ops=*/true});
|
||||||
}
|
}
|
||||||
|
out_algorithms->push_back({i, /*use_tensor_ops=*/false});
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
39
tensorflow/stream_executor/cuda/cuda_helpers.cc
Normal file
39
tensorflow/stream_executor/cuda/cuda_helpers.cc
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/util/env_var.h"
|
||||||
|
|
||||||
|
namespace stream_executor {
|
||||||
|
namespace cuda {
|
||||||
|
|
||||||
|
bool RequireCuDNNDeterminism() {
|
||||||
|
static bool require_cudnn_determinism = [] {
|
||||||
|
bool deterministic_ops = false;
|
||||||
|
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
|
||||||
|
/*default_val=*/false,
|
||||||
|
&deterministic_ops));
|
||||||
|
bool cudnn_deterministic = false;
|
||||||
|
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
|
||||||
|
/*default_val=*/false,
|
||||||
|
&cudnn_deterministic));
|
||||||
|
return deterministic_ops || cudnn_deterministic;
|
||||||
|
}();
|
||||||
|
return require_cudnn_determinism;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cuda
|
||||||
|
} // namespace stream_executor
|
@ -22,4 +22,14 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
|
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
|
||||||
|
|
||||||
|
namespace stream_executor {
|
||||||
|
namespace cuda {
|
||||||
|
|
||||||
|
// A helper function to decide whether to enable deterministic cuDNN
|
||||||
|
// functionality.
|
||||||
|
bool RequireCuDNNDeterminism();
|
||||||
|
|
||||||
|
} // namespace cuda
|
||||||
|
} // namespace stream_executor
|
||||||
|
|
||||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_HELPERS_H_
|
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_HELPERS_H_
|
||||||
|
Loading…
x
Reference in New Issue
Block a user