PR #37679: Support two CUDNN CTC Loss algorithms
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/37679 This PR enables CUDNN CTC Loss to support both deterministic and non-deterministic algos. If determinism is required, we will stick to deterministic algo. Otherwise, the faster non-deterministic algo will be tried in cudnnGetCtcLossWorkspace(). If it fails, we fall-back to the deterministic algo. fyi @nluehr @sanjoy Copybara import of the project: -- e338f0bad8a10c6b0e6284e609dfe2b6ebeab9f0 by Kaixi Hou <kaixih@nvidia.com>: CTC loss support two algos -- c90ebbd533202c791e342b5716c3410c4557ea1a by Kaixi Hou <kaixih@nvidia.com>: Reuse macros of cudnn launch -- c9fd60fa4abc37674923acb802e46a7ac8055796 by Kaixi Hou <kaixih@nvidia.com>: Add unittest for ctc_loss fallback algo -- e29100d0202e2a519845b818d6f9603505d2c83e by Kaixi Hou <kaixih@nvidia.com>: Shorten the comment sentence COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/37679 from houtoms:ctc_loss_workspace_check e29100d0202e2a519845b818d6f9603505d2c83e PiperOrigin-RevId: 305942210 Change-Id: I57062bcb5f04097a3280eb6eeb5de51bae6ef3ca
This commit is contained in:
parent
b43ff5b8dd
commit
d779e8431a
|
@ -941,6 +941,21 @@ class CTCLossTestV2(test.TestCase):
|
|||
[[1.0, 2.0], [5.0, 8.0], [14.0, 20.0]], out)
|
||||
|
||||
|
||||
def _ctc_loss_v3(labels, logits, label_length, logit_length, use_gpu):
|
||||
with test_util.device(use_gpu=use_gpu):
|
||||
sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
|
||||
with backprop.GradientTape() as t:
|
||||
t.watch(logits)
|
||||
ref_loss = ctc_ops.ctc_loss_v3(
|
||||
labels=sparse_labels,
|
||||
logits=logits,
|
||||
label_length=label_length,
|
||||
logit_length=logit_length,
|
||||
blank_index=0)
|
||||
ref_grad = t.gradient(ref_loss, [logits])
|
||||
return ref_loss, ref_grad
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class CTCLossTestV3(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
@ -978,42 +993,47 @@ class CTCLossTestV3(test.TestCase, parameterized.TestCase):
|
|||
labels *= label_mask
|
||||
logit_length = [num_frames] * batch_size
|
||||
|
||||
def ctc_loss_cpu(labels, logits, label_length, logit_length):
|
||||
with test_util.device(use_gpu=False):
|
||||
sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
|
||||
with backprop.GradientTape() as t:
|
||||
t.watch(logits)
|
||||
ref_loss = ctc_ops.ctc_loss_v3(
|
||||
labels=sparse_labels,
|
||||
logits=logits,
|
||||
label_length=label_length,
|
||||
logit_length=logit_length,
|
||||
blank_index=0)
|
||||
ref_grad = t.gradient(ref_loss, [logits])
|
||||
return ref_loss, ref_grad
|
||||
|
||||
def ctc_loss_gpu(labels, logits, label_length, logit_length):
|
||||
with test_util.device(use_gpu=True):
|
||||
sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
|
||||
with backprop.GradientTape() as t:
|
||||
t.watch(logits)
|
||||
loss = ctc_ops.ctc_loss_v3(
|
||||
labels=sparse_labels,
|
||||
logits=logits,
|
||||
label_length=label_length,
|
||||
logit_length=logit_length,
|
||||
blank_index=0)
|
||||
grad = t.gradient(loss, [logits])
|
||||
|
||||
return loss, grad
|
||||
|
||||
if run_tf_func:
|
||||
ctc_loss_cpu = def_function.function(ctc_loss_cpu)
|
||||
ctc_loss_gpu = def_function.function(ctc_loss_gpu)
|
||||
ctc_loss = def_function.function(_ctc_loss_v3)
|
||||
else:
|
||||
ctc_loss = _ctc_loss_v3
|
||||
|
||||
ref_loss, ref_grad = ctc_loss_cpu(labels, logits, label_length,
|
||||
logit_length)
|
||||
loss, grad = ctc_loss_gpu(labels, logits, label_length, logit_length)
|
||||
ref_loss, ref_grad = ctc_loss(labels, logits, label_length, logit_length,
|
||||
False)
|
||||
loss, grad = ctc_loss(labels, logits, label_length, logit_length, True)
|
||||
|
||||
self.assertAllClose(loss, ref_loss, atol=1e-6)
|
||||
self.assertAllClose(grad, ref_grad, atol=2e-6)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testCtcLossAlgorithmFallback(self):
|
||||
"""Test if GPU CTC loss can fallback to the correct algorithm."""
|
||||
if not test.is_gpu_available():
|
||||
self.skipTest("Need GPU for testing.")
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest("Need eager execution for testing.")
|
||||
random_seed.set_random_seed(5)
|
||||
|
||||
batch_size = 1
|
||||
num_labels = 11777
|
||||
max_label_length = 2
|
||||
num_frames = 1
|
||||
|
||||
labels = random_ops.random_uniform([batch_size, max_label_length],
|
||||
minval=1,
|
||||
maxval=num_labels,
|
||||
dtype=dtypes.int64)
|
||||
logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
|
||||
|
||||
label_length = random_ops.random_uniform([batch_size],
|
||||
minval=1,
|
||||
maxval=max_label_length,
|
||||
dtype=dtypes.int64)
|
||||
logit_length = [num_frames] * batch_size
|
||||
|
||||
loss, grad = _ctc_loss_v3(labels, logits, label_length, logit_length, True)
|
||||
ref_loss, ref_grad = _ctc_loss_v3(labels, logits, label_length,
|
||||
logit_length, False)
|
||||
|
||||
self.assertAllClose(loss, ref_loss, atol=1e-6)
|
||||
self.assertAllClose(grad, ref_grad, atol=2e-6)
|
||||
|
|
|
@ -2045,7 +2045,7 @@ port::Status CudnnSupport::DoCtcLossImpl(
|
|||
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||
const CudnnRnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
|
||||
DeviceMemory<uint8> scratch_memory) {
|
||||
DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
|
||||
int kNumTimestamps = probs_desc.num_layers();
|
||||
|
@ -2055,6 +2055,8 @@ port::Status CudnnSupport::DoCtcLossImpl(
|
|||
(void)total_size;
|
||||
|
||||
#if CUDNN_VERSION >= 7603
|
||||
cudnnCTCLossAlgo_t ctc_loss_algo =
|
||||
static_cast<cudnnCTCLossAlgo_t>(ctc_loss_algo_id);
|
||||
RETURN_IF_CUDNN_ERROR(cudnnCTCLoss(
|
||||
/*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
|
||||
/*probs=*/probs_data.opaque(), /*labels=*/labels_data.data(),
|
||||
|
@ -2062,9 +2064,7 @@ port::Status CudnnSupport::DoCtcLossImpl(
|
|||
/*inputLengths=*/input_lengths_data.data(),
|
||||
/*costs=*/costs_data.opaque(), /*gradientsDesc=*/grads_desc.handle(),
|
||||
/*gradients=*/grads_data.opaque(),
|
||||
/*algo=*/
|
||||
RequireCudnnDeterminism() ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
|
||||
: CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
|
||||
/*algo=*/ctc_loss_algo,
|
||||
/*ctcLossDesc=*/ctc_loss_desc.handle(),
|
||||
/*workspace=*/scratch_memory.opaque(),
|
||||
/*workSpaceSizeInBytes=*/scratch_memory.size()));
|
||||
|
@ -3935,7 +3935,8 @@ port::Status CudnnSupport::DoPrepareForCtcLoss(
|
|||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory) {
|
||||
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
|
||||
int* ctc_loss_algo_id) {
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
// Query the workspace size.
|
||||
size_t workspace_size_in_bytes = 0;
|
||||
|
@ -3945,17 +3946,38 @@ port::Status CudnnSupport::DoPrepareForCtcLoss(
|
|||
static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
|
||||
const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
|
||||
static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
|
||||
|
||||
// Try running with `algo`, if successful then pick it. The non-deterministic
|
||||
// algorithm is first and thus preferentially picked when determinism is not
|
||||
// required.
|
||||
auto algo = RequireCudnnDeterminism() ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
|
||||
: CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC;
|
||||
cudnnStatus_t status = cudnnGetCTCLossWorkspaceSize(
|
||||
/*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
|
||||
/*gradientsDesc=*/cudnn_grads_desc.handle(),
|
||||
/*labels=*/labels_data.data(),
|
||||
/*labelLengths=*/labels_lengths_data.data(),
|
||||
/*inputLengths=*/input_lengths_data.data(),
|
||||
/*algo=*/algo,
|
||||
/*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
|
||||
/*sizeInBytes=*/&workspace_size_in_bytes);
|
||||
if (RequireCudnnDeterminism()) {
|
||||
RETURN_IF_CUDNN_ERROR(status);
|
||||
}
|
||||
|
||||
if (status != CUDNN_STATUS_SUCCESS) {
|
||||
algo = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC;
|
||||
RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
|
||||
/*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
|
||||
/*gradientsDesc=*/cudnn_grads_desc.handle(),
|
||||
/*labels=*/labels_data.data(),
|
||||
/*labelLengths=*/labels_lengths_data.data(),
|
||||
/*inputLengths=*/input_lengths_data.data(),
|
||||
/*algo=*/
|
||||
RequireCudnnDeterminism() ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
|
||||
: CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
|
||||
/*algo=*/algo,
|
||||
/*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
|
||||
/*sizeInBytes=*/&workspace_size_in_bytes));
|
||||
}
|
||||
*ctc_loss_algo_id = algo;
|
||||
#else
|
||||
return port::Status(port::error::INVALID_ARGUMENT,
|
||||
"No supported cudnnGetCTCLossWorkspaceSize when "
|
||||
|
@ -3979,13 +4001,12 @@ port::Status CudnnSupport::DoPrepareForCtcLoss(
|
|||
port::Status CudnnSupport::DoCtcLoss(
|
||||
Stream* stream, dnn::DataType element_type,
|
||||
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||
const DeviceMemoryBase probs_data,
|
||||
|
||||
absl::Span<const int> labels_data,
|
||||
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory) {
|
||||
DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory,
|
||||
int ctc_loss_algo_id) {
|
||||
// Current cuDNN CTC Loss only supports the float datatype
|
||||
if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) {
|
||||
return port::Status(port::error::INVALID_ARGUMENT,
|
||||
|
@ -4000,7 +4021,7 @@ port::Status CudnnSupport::DoCtcLoss(
|
|||
return DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data,
|
||||
labels_lengths_data, input_lengths_data, costs_data,
|
||||
cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc,
|
||||
scratch_memory);
|
||||
scratch_memory, ctc_loss_algo_id);
|
||||
}
|
||||
|
||||
bool CudnnSupport::DoTransformTensor(Stream* stream,
|
||||
|
|
|
@ -570,7 +570,8 @@ class CudnnSupport : public dnn::DnnSupport {
|
|||
DeviceMemoryBase costs_data,
|
||||
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data,
|
||||
DeviceMemory<uint8> scratch_memory) override;
|
||||
DeviceMemory<uint8> scratch_memory,
|
||||
int ctc_loss_algo_id) override;
|
||||
|
||||
bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
|
||||
dnn::DataType input_type,
|
||||
|
@ -689,7 +690,7 @@ class CudnnSupport : public dnn::DnnSupport {
|
|||
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||
const CudnnRnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
|
||||
DeviceMemory<uint8> scratch_memory);
|
||||
DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id);
|
||||
|
||||
private:
|
||||
port::Status DoPrepareForConvolution(
|
||||
|
@ -711,8 +712,8 @@ class CudnnSupport : public dnn::DnnSupport {
|
|||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
DeviceMemory<uint8>* scratch_memory) override;
|
||||
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
|
||||
int* ctc_loss_algo_id) override;
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
|
||||
};
|
||||
|
|
|
@ -618,16 +618,14 @@ bool DnnSupport::IsStatusOk(const port::Status& status, bool report_error) {
|
|||
return false;
|
||||
}
|
||||
|
||||
port::Status DnnSupport::DoCtcLoss(Stream* stream, dnn::DataType element_type,
|
||||
port::Status DnnSupport::DoCtcLoss(
|
||||
Stream* stream, dnn::DataType element_type,
|
||||
const RnnStateTensorDescriptor& probs_desc,
|
||||
const DeviceMemoryBase probs_data,
|
||||
absl::Span<const int> labels_data,
|
||||
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
DeviceMemoryBase costs_data,
|
||||
const RnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data,
|
||||
DeviceMemory<uint8> scratch_memory) {
|
||||
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||
const RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data,
|
||||
DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
|
||||
return port::UnimplementedError("CtcLoss not implemented");
|
||||
}
|
||||
|
||||
|
|
|
@ -2396,11 +2396,12 @@ class DnnSupport {
|
|||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
ScratchAllocator* workspace_allocator,
|
||||
DeviceMemory<uint8>* scratch_memory) {
|
||||
return DoPrepareForCtcLoss(stream, ToDataType<ElementType>::value,
|
||||
probs_desc, grads_desc, labels_data,
|
||||
labels_lengths_data, input_lengths_data,
|
||||
workspace_allocator, scratch_memory);
|
||||
DeviceMemory<uint8>* scratch_memory,
|
||||
int* ctc_loss_algo_id) {
|
||||
return DoPrepareForCtcLoss(
|
||||
stream, ToDataType<ElementType>::value, probs_desc, grads_desc,
|
||||
labels_data, labels_lengths_data, input_lengths_data,
|
||||
workspace_allocator, scratch_memory, ctc_loss_algo_id);
|
||||
}
|
||||
|
||||
// Enqueue a CTC Loss operation onto the stream.
|
||||
|
@ -2424,16 +2425,14 @@ class DnnSupport {
|
|||
// workspace memory used by this operation. The caller is responsible for
|
||||
// keeping the memory alive long enough for this operation, and recylces
|
||||
// afterwards.
|
||||
virtual port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
|
||||
virtual port::Status DoCtcLoss(
|
||||
Stream* stream, dnn::DataType element_type,
|
||||
const RnnStateTensorDescriptor& probs_desc,
|
||||
const DeviceMemoryBase probs_data,
|
||||
absl::Span<const int> labels_data,
|
||||
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
DeviceMemoryBase costs_data,
|
||||
const RnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data,
|
||||
DeviceMemory<uint8> scratch_memory);
|
||||
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||
const RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data,
|
||||
DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id);
|
||||
|
||||
template <typename ElementType>
|
||||
bool DoCtcLoss(Stream* stream,
|
||||
|
@ -2445,12 +2444,12 @@ class DnnSupport {
|
|||
DeviceMemory<ElementType>* costs_data,
|
||||
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemory<ElementType>* grads_data,
|
||||
DeviceMemory<uint8>* scratch_memory) {
|
||||
DeviceMemory<uint8>* scratch_memory, int ctc_loss_algo_id) {
|
||||
return IsStatusOk(
|
||||
DoCtcLoss(stream, ToDataType<ElementType>::value, probs_desc,
|
||||
probs_data, labels_data, labels_lengths_data,
|
||||
input_lengths_data, *costs_data, grads_desc, *grads_data,
|
||||
*scratch_memory),
|
||||
*scratch_memory, ctc_loss_algo_id),
|
||||
false);
|
||||
}
|
||||
|
||||
|
@ -2716,8 +2715,8 @@ class DnnSupport {
|
|||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
DeviceMemory<uint8>* scratch_memory) {
|
||||
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
|
||||
int* ctc_loss_algo_id) {
|
||||
*scratch_memory = {};
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
|
|
@ -2393,7 +2393,8 @@ port::Status MIOpenSupport::DoPrepareForCtcLoss(
|
|||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory) {
|
||||
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
|
||||
int* ctc_loss_algo_id) {
|
||||
auto miopen = miopen_->GetHandle(parent_, stream);
|
||||
|
||||
MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type));
|
||||
|
@ -2456,7 +2457,7 @@ port::Status MIOpenSupport::DoCtcLossImpl(
|
|||
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||
const MIOpenRnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
|
||||
DeviceMemory<uint8> scratch_memory) {
|
||||
DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
|
||||
auto miopen = miopen_->GetHandle(parent_, stream);
|
||||
|
||||
int kNumTimestamps = probs_desc.num_layers();
|
||||
|
@ -2482,13 +2483,12 @@ port::Status MIOpenSupport::DoCtcLossImpl(
|
|||
port::Status MIOpenSupport::DoCtcLoss(
|
||||
Stream* stream, dnn::DataType element_type,
|
||||
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||
const DeviceMemoryBase probs_data,
|
||||
|
||||
absl::Span<const int> labels_data,
|
||||
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory) {
|
||||
DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory,
|
||||
int ctc_loss_algo_id) {
|
||||
// Current MIOPen CTC Loss only supports the float datatype
|
||||
if (element_type != dnn::DataType::kFloat) {
|
||||
return port::Status(port::error::INVALID_ARGUMENT,
|
||||
|
@ -2507,7 +2507,7 @@ port::Status MIOpenSupport::DoCtcLoss(
|
|||
return DoCtcLossImpl(stream, miopen_probs_desc, probs_data, labels_data,
|
||||
labels_lengths_data, input_lengths_data, costs_data,
|
||||
miopen_grads_desc, grads_data, miopen_ctc_loss_desc,
|
||||
scratch_memory);
|
||||
scratch_memory, ctc_loss_algo_id);
|
||||
}
|
||||
|
||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
||||
|
|
|
@ -648,7 +648,8 @@ class MIOpenSupport : public dnn::DnnSupport {
|
|||
DeviceMemoryBase costs_data,
|
||||
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data,
|
||||
DeviceMemory<uint8> scratch_memory) override;
|
||||
DeviceMemory<uint8> scratch_memory,
|
||||
int ctc_loss_algo_id) override;
|
||||
|
||||
private:
|
||||
GpuExecutor* parent_; // Parent executor object. Not owned.
|
||||
|
@ -812,7 +813,7 @@ class MIOpenSupport : public dnn::DnnSupport {
|
|||
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||
const MIOpenRnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
|
||||
DeviceMemory<uint8> scratch_memory);
|
||||
DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id);
|
||||
|
||||
port::Status DoPrepareForCtcLoss(
|
||||
Stream* stream, dnn::DataType element_type,
|
||||
|
@ -821,8 +822,8 @@ class MIOpenSupport : public dnn::DnnSupport {
|
|||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
DeviceMemory<uint8>* scratch_memory) override;
|
||||
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
|
||||
int* ctc_loss_algo_id) override;
|
||||
|
||||
bool GetMIOpenConvolveAlgorithmsImmediateMode(
|
||||
dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
|
||||
|
|
|
@ -5251,16 +5251,18 @@ Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
|
|||
if (ok()) {
|
||||
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||
DeviceMemory<uint8> scratch_memory;
|
||||
auto status = dnn->PrepareForCtcLoss(
|
||||
this, probs_desc, probs_data, grads_desc,
|
||||
labels_data, labels_lengths_data, input_lengths_data,
|
||||
workspace_allocator, &scratch_memory)
|
||||
int ctc_loss_algo_id;
|
||||
auto status =
|
||||
dnn->PrepareForCtcLoss(this, probs_desc, probs_data, grads_desc,
|
||||
labels_data, labels_lengths_data,
|
||||
input_lengths_data, workspace_allocator,
|
||||
&scratch_memory, &ctc_loss_algo_id)
|
||||
.ok();
|
||||
if (status) {
|
||||
status =
|
||||
dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data,
|
||||
labels_lengths_data, input_lengths_data, costs_data,
|
||||
grads_desc, grads_data, &scratch_memory);
|
||||
status = dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data,
|
||||
labels_lengths_data, input_lengths_data,
|
||||
costs_data, grads_desc, grads_data,
|
||||
&scratch_memory, ctc_loss_algo_id);
|
||||
}
|
||||
if (!status) {
|
||||
SetError();
|
||||
|
|
Loading…
Reference in New Issue