diff --git a/tensorflow/python/kernel_tests/ctc_loss_op_test.py b/tensorflow/python/kernel_tests/ctc_loss_op_test.py index 058437cc04e..9b94536de0a 100644 --- a/tensorflow/python/kernel_tests/ctc_loss_op_test.py +++ b/tensorflow/python/kernel_tests/ctc_loss_op_test.py @@ -941,6 +941,21 @@ class CTCLossTestV2(test.TestCase): [[1.0, 2.0], [5.0, 8.0], [14.0, 20.0]], out) +def _ctc_loss_v3(labels, logits, label_length, logit_length, use_gpu): + with test_util.device(use_gpu=use_gpu): + sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length) + with backprop.GradientTape() as t: + t.watch(logits) + ref_loss = ctc_ops.ctc_loss_v3( + labels=sparse_labels, + logits=logits, + label_length=label_length, + logit_length=logit_length, + blank_index=0) + ref_grad = t.gradient(ref_loss, [logits]) + return ref_loss, ref_grad + + @test_util.run_all_in_graph_and_eager_modes class CTCLossTestV3(test.TestCase, parameterized.TestCase): @@ -978,42 +993,47 @@ class CTCLossTestV3(test.TestCase, parameterized.TestCase): labels *= label_mask logit_length = [num_frames] * batch_size - def ctc_loss_cpu(labels, logits, label_length, logit_length): - with test_util.device(use_gpu=False): - sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length) - with backprop.GradientTape() as t: - t.watch(logits) - ref_loss = ctc_ops.ctc_loss_v3( - labels=sparse_labels, - logits=logits, - label_length=label_length, - logit_length=logit_length, - blank_index=0) - ref_grad = t.gradient(ref_loss, [logits]) - return ref_loss, ref_grad - - def ctc_loss_gpu(labels, logits, label_length, logit_length): - with test_util.device(use_gpu=True): - sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length) - with backprop.GradientTape() as t: - t.watch(logits) - loss = ctc_ops.ctc_loss_v3( - labels=sparse_labels, - logits=logits, - label_length=label_length, - logit_length=logit_length, - blank_index=0) - grad = t.gradient(loss, [logits]) - - return loss, grad - if run_tf_func: - ctc_loss_cpu = def_function.function(ctc_loss_cpu) - ctc_loss_gpu = def_function.function(ctc_loss_gpu) + ctc_loss = def_function.function(_ctc_loss_v3) + else: + ctc_loss = _ctc_loss_v3 - ref_loss, ref_grad = ctc_loss_cpu(labels, logits, label_length, - logit_length) - loss, grad = ctc_loss_gpu(labels, logits, label_length, logit_length) + ref_loss, ref_grad = ctc_loss(labels, logits, label_length, logit_length, + False) + loss, grad = ctc_loss(labels, logits, label_length, logit_length, True) + + self.assertAllClose(loss, ref_loss, atol=1e-6) + self.assertAllClose(grad, ref_grad, atol=2e-6) + + @test_util.run_v2_only + def testCtcLossAlgorithmFallback(self): + """Test if GPU CTC loss can fallback to the correct algorithm.""" + if not test.is_gpu_available(): + self.skipTest("Need GPU for testing.") + if not context.executing_eagerly(): + self.skipTest("Need eager execution for testing.") + random_seed.set_random_seed(5) + + batch_size = 1 + num_labels = 11777 + max_label_length = 2 + num_frames = 1 + + labels = random_ops.random_uniform([batch_size, max_label_length], + minval=1, + maxval=num_labels, + dtype=dtypes.int64) + logits = random_ops.random_uniform([num_frames, batch_size, num_labels]) + + label_length = random_ops.random_uniform([batch_size], + minval=1, + maxval=max_label_length, + dtype=dtypes.int64) + logit_length = [num_frames] * batch_size + + loss, grad = _ctc_loss_v3(labels, logits, label_length, logit_length, True) + ref_loss, ref_grad = _ctc_loss_v3(labels, logits, label_length, + logit_length, False) self.assertAllClose(loss, ref_loss, atol=1e-6) self.assertAllClose(grad, ref_grad, atol=2e-6) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 5fec1b5990e..6122877f91f 100755 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -2045,7 +2045,7 @@ port::Status CudnnSupport::DoCtcLossImpl( absl::Span input_lengths_data, DeviceMemoryBase costs_data, const CudnnRnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc, - DeviceMemory scratch_memory) { + DeviceMemory scratch_memory, int ctc_loss_algo_id) { auto cudnn = cudnn_->GetHandle(parent_, stream); int kNumTimestamps = probs_desc.num_layers(); @@ -2055,6 +2055,8 @@ port::Status CudnnSupport::DoCtcLossImpl( (void)total_size; #if CUDNN_VERSION >= 7603 + cudnnCTCLossAlgo_t ctc_loss_algo = + static_cast(ctc_loss_algo_id); RETURN_IF_CUDNN_ERROR(cudnnCTCLoss( /*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(), /*probs=*/probs_data.opaque(), /*labels=*/labels_data.data(), @@ -2062,9 +2064,7 @@ port::Status CudnnSupport::DoCtcLossImpl( /*inputLengths=*/input_lengths_data.data(), /*costs=*/costs_data.opaque(), /*gradientsDesc=*/grads_desc.handle(), /*gradients=*/grads_data.opaque(), - /*algo=*/ - RequireCudnnDeterminism() ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC - : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC, + /*algo=*/ctc_loss_algo, /*ctcLossDesc=*/ctc_loss_desc.handle(), /*workspace=*/scratch_memory.opaque(), /*workSpaceSizeInBytes=*/scratch_memory.size())); @@ -3935,7 +3935,8 @@ port::Status CudnnSupport::DoPrepareForCtcLoss( absl::Span labels_data, absl::Span labels_lengths_data, absl::Span input_lengths_data, - ScratchAllocator* scratch_allocator, DeviceMemory* scratch_memory) { + ScratchAllocator* scratch_allocator, DeviceMemory* scratch_memory, + int* ctc_loss_algo_id) { auto cudnn = cudnn_->GetHandle(parent_, stream); // Query the workspace size. size_t workspace_size_in_bytes = 0; @@ -3945,17 +3946,38 @@ port::Status CudnnSupport::DoPrepareForCtcLoss( static_cast(probs_desc); const CudnnRnnStateTensorDescriptor& cudnn_grads_desc = static_cast(grads_desc); - RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize( + + // Try running with `algo`, if successful then pick it. The non-deterministic + // algorithm is first and thus preferentially picked when determinism is not + // required. + auto algo = RequireCudnnDeterminism() ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC + : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC; + cudnnStatus_t status = cudnnGetCTCLossWorkspaceSize( /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(), /*gradientsDesc=*/cudnn_grads_desc.handle(), /*labels=*/labels_data.data(), /*labelLengths=*/labels_lengths_data.data(), /*inputLengths=*/input_lengths_data.data(), - /*algo=*/ - RequireCudnnDeterminism() ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC - : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC, + /*algo=*/algo, /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(), - /*sizeInBytes=*/&workspace_size_in_bytes)); + /*sizeInBytes=*/&workspace_size_in_bytes); + if (RequireCudnnDeterminism()) { + RETURN_IF_CUDNN_ERROR(status); + } + + if (status != CUDNN_STATUS_SUCCESS) { + algo = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC; + RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize( + /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(), + /*gradientsDesc=*/cudnn_grads_desc.handle(), + /*labels=*/labels_data.data(), + /*labelLengths=*/labels_lengths_data.data(), + /*inputLengths=*/input_lengths_data.data(), + /*algo=*/algo, + /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(), + /*sizeInBytes=*/&workspace_size_in_bytes)); + } + *ctc_loss_algo_id = algo; #else return port::Status(port::error::INVALID_ARGUMENT, "No supported cudnnGetCTCLossWorkspaceSize when " @@ -3979,13 +4001,12 @@ port::Status CudnnSupport::DoPrepareForCtcLoss( port::Status CudnnSupport::DoCtcLoss( Stream* stream, dnn::DataType element_type, const dnn::RnnStateTensorDescriptor& probs_desc, - const DeviceMemoryBase probs_data, - - absl::Span labels_data, + const DeviceMemoryBase probs_data, absl::Span labels_data, absl::Span labels_lengths_data, absl::Span input_lengths_data, DeviceMemoryBase costs_data, const dnn::RnnStateTensorDescriptor& grads_desc, - DeviceMemoryBase grads_data, DeviceMemory scratch_memory) { + DeviceMemoryBase grads_data, DeviceMemory scratch_memory, + int ctc_loss_algo_id) { // Current cuDNN CTC Loss only supports the float datatype if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) { return port::Status(port::error::INVALID_ARGUMENT, @@ -4000,7 +4021,7 @@ port::Status CudnnSupport::DoCtcLoss( return DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data, labels_lengths_data, input_lengths_data, costs_data, cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc, - scratch_memory); + scratch_memory, ctc_loss_algo_id); } bool CudnnSupport::DoTransformTensor(Stream* stream, diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 89c61789b47..181502e03ee 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -570,7 +570,8 @@ class CudnnSupport : public dnn::DnnSupport { DeviceMemoryBase costs_data, const dnn::RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data, - DeviceMemory scratch_memory) override; + DeviceMemory scratch_memory, + int ctc_loss_algo_id) override; bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc, dnn::DataType input_type, @@ -689,7 +690,7 @@ class CudnnSupport : public dnn::DnnSupport { absl::Span input_lengths_data, DeviceMemoryBase costs_data, const CudnnRnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc, - DeviceMemory scratch_memory); + DeviceMemory scratch_memory, int ctc_loss_algo_id); private: port::Status DoPrepareForConvolution( @@ -711,8 +712,8 @@ class CudnnSupport : public dnn::DnnSupport { absl::Span labels_data, absl::Span labels_lengths_data, absl::Span input_lengths_data, - ScratchAllocator* scratch_allocator, - DeviceMemory* scratch_memory) override; + ScratchAllocator* scratch_allocator, DeviceMemory* scratch_memory, + int* ctc_loss_algo_id) override; SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport); }; diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc index 85b9a02d742..6aba656fc68 100644 --- a/tensorflow/stream_executor/dnn.cc +++ b/tensorflow/stream_executor/dnn.cc @@ -618,16 +618,14 @@ bool DnnSupport::IsStatusOk(const port::Status& status, bool report_error) { return false; } -port::Status DnnSupport::DoCtcLoss(Stream* stream, dnn::DataType element_type, - const RnnStateTensorDescriptor& probs_desc, - const DeviceMemoryBase probs_data, - absl::Span labels_data, - absl::Span labels_lengths_data, - absl::Span input_lengths_data, - DeviceMemoryBase costs_data, - const RnnStateTensorDescriptor& grads_desc, - DeviceMemoryBase grads_data, - DeviceMemory scratch_memory) { +port::Status DnnSupport::DoCtcLoss( + Stream* stream, dnn::DataType element_type, + const RnnStateTensorDescriptor& probs_desc, + const DeviceMemoryBase probs_data, absl::Span labels_data, + absl::Span labels_lengths_data, + absl::Span input_lengths_data, DeviceMemoryBase costs_data, + const RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data, + DeviceMemory scratch_memory, int ctc_loss_algo_id) { return port::UnimplementedError("CtcLoss not implemented"); } diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 8e7f8790ab9..7b45ec2cc87 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -2396,11 +2396,12 @@ class DnnSupport { absl::Span labels_lengths_data, absl::Span input_lengths_data, ScratchAllocator* workspace_allocator, - DeviceMemory* scratch_memory) { - return DoPrepareForCtcLoss(stream, ToDataType::value, - probs_desc, grads_desc, labels_data, - labels_lengths_data, input_lengths_data, - workspace_allocator, scratch_memory); + DeviceMemory* scratch_memory, + int* ctc_loss_algo_id) { + return DoPrepareForCtcLoss( + stream, ToDataType::value, probs_desc, grads_desc, + labels_data, labels_lengths_data, input_lengths_data, + workspace_allocator, scratch_memory, ctc_loss_algo_id); } // Enqueue a CTC Loss operation onto the stream. @@ -2424,16 +2425,14 @@ class DnnSupport { // workspace memory used by this operation. The caller is responsible for // keeping the memory alive long enough for this operation, and recylces // afterwards. - virtual port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type, - const RnnStateTensorDescriptor& probs_desc, - const DeviceMemoryBase probs_data, - absl::Span labels_data, - absl::Span labels_lengths_data, - absl::Span input_lengths_data, - DeviceMemoryBase costs_data, - const RnnStateTensorDescriptor& grads_desc, - DeviceMemoryBase grads_data, - DeviceMemory scratch_memory); + virtual port::Status DoCtcLoss( + Stream* stream, dnn::DataType element_type, + const RnnStateTensorDescriptor& probs_desc, + const DeviceMemoryBase probs_data, absl::Span labels_data, + absl::Span labels_lengths_data, + absl::Span input_lengths_data, DeviceMemoryBase costs_data, + const RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data, + DeviceMemory scratch_memory, int ctc_loss_algo_id); template bool DoCtcLoss(Stream* stream, @@ -2445,12 +2444,12 @@ class DnnSupport { DeviceMemory* costs_data, const dnn::RnnStateTensorDescriptor& grads_desc, DeviceMemory* grads_data, - DeviceMemory* scratch_memory) { + DeviceMemory* scratch_memory, int ctc_loss_algo_id) { return IsStatusOk( DoCtcLoss(stream, ToDataType::value, probs_desc, probs_data, labels_data, labels_lengths_data, input_lengths_data, *costs_data, grads_desc, *grads_data, - *scratch_memory), + *scratch_memory, ctc_loss_algo_id), false); } @@ -2716,8 +2715,8 @@ class DnnSupport { absl::Span labels_data, absl::Span labels_lengths_data, absl::Span input_lengths_data, - ScratchAllocator* scratch_allocator, - DeviceMemory* scratch_memory) { + ScratchAllocator* scratch_allocator, DeviceMemory* scratch_memory, + int* ctc_loss_algo_id) { *scratch_memory = {}; return port::Status::OK(); } diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.cc b/tensorflow/stream_executor/rocm/rocm_dnn.cc index 7f138a4048b..e0ead6d57e8 100644 --- a/tensorflow/stream_executor/rocm/rocm_dnn.cc +++ b/tensorflow/stream_executor/rocm/rocm_dnn.cc @@ -2393,7 +2393,8 @@ port::Status MIOpenSupport::DoPrepareForCtcLoss( absl::Span labels_data, absl::Span labels_lengths_data, absl::Span input_lengths_data, - ScratchAllocator* scratch_allocator, DeviceMemory* scratch_memory) { + ScratchAllocator* scratch_allocator, DeviceMemory* scratch_memory, + int* ctc_loss_algo_id) { auto miopen = miopen_->GetHandle(parent_, stream); MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type)); @@ -2456,7 +2457,7 @@ port::Status MIOpenSupport::DoCtcLossImpl( absl::Span input_lengths_data, DeviceMemoryBase costs_data, const MIOpenRnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc, - DeviceMemory scratch_memory) { + DeviceMemory scratch_memory, int ctc_loss_algo_id) { auto miopen = miopen_->GetHandle(parent_, stream); int kNumTimestamps = probs_desc.num_layers(); @@ -2482,13 +2483,12 @@ port::Status MIOpenSupport::DoCtcLossImpl( port::Status MIOpenSupport::DoCtcLoss( Stream* stream, dnn::DataType element_type, const dnn::RnnStateTensorDescriptor& probs_desc, - const DeviceMemoryBase probs_data, - - absl::Span labels_data, + const DeviceMemoryBase probs_data, absl::Span labels_data, absl::Span labels_lengths_data, absl::Span input_lengths_data, DeviceMemoryBase costs_data, const dnn::RnnStateTensorDescriptor& grads_desc, - DeviceMemoryBase grads_data, DeviceMemory scratch_memory) { + DeviceMemoryBase grads_data, DeviceMemory scratch_memory, + int ctc_loss_algo_id) { // Current MIOPen CTC Loss only supports the float datatype if (element_type != dnn::DataType::kFloat) { return port::Status(port::error::INVALID_ARGUMENT, @@ -2507,7 +2507,7 @@ port::Status MIOpenSupport::DoCtcLoss( return DoCtcLossImpl(stream, miopen_probs_desc, probs_data, labels_data, labels_lengths_data, input_lengths_data, costs_data, miopen_grads_desc, grads_data, miopen_ctc_loss_desc, - scratch_memory); + scratch_memory, ctc_loss_algo_id); } port::StatusOr> diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.h b/tensorflow/stream_executor/rocm/rocm_dnn.h index 76b6606d8a5..40e156b5f74 100644 --- a/tensorflow/stream_executor/rocm/rocm_dnn.h +++ b/tensorflow/stream_executor/rocm/rocm_dnn.h @@ -648,7 +648,8 @@ class MIOpenSupport : public dnn::DnnSupport { DeviceMemoryBase costs_data, const dnn::RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data, - DeviceMemory scratch_memory) override; + DeviceMemory scratch_memory, + int ctc_loss_algo_id) override; private: GpuExecutor* parent_; // Parent executor object. Not owned. @@ -812,7 +813,7 @@ class MIOpenSupport : public dnn::DnnSupport { absl::Span input_lengths_data, DeviceMemoryBase costs_data, const MIOpenRnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc, - DeviceMemory scratch_memory); + DeviceMemory scratch_memory, int ctc_loss_algo_id); port::Status DoPrepareForCtcLoss( Stream* stream, dnn::DataType element_type, @@ -821,8 +822,8 @@ class MIOpenSupport : public dnn::DnnSupport { absl::Span labels_data, absl::Span labels_lengths_data, absl::Span input_lengths_data, - ScratchAllocator* scratch_allocator, - DeviceMemory* scratch_memory) override; + ScratchAllocator* scratch_allocator, DeviceMemory* scratch_memory, + int* ctc_loss_algo_id) override; bool GetMIOpenConvolveAlgorithmsImmediateMode( dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 8b50eab838c..c63565c65a8 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -5251,16 +5251,18 @@ Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc, if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { DeviceMemory scratch_memory; - auto status = dnn->PrepareForCtcLoss( - this, probs_desc, probs_data, grads_desc, - labels_data, labels_lengths_data, input_lengths_data, - workspace_allocator, &scratch_memory) - .ok(); + int ctc_loss_algo_id; + auto status = + dnn->PrepareForCtcLoss(this, probs_desc, probs_data, grads_desc, + labels_data, labels_lengths_data, + input_lengths_data, workspace_allocator, + &scratch_memory, &ctc_loss_algo_id) + .ok(); if (status) { - status = - dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data, - labels_lengths_data, input_lengths_data, costs_data, - grads_desc, grads_data, &scratch_memory); + status = dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data, + labels_lengths_data, input_lengths_data, + costs_data, grads_desc, grads_data, + &scratch_memory, ctc_loss_algo_id); } if (!status) { SetError();