PR #37679: Support two CUDNN CTC Loss algorithms

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/37679

This PR enables CUDNN CTC Loss to support both deterministic and non-deterministic algos.

If determinism is required, we will stick to deterministic algo.
Otherwise, the faster non-deterministic algo will be tried in cudnnGetCtcLossWorkspace(). If it fails, we fall-back to the deterministic algo.

fyi @nluehr @sanjoy
Copybara import of the project:

--
e338f0bad8a10c6b0e6284e609dfe2b6ebeab9f0 by Kaixi Hou <kaixih@nvidia.com>:

CTC loss support two algos

--
c90ebbd533202c791e342b5716c3410c4557ea1a by Kaixi Hou <kaixih@nvidia.com>:

Reuse macros of cudnn launch

--
c9fd60fa4abc37674923acb802e46a7ac8055796 by Kaixi Hou <kaixih@nvidia.com>:

Add unittest for ctc_loss fallback algo

--
e29100d0202e2a519845b818d6f9603505d2c83e by Kaixi Hou <kaixih@nvidia.com>:

Shorten the comment sentence

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/37679 from houtoms:ctc_loss_workspace_check e29100d0202e2a519845b818d6f9603505d2c83e
PiperOrigin-RevId: 305942210
Change-Id: I57062bcb5f04097a3280eb6eeb5de51bae6ef3ca
This commit is contained in:
Kaixi Hou 2020-04-10 14:11:55 -07:00 committed by TensorFlower Gardener
parent b43ff5b8dd
commit d779e8431a
8 changed files with 144 additions and 102 deletions

View File

@ -941,6 +941,21 @@ class CTCLossTestV2(test.TestCase):
[[1.0, 2.0], [5.0, 8.0], [14.0, 20.0]], out)
def _ctc_loss_v3(labels, logits, label_length, logit_length, use_gpu):
with test_util.device(use_gpu=use_gpu):
sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
with backprop.GradientTape() as t:
t.watch(logits)
ref_loss = ctc_ops.ctc_loss_v3(
labels=sparse_labels,
logits=logits,
label_length=label_length,
logit_length=logit_length,
blank_index=0)
ref_grad = t.gradient(ref_loss, [logits])
return ref_loss, ref_grad
@test_util.run_all_in_graph_and_eager_modes
class CTCLossTestV3(test.TestCase, parameterized.TestCase):
@ -978,42 +993,47 @@ class CTCLossTestV3(test.TestCase, parameterized.TestCase):
labels *= label_mask
logit_length = [num_frames] * batch_size
def ctc_loss_cpu(labels, logits, label_length, logit_length):
with test_util.device(use_gpu=False):
sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
with backprop.GradientTape() as t:
t.watch(logits)
ref_loss = ctc_ops.ctc_loss_v3(
labels=sparse_labels,
logits=logits,
label_length=label_length,
logit_length=logit_length,
blank_index=0)
ref_grad = t.gradient(ref_loss, [logits])
return ref_loss, ref_grad
def ctc_loss_gpu(labels, logits, label_length, logit_length):
with test_util.device(use_gpu=True):
sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
with backprop.GradientTape() as t:
t.watch(logits)
loss = ctc_ops.ctc_loss_v3(
labels=sparse_labels,
logits=logits,
label_length=label_length,
logit_length=logit_length,
blank_index=0)
grad = t.gradient(loss, [logits])
return loss, grad
if run_tf_func:
ctc_loss_cpu = def_function.function(ctc_loss_cpu)
ctc_loss_gpu = def_function.function(ctc_loss_gpu)
ctc_loss = def_function.function(_ctc_loss_v3)
else:
ctc_loss = _ctc_loss_v3
ref_loss, ref_grad = ctc_loss_cpu(labels, logits, label_length,
logit_length)
loss, grad = ctc_loss_gpu(labels, logits, label_length, logit_length)
ref_loss, ref_grad = ctc_loss(labels, logits, label_length, logit_length,
False)
loss, grad = ctc_loss(labels, logits, label_length, logit_length, True)
self.assertAllClose(loss, ref_loss, atol=1e-6)
self.assertAllClose(grad, ref_grad, atol=2e-6)
@test_util.run_v2_only
def testCtcLossAlgorithmFallback(self):
"""Test if GPU CTC loss can fallback to the correct algorithm."""
if not test.is_gpu_available():
self.skipTest("Need GPU for testing.")
if not context.executing_eagerly():
self.skipTest("Need eager execution for testing.")
random_seed.set_random_seed(5)
batch_size = 1
num_labels = 11777
max_label_length = 2
num_frames = 1
labels = random_ops.random_uniform([batch_size, max_label_length],
minval=1,
maxval=num_labels,
dtype=dtypes.int64)
logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
label_length = random_ops.random_uniform([batch_size],
minval=1,
maxval=max_label_length,
dtype=dtypes.int64)
logit_length = [num_frames] * batch_size
loss, grad = _ctc_loss_v3(labels, logits, label_length, logit_length, True)
ref_loss, ref_grad = _ctc_loss_v3(labels, logits, label_length,
logit_length, False)
self.assertAllClose(loss, ref_loss, atol=1e-6)
self.assertAllClose(grad, ref_grad, atol=2e-6)

View File

@ -2045,7 +2045,7 @@ port::Status CudnnSupport::DoCtcLossImpl(
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
const CudnnRnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
DeviceMemory<uint8> scratch_memory) {
DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
auto cudnn = cudnn_->GetHandle(parent_, stream);
int kNumTimestamps = probs_desc.num_layers();
@ -2055,6 +2055,8 @@ port::Status CudnnSupport::DoCtcLossImpl(
(void)total_size;
#if CUDNN_VERSION >= 7603
cudnnCTCLossAlgo_t ctc_loss_algo =
static_cast<cudnnCTCLossAlgo_t>(ctc_loss_algo_id);
RETURN_IF_CUDNN_ERROR(cudnnCTCLoss(
/*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
/*probs=*/probs_data.opaque(), /*labels=*/labels_data.data(),
@ -2062,9 +2064,7 @@ port::Status CudnnSupport::DoCtcLossImpl(
/*inputLengths=*/input_lengths_data.data(),
/*costs=*/costs_data.opaque(), /*gradientsDesc=*/grads_desc.handle(),
/*gradients=*/grads_data.opaque(),
/*algo=*/
RequireCudnnDeterminism() ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
: CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
/*algo=*/ctc_loss_algo,
/*ctcLossDesc=*/ctc_loss_desc.handle(),
/*workspace=*/scratch_memory.opaque(),
/*workSpaceSizeInBytes=*/scratch_memory.size()));
@ -3935,7 +3935,8 @@ port::Status CudnnSupport::DoPrepareForCtcLoss(
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory) {
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
int* ctc_loss_algo_id) {
auto cudnn = cudnn_->GetHandle(parent_, stream);
// Query the workspace size.
size_t workspace_size_in_bytes = 0;
@ -3945,17 +3946,38 @@ port::Status CudnnSupport::DoPrepareForCtcLoss(
static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
// Try running with `algo`, if successful then pick it. The non-deterministic
// algorithm is first and thus preferentially picked when determinism is not
// required.
auto algo = RequireCudnnDeterminism() ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
: CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC;
cudnnStatus_t status = cudnnGetCTCLossWorkspaceSize(
/*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
/*gradientsDesc=*/cudnn_grads_desc.handle(),
/*labels=*/labels_data.data(),
/*labelLengths=*/labels_lengths_data.data(),
/*inputLengths=*/input_lengths_data.data(),
/*algo=*/algo,
/*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
/*sizeInBytes=*/&workspace_size_in_bytes);
if (RequireCudnnDeterminism()) {
RETURN_IF_CUDNN_ERROR(status);
}
if (status != CUDNN_STATUS_SUCCESS) {
algo = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC;
RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
/*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
/*gradientsDesc=*/cudnn_grads_desc.handle(),
/*labels=*/labels_data.data(),
/*labelLengths=*/labels_lengths_data.data(),
/*inputLengths=*/input_lengths_data.data(),
/*algo=*/
RequireCudnnDeterminism() ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
: CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
/*algo=*/algo,
/*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
/*sizeInBytes=*/&workspace_size_in_bytes));
}
*ctc_loss_algo_id = algo;
#else
return port::Status(port::error::INVALID_ARGUMENT,
"No supported cudnnGetCTCLossWorkspaceSize when "
@ -3979,13 +4001,12 @@ port::Status CudnnSupport::DoPrepareForCtcLoss(
port::Status CudnnSupport::DoCtcLoss(
Stream* stream, dnn::DataType element_type,
const dnn::RnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data,
absl::Span<const int> labels_data,
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
const dnn::RnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory) {
DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory,
int ctc_loss_algo_id) {
// Current cuDNN CTC Loss only supports the float datatype
if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) {
return port::Status(port::error::INVALID_ARGUMENT,
@ -4000,7 +4021,7 @@ port::Status CudnnSupport::DoCtcLoss(
return DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data,
labels_lengths_data, input_lengths_data, costs_data,
cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc,
scratch_memory);
scratch_memory, ctc_loss_algo_id);
}
bool CudnnSupport::DoTransformTensor(Stream* stream,

View File

@ -570,7 +570,8 @@ class CudnnSupport : public dnn::DnnSupport {
DeviceMemoryBase costs_data,
const dnn::RnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data,
DeviceMemory<uint8> scratch_memory) override;
DeviceMemory<uint8> scratch_memory,
int ctc_loss_algo_id) override;
bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
dnn::DataType input_type,
@ -689,7 +690,7 @@ class CudnnSupport : public dnn::DnnSupport {
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
const CudnnRnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
DeviceMemory<uint8> scratch_memory);
DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id);
private:
port::Status DoPrepareForConvolution(
@ -711,8 +712,8 @@ class CudnnSupport : public dnn::DnnSupport {
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
ScratchAllocator* scratch_allocator,
DeviceMemory<uint8>* scratch_memory) override;
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
int* ctc_loss_algo_id) override;
SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
};

View File

@ -618,16 +618,14 @@ bool DnnSupport::IsStatusOk(const port::Status& status, bool report_error) {
return false;
}
port::Status DnnSupport::DoCtcLoss(Stream* stream, dnn::DataType element_type,
port::Status DnnSupport::DoCtcLoss(
Stream* stream, dnn::DataType element_type,
const RnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data,
absl::Span<const int> labels_data,
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
DeviceMemoryBase costs_data,
const RnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data,
DeviceMemory<uint8> scratch_memory) {
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
const RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data,
DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
return port::UnimplementedError("CtcLoss not implemented");
}

View File

@ -2396,11 +2396,12 @@ class DnnSupport {
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
ScratchAllocator* workspace_allocator,
DeviceMemory<uint8>* scratch_memory) {
return DoPrepareForCtcLoss(stream, ToDataType<ElementType>::value,
probs_desc, grads_desc, labels_data,
labels_lengths_data, input_lengths_data,
workspace_allocator, scratch_memory);
DeviceMemory<uint8>* scratch_memory,
int* ctc_loss_algo_id) {
return DoPrepareForCtcLoss(
stream, ToDataType<ElementType>::value, probs_desc, grads_desc,
labels_data, labels_lengths_data, input_lengths_data,
workspace_allocator, scratch_memory, ctc_loss_algo_id);
}
// Enqueue a CTC Loss operation onto the stream.
@ -2424,16 +2425,14 @@ class DnnSupport {
// workspace memory used by this operation. The caller is responsible for
// keeping the memory alive long enough for this operation, and recylces
// afterwards.
virtual port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
virtual port::Status DoCtcLoss(
Stream* stream, dnn::DataType element_type,
const RnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data,
absl::Span<const int> labels_data,
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
DeviceMemoryBase costs_data,
const RnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data,
DeviceMemory<uint8> scratch_memory);
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
const RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data,
DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id);
template <typename ElementType>
bool DoCtcLoss(Stream* stream,
@ -2445,12 +2444,12 @@ class DnnSupport {
DeviceMemory<ElementType>* costs_data,
const dnn::RnnStateTensorDescriptor& grads_desc,
DeviceMemory<ElementType>* grads_data,
DeviceMemory<uint8>* scratch_memory) {
DeviceMemory<uint8>* scratch_memory, int ctc_loss_algo_id) {
return IsStatusOk(
DoCtcLoss(stream, ToDataType<ElementType>::value, probs_desc,
probs_data, labels_data, labels_lengths_data,
input_lengths_data, *costs_data, grads_desc, *grads_data,
*scratch_memory),
*scratch_memory, ctc_loss_algo_id),
false);
}
@ -2716,8 +2715,8 @@ class DnnSupport {
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
ScratchAllocator* scratch_allocator,
DeviceMemory<uint8>* scratch_memory) {
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
int* ctc_loss_algo_id) {
*scratch_memory = {};
return port::Status::OK();
}

View File

@ -2393,7 +2393,8 @@ port::Status MIOpenSupport::DoPrepareForCtcLoss(
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory) {
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
int* ctc_loss_algo_id) {
auto miopen = miopen_->GetHandle(parent_, stream);
MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type));
@ -2456,7 +2457,7 @@ port::Status MIOpenSupport::DoCtcLossImpl(
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
const MIOpenRnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
DeviceMemory<uint8> scratch_memory) {
DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
auto miopen = miopen_->GetHandle(parent_, stream);
int kNumTimestamps = probs_desc.num_layers();
@ -2482,13 +2483,12 @@ port::Status MIOpenSupport::DoCtcLossImpl(
port::Status MIOpenSupport::DoCtcLoss(
Stream* stream, dnn::DataType element_type,
const dnn::RnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data,
absl::Span<const int> labels_data,
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
const dnn::RnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory) {
DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory,
int ctc_loss_algo_id) {
// Current MIOPen CTC Loss only supports the float datatype
if (element_type != dnn::DataType::kFloat) {
return port::Status(port::error::INVALID_ARGUMENT,
@ -2507,7 +2507,7 @@ port::Status MIOpenSupport::DoCtcLoss(
return DoCtcLossImpl(stream, miopen_probs_desc, probs_data, labels_data,
labels_lengths_data, input_lengths_data, costs_data,
miopen_grads_desc, grads_data, miopen_ctc_loss_desc,
scratch_memory);
scratch_memory, ctc_loss_algo_id);
}
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>

View File

@ -648,7 +648,8 @@ class MIOpenSupport : public dnn::DnnSupport {
DeviceMemoryBase costs_data,
const dnn::RnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data,
DeviceMemory<uint8> scratch_memory) override;
DeviceMemory<uint8> scratch_memory,
int ctc_loss_algo_id) override;
private:
GpuExecutor* parent_; // Parent executor object. Not owned.
@ -812,7 +813,7 @@ class MIOpenSupport : public dnn::DnnSupport {
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
const MIOpenRnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
DeviceMemory<uint8> scratch_memory);
DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id);
port::Status DoPrepareForCtcLoss(
Stream* stream, dnn::DataType element_type,
@ -821,8 +822,8 @@ class MIOpenSupport : public dnn::DnnSupport {
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
ScratchAllocator* scratch_allocator,
DeviceMemory<uint8>* scratch_memory) override;
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
int* ctc_loss_algo_id) override;
bool GetMIOpenConvolveAlgorithmsImmediateMode(
dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,

View File

@ -5251,16 +5251,18 @@ Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
DeviceMemory<uint8> scratch_memory;
auto status = dnn->PrepareForCtcLoss(
this, probs_desc, probs_data, grads_desc,
labels_data, labels_lengths_data, input_lengths_data,
workspace_allocator, &scratch_memory)
int ctc_loss_algo_id;
auto status =
dnn->PrepareForCtcLoss(this, probs_desc, probs_data, grads_desc,
labels_data, labels_lengths_data,
input_lengths_data, workspace_allocator,
&scratch_memory, &ctc_loss_algo_id)
.ok();
if (status) {
status =
dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data,
labels_lengths_data, input_lengths_data, costs_data,
grads_desc, grads_data, &scratch_memory);
status = dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data,
labels_lengths_data, input_lengths_data,
costs_data, grads_desc, grads_data,
&scratch_memory, ctc_loss_algo_id);
}
if (!status) {
SetError();