Merge pull request #32302 from houtoms:pr_cudnn_ctc_loss
PiperOrigin-RevId: 290387603 Change-Id: I28491f42a4559a9f79bd6a7b73d8e6b670f55368
This commit is contained in:
commit
bd4c38b3dc
72
tensorflow/core/api_def/base_api/api_def_CTCLossV2.pbtxt
Normal file
72
tensorflow/core/api_def/base_api/api_def_CTCLossV2.pbtxt
Normal file
@ -0,0 +1,72 @@
|
||||
op {
|
||||
graph_op_name: "CTCLossV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "inputs"
|
||||
description: <<END
|
||||
3-D, shape: `(max_time x batch_size x num_classes)`, the logits. Default blank
|
||||
label is 0 rather num_classes - 1.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "labels_indices"
|
||||
description: <<END
|
||||
The indices of a `SparseTensor<int32, 2>`.
|
||||
`labels_indices(i, :) == [b, t]` means `labels_values(i)` stores the id for
|
||||
`(batch b, time t)`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "labels_values"
|
||||
description: <<END
|
||||
The values (labels) associated with the given batch and time.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "sequence_length"
|
||||
description: <<END
|
||||
A vector containing sequence lengths (batch).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "loss"
|
||||
description: <<END
|
||||
A vector (batch) containing log-probabilities.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "gradient"
|
||||
description: <<END
|
||||
The gradient of `loss`. 3-D, shape:
|
||||
`(max_time x batch_size x num_classes)`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "preprocess_collapse_repeated"
|
||||
description: <<END
|
||||
Scalar, if true then repeated labels are
|
||||
collapsed prior to the CTC calculation.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "ctc_merge_repeated"
|
||||
description: <<END
|
||||
Scalar. If set to false, *during* CTC calculation
|
||||
repeated non-blank labels will not be merged and are interpreted as
|
||||
individual labels. This is a simplified version of CTC.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "ignore_longer_outputs_than_inputs"
|
||||
description: <<END
|
||||
Scalar. If set to true, during CTC
|
||||
calculation, items that have longer output sequences than input sequences
|
||||
are skipped: they don't contribute to the loss term and have zero-gradient.
|
||||
END
|
||||
}
|
||||
summary: "Calculates the CTC Loss (log probability) for each batch entry. Also calculates"
|
||||
description: <<END
|
||||
the gradient. This class performs the softmax operation for you, so inputs
|
||||
should be e.g. linear projections of outputs by an LSTM.
|
||||
END
|
||||
}
|
@ -2358,7 +2358,11 @@ tf_kernel_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/util/ctc:ctc_beam_search_lib",
|
||||
"//tensorflow/core/util/ctc:ctc_loss_calculator_lib",
|
||||
],
|
||||
] + if_cuda([
|
||||
":gpu_utils",
|
||||
":conv_ops_gpu_hdrs",
|
||||
"@local_config_cuda//cuda:cudnn_header",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -15,6 +15,10 @@ limitations under the License.
|
||||
|
||||
// See docs in ../ops/ctc_ops.cc.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -25,8 +29,39 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
|
||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cudnn/cudnn.h"
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
#if GOOGLE_CUDA
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
namespace {
|
||||
using se::Stream;
|
||||
using se::StreamExecutor;
|
||||
using se::dnn::RnnStateTensorDescriptor;
|
||||
using se::dnn::ToDataType;
|
||||
|
||||
template <typename T>
|
||||
void DoHistogram(OpKernelContext* ctx, const Tensor* labels_indices,
|
||||
int num_indices, int batch_size,
|
||||
std::vector<int>* labels_lengths) {
|
||||
const T* h_in = labels_indices->flat<T>().data();
|
||||
for (int i = 0; i < num_indices; i++) {
|
||||
const T& key = h_in[i * 2];
|
||||
(*labels_lengths)[key]++;
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
template <typename T>
|
||||
class CTCLossOp : public OpKernel {
|
||||
typedef Eigen::Map<
|
||||
@ -186,4 +221,150 @@ REGISTER_CPU(double);
|
||||
|
||||
#undef REGISTER_CPU
|
||||
|
||||
#if GOOGLE_CUDA && CUDNN_VERSION >= 7603
|
||||
class CTCLossOpGPU : public OpKernel {
|
||||
public:
|
||||
explicit CTCLossOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
bool preprocess_collapse_repeated;
|
||||
bool ctc_merge_repeated;
|
||||
bool ignore_longer_outputs_than_inputs;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("preprocess_collapse_repeated",
|
||||
&preprocess_collapse_repeated));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs",
|
||||
&ignore_longer_outputs_than_inputs));
|
||||
|
||||
OP_REQUIRES(ctx, !preprocess_collapse_repeated,
|
||||
errors::InvalidArgument("GPU CTCLossOp requires "
|
||||
"preprocess_collapse_repeated to be "
|
||||
"false"));
|
||||
OP_REQUIRES(ctx, ctc_merge_repeated,
|
||||
errors::InvalidArgument("GPU CTCLossOp requires "
|
||||
"ctc_merge_repeated to be "
|
||||
"true"));
|
||||
OP_REQUIRES(ctx, !ignore_longer_outputs_than_inputs,
|
||||
errors::InvalidArgument("GPU CTCLossOp requires "
|
||||
"ignore_longer_outputs_than_inputs to"
|
||||
"be false"));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor* inputs;
|
||||
const Tensor* labels_indices;
|
||||
const Tensor* labels_values;
|
||||
const Tensor* seq_len;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("inputs", &inputs));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("labels_indices", &labels_indices));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("labels_values", &labels_values));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("sequence_length", &seq_len));
|
||||
|
||||
OP_REQUIRES(ctx, inputs->shape().dims() == 3,
|
||||
errors::InvalidArgument("inputs is not a 3-Tensor"));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(seq_len->shape()),
|
||||
errors::InvalidArgument("sequence_length is not a vector"));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_indices->shape()),
|
||||
errors::InvalidArgument("labels_indices is not a matrix"));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_values->shape()),
|
||||
errors::InvalidArgument("labels_values is not a vector"));
|
||||
|
||||
const TensorShape& inputs_shape = inputs->shape();
|
||||
const int64 max_time_raw = inputs_shape.dim_size(0);
|
||||
const int64 batch_size_raw = inputs_shape.dim_size(1);
|
||||
const int64 num_classes_raw = inputs_shape.dim_size(2);
|
||||
OP_REQUIRES(ctx,
|
||||
FastBoundsCheck(max_time_raw, std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("max_time_ cannot exceed max int"));
|
||||
OP_REQUIRES(
|
||||
ctx, FastBoundsCheck(batch_size_raw, std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("batch_size cannot exceed max int"));
|
||||
OP_REQUIRES(
|
||||
ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("num_classes cannot exceed max int"));
|
||||
const int max_time = static_cast<const int>(max_time_raw);
|
||||
const int batch_size = static_cast<const int>(batch_size_raw);
|
||||
const int num_classes = static_cast<const int>(num_classes_raw);
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, batch_size == seq_len->dim_size(0),
|
||||
errors::InvalidArgument("len(sequence_length) != batch_size. ",
|
||||
"len(sequence_length): ", seq_len->dim_size(0),
|
||||
" batch_size: ", batch_size));
|
||||
|
||||
OP_REQUIRES(ctx, labels_indices->dim_size(0) == labels_values->dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"labels_indices and labels_values must contain the "
|
||||
"same number of rows, but saw shapes: ",
|
||||
labels_indices->shape().DebugString(), " vs. ",
|
||||
labels_values->shape().DebugString()));
|
||||
auto num_indices = labels_indices->dim_size(0);
|
||||
|
||||
OP_REQUIRES(ctx, batch_size != 0,
|
||||
errors::InvalidArgument("batch_size must not be 0"));
|
||||
|
||||
Tensor* loss = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss));
|
||||
|
||||
Tensor* gradient = nullptr;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output("gradient", inputs_shape, &gradient));
|
||||
|
||||
// Convert the labels_indices to labels_lengths.
|
||||
std::vector<int> labels_lengths(batch_size, 0);
|
||||
DoHistogram<int64>(ctx, labels_indices, num_indices, batch_size,
|
||||
&labels_lengths);
|
||||
|
||||
StreamExecutor* executor = ctx->op_device_context()->stream()->parent();
|
||||
se::dnn::DataType data_type = ToDataType<float>::value;
|
||||
|
||||
auto probs_desc_s = executor->createRnnStateTensorDescriptor(
|
||||
max_time, batch_size, num_classes, data_type);
|
||||
OP_REQUIRES_OK(ctx, probs_desc_s.status());
|
||||
std::unique_ptr<RnnStateTensorDescriptor> probs_desc =
|
||||
probs_desc_s.ConsumeValueOrDie();
|
||||
|
||||
auto grads_desc_s = executor->createRnnStateTensorDescriptor(
|
||||
max_time, batch_size, num_classes, data_type);
|
||||
OP_REQUIRES_OK(ctx, grads_desc_s.status());
|
||||
std::unique_ptr<RnnStateTensorDescriptor> grads_desc =
|
||||
grads_desc_s.ConsumeValueOrDie();
|
||||
|
||||
absl::Span<const int32> labels_data(labels_values->flat<int32>().data(),
|
||||
num_indices);
|
||||
absl::Span<const int32> labels_lengths_data(labels_lengths.data(),
|
||||
batch_size);
|
||||
absl::Span<const int32> input_lengths_data(seq_len->flat<int32>().data(),
|
||||
batch_size);
|
||||
|
||||
auto probs_data = StreamExecutorUtil::AsDeviceMemory<float>(*inputs);
|
||||
auto costs_data = StreamExecutorUtil::AsDeviceMemory<float>(*loss);
|
||||
auto grads_data = StreamExecutorUtil::AsDeviceMemory<float>(*gradient);
|
||||
|
||||
// Set the memory limitation to 4GB for workspace memory.
|
||||
DnnScratchAllocator workspace_allocator(1LL << 32, ctx);
|
||||
|
||||
Stream* stream = ctx->op_device_context()->stream();
|
||||
bool cudnn_launch_status =
|
||||
stream
|
||||
->ThenCtcLoss(*probs_desc, probs_data, labels_data,
|
||||
labels_lengths_data, input_lengths_data, &costs_data,
|
||||
*grads_desc, &grads_data, &workspace_allocator)
|
||||
.ok();
|
||||
|
||||
if (!cudnn_launch_status) {
|
||||
ctx->SetStatus(errors::Internal("cuDNN CTCLoss launch failure"));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOpGPU);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CTCLossV2")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("labels_indices")
|
||||
.HostMemory("labels_values")
|
||||
.HostMemory("sequence_length"),
|
||||
CTCLossOpGPU);
|
||||
#endif // GOOGLE_CUDA && CUDNN_VERSION >= 7603
|
||||
} // end namespace tensorflow
|
||||
|
@ -62,6 +62,43 @@ REGISTER_OP("CTCLoss")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("CTCLossV2")
|
||||
.Input("inputs: float")
|
||||
.Input("labels_indices: int64")
|
||||
.Input("labels_values: int32")
|
||||
.Input("sequence_length: int32")
|
||||
.Attr("preprocess_collapse_repeated: bool = false")
|
||||
.Attr("ctc_merge_repeated: bool = true")
|
||||
.Attr("ignore_longer_outputs_than_inputs: bool = false")
|
||||
.Output("loss: float")
|
||||
.Output("gradient: float")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle inputs;
|
||||
ShapeHandle labels_indices;
|
||||
ShapeHandle labels_values;
|
||||
ShapeHandle sequence_length;
|
||||
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &labels_indices));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &labels_values));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &sequence_length));
|
||||
|
||||
DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->Dim(labels_indices, 0),
|
||||
c->Dim(labels_values, 0), &unused));
|
||||
|
||||
// Get batch size from inputs and sequence_length, and update inputs
|
||||
// with the merged batch_size since it is returned.
|
||||
DimensionHandle batch_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
|
||||
TF_RETURN_IF_ERROR(c->ReplaceDim(inputs, 1, batch_size, &inputs));
|
||||
|
||||
c->set_output(0, c->Vector(batch_size));
|
||||
c->set_output(1, inputs);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("CTCGreedyDecoder")
|
||||
.Input("inputs: T")
|
||||
.Input("sequence_length: int32")
|
||||
|
@ -18,10 +18,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
@ -29,6 +31,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import ctc_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
@ -839,5 +842,82 @@ class CTCLossTestV2(test.TestCase):
|
||||
self.assertAllEqual(
|
||||
[[1.0, 2.0], [5.0, 8.0], [14.0, 20.0]], out)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class CTCLossTestV3(keras_parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([False, True])
|
||||
@test_util.run_v2_only
|
||||
def testCtcLossV3(self, run_tf_func):
|
||||
"""Testing GPU CTC loss.
|
||||
|
||||
|
||||
testing if GPU CTC loss will generate same result with CPU version
|
||||
"""
|
||||
if not test.is_gpu_available():
|
||||
self.skipTest("Need GPU for testing.")
|
||||
random_seed.set_random_seed(5)
|
||||
|
||||
batch_size = 8
|
||||
num_labels = 6
|
||||
max_label_length = 5
|
||||
num_frames = 12
|
||||
|
||||
labels = random_ops.random_uniform([batch_size, max_label_length],
|
||||
minval=1,
|
||||
maxval=num_labels,
|
||||
dtype=dtypes.int64)
|
||||
logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
|
||||
|
||||
label_length = random_ops.random_uniform([batch_size],
|
||||
minval=2,
|
||||
maxval=max_label_length,
|
||||
dtype=dtypes.int64)
|
||||
label_mask = array_ops.sequence_mask(
|
||||
label_length, maxlen=max_label_length, dtype=label_length.dtype)
|
||||
labels *= label_mask
|
||||
logit_length = [num_frames] * batch_size
|
||||
|
||||
def ctc_loss_cpu(labels, logits, label_length, logit_length):
|
||||
with test_util.device(use_gpu=False):
|
||||
sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
|
||||
with backprop.GradientTape() as t:
|
||||
t.watch(logits)
|
||||
ref_loss = ctc_ops.ctc_loss_v3(
|
||||
labels=sparse_labels,
|
||||
logits=logits,
|
||||
label_length=label_length,
|
||||
logit_length=logit_length,
|
||||
blank_index=0)
|
||||
ref_grad = t.gradient(ref_loss, [logits])
|
||||
return ref_loss, ref_grad
|
||||
|
||||
def ctc_loss_gpu(labels, logits, label_length, logit_length):
|
||||
with test_util.device(use_gpu=True):
|
||||
sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
|
||||
with backprop.GradientTape() as t:
|
||||
t.watch(logits)
|
||||
loss = ctc_ops.ctc_loss_v3(
|
||||
labels=sparse_labels,
|
||||
logits=logits,
|
||||
label_length=label_length,
|
||||
logit_length=logit_length,
|
||||
blank_index=0)
|
||||
grad = t.gradient(loss, [logits])
|
||||
|
||||
return loss, grad
|
||||
|
||||
if run_tf_func:
|
||||
ctc_loss_cpu = def_function.function(ctc_loss_cpu)
|
||||
ctc_loss_gpu = def_function.function(ctc_loss_gpu)
|
||||
|
||||
ref_loss, ref_grad = ctc_loss_cpu(labels, logits, label_length,
|
||||
logit_length)
|
||||
loss, grad = ctc_loss_gpu(labels, logits, label_length, logit_length)
|
||||
|
||||
self.assertAllClose(loss, ref_loss, atol=1e-6)
|
||||
self.assertAllClose(grad, ref_grad, atol=2e-6)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -18,9 +18,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import uuid
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function as function_eager
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
@ -42,6 +46,27 @@ from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
_DEFUN_API_NAME_ATTRIBUTE = "api_implements"
|
||||
_DEFUN_DEVICE_ATTRIBUTE = "api_preferred_device"
|
||||
_CPU_DEVICE_NAME = "CPU"
|
||||
_GPU_DEVICE_NAME = "GPU"
|
||||
|
||||
|
||||
def _get_context_device_type():
|
||||
"""Parse the current context and return the device type, eg CPU/GPU."""
|
||||
current_device = context.context().device_name
|
||||
if current_device is None:
|
||||
return None
|
||||
return device.DeviceSpec.from_string(current_device).device_type
|
||||
|
||||
|
||||
def _generate_defun_backend(unique_api_name, preferred_device, func):
|
||||
function_attributes = {
|
||||
_DEFUN_API_NAME_ATTRIBUTE: unique_api_name,
|
||||
_DEFUN_DEVICE_ATTRIBUTE: preferred_device,
|
||||
}
|
||||
return function_eager.defun_with_attributes(
|
||||
func=func, attributes=function_attributes, autograph=False)
|
||||
|
||||
# pylint: disable=protected-access, invalid-name
|
||||
@tf_export(v1=["nn.ctc_loss"])
|
||||
@ -156,6 +181,31 @@ def ctc_loss(labels,
|
||||
[Graves et al., 2016](https://dl.acm.org/citation.cfm?id=1143891)
|
||||
([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
|
||||
"""
|
||||
return _ctc_loss_impl(
|
||||
labels,
|
||||
inputs,
|
||||
sequence_length,
|
||||
preprocess_collapse_repeated,
|
||||
ctc_merge_repeated,
|
||||
ignore_longer_outputs_than_inputs,
|
||||
time_major,
|
||||
logits,
|
||||
use_cudnn=False)
|
||||
|
||||
|
||||
def _ctc_loss_impl(labels,
|
||||
inputs=None,
|
||||
sequence_length=None,
|
||||
preprocess_collapse_repeated=False,
|
||||
ctc_merge_repeated=True,
|
||||
ignore_longer_outputs_than_inputs=False,
|
||||
time_major=True,
|
||||
logits=None,
|
||||
use_cudnn=False):
|
||||
# Helper function of ctc_loss with one additional param:
|
||||
# use_cudnn: A bool to enable cuDNN CTC loss operation. If true, the blank
|
||||
# index has to be 0.
|
||||
|
||||
# The second, third, etc output tensors contain the gradients. We use it in
|
||||
# _CTCLossGrad() below.
|
||||
if not isinstance(labels, sparse_tensor.SparseTensor):
|
||||
@ -167,7 +217,14 @@ def ctc_loss(labels,
|
||||
if not time_major:
|
||||
inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,N) => (T,B,N)
|
||||
|
||||
loss, _ = gen_ctc_ops.ctc_loss(
|
||||
# gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the
|
||||
# blank index to be 0, but v1 views it as the last index.
|
||||
if use_cudnn:
|
||||
ctc_loss_func = gen_ctc_ops.ctc_loss_v2
|
||||
else:
|
||||
ctc_loss_func = gen_ctc_ops.ctc_loss
|
||||
|
||||
loss, _ = ctc_loss_func(
|
||||
inputs,
|
||||
labels.indices,
|
||||
labels.values,
|
||||
@ -178,19 +235,8 @@ def ctc_loss(labels,
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@ops.RegisterGradient("CTCLoss")
|
||||
def _CTCLossGrad(op, grad_loss, _):
|
||||
"""The derivative provided by CTC Loss.
|
||||
|
||||
Args:
|
||||
op: the CTCLoss op.
|
||||
grad_loss: The backprop for cost.
|
||||
|
||||
Returns:
|
||||
The CTC Loss gradient.
|
||||
"""
|
||||
def _CTCLossGradImpl(op, grad_loss, _):
|
||||
# Outputs are: loss, grad
|
||||
#
|
||||
# Currently there is no way to take the second derivative of this op
|
||||
@ -207,6 +253,36 @@ def _CTCLossGrad(op, grad_loss, _):
|
||||
return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None]
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@ops.RegisterGradient("CTCLoss")
|
||||
def _CTCLossGrad(op, grad_loss, _):
|
||||
"""The derivative provided by CTC Loss.
|
||||
|
||||
Args:
|
||||
op: the CTCLoss op.
|
||||
grad_loss: The backprop for cost.
|
||||
|
||||
Returns:
|
||||
The CTC Loss gradient.
|
||||
"""
|
||||
return _CTCLossGradImpl(op, grad_loss, _)
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@ops.RegisterGradient("CTCLossV2")
|
||||
def _CTCLossV2Grad(op, grad_loss, _):
|
||||
"""The derivative provided by CTC Loss V2.
|
||||
|
||||
Args:
|
||||
op: the CTCLossV2 op.
|
||||
grad_loss: The backprop for cost.
|
||||
|
||||
Returns:
|
||||
The CTC Loss V2 gradient.
|
||||
"""
|
||||
return _CTCLossGradImpl(op, grad_loss, _)
|
||||
|
||||
|
||||
@tf_export("nn.ctc_greedy_decoder")
|
||||
def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
|
||||
"""Performs greedy decoding on the logits given in input (best path).
|
||||
@ -593,11 +669,48 @@ def _ctc_loss_grad(op, grad_loss, _):
|
||||
return grad
|
||||
|
||||
|
||||
def _ctc_loss_op_standard(labels, logits, logit_length, logits_time_major,
|
||||
blank_index):
|
||||
part_before = logits[:, :, :blank_index]
|
||||
part_after = logits[:, :, blank_index + 1:]
|
||||
part_blank = logits[:, :, blank_index:blank_index + 1]
|
||||
logits = array_ops.concat([part_before, part_after, part_blank], axis=2)
|
||||
labels = sparse_tensor.SparseTensor(
|
||||
labels.indices,
|
||||
array_ops.where(labels.values < blank_index, labels.values,
|
||||
labels.values - 1), labels.dense_shape)
|
||||
return _ctc_loss_impl(
|
||||
labels=labels,
|
||||
inputs=logits,
|
||||
sequence_length=logit_length,
|
||||
time_major=logits_time_major,
|
||||
use_cudnn=False)
|
||||
|
||||
|
||||
def _ctc_loss_op_cudnn(labels, logits, logit_length, logits_time_major,
|
||||
blank_index):
|
||||
part_before = logits[:, :, :blank_index]
|
||||
part_after = logits[:, :, blank_index + 1:]
|
||||
part_blank = logits[:, :, blank_index:blank_index + 1]
|
||||
logits = array_ops.concat([part_blank, part_before, part_after], axis=2)
|
||||
labels = sparse_tensor.SparseTensor(
|
||||
labels.indices,
|
||||
array_ops.where(labels.values < blank_index, labels.values + 1,
|
||||
labels.values), labels.dense_shape)
|
||||
return _ctc_loss_impl(
|
||||
labels=labels,
|
||||
inputs=logits,
|
||||
sequence_length=logit_length,
|
||||
time_major=logits_time_major,
|
||||
use_cudnn=True)
|
||||
|
||||
|
||||
def _ctc_loss_shape(op):
|
||||
return [op.inputs[2].get_shape(), op.inputs[0].get_shape()]
|
||||
|
||||
|
||||
@tf_export("nn.ctc_loss", v1=["nn.ctc_loss_v2"])
|
||||
# pylint: disable=protected-access, invalid-name
|
||||
@tf_export(v1=["nn.ctc_loss_v2"])
|
||||
def ctc_loss_v2(labels,
|
||||
logits,
|
||||
label_length,
|
||||
@ -691,6 +804,111 @@ def ctc_loss_v2(labels,
|
||||
name=name)
|
||||
|
||||
|
||||
@tf_export("nn.ctc_loss", v1=[])
|
||||
def ctc_loss_v3(labels,
|
||||
logits,
|
||||
label_length,
|
||||
logit_length,
|
||||
logits_time_major=True,
|
||||
unique=None,
|
||||
blank_index=None,
|
||||
name=None):
|
||||
"""Computes CTC (Connectionist Temporal Classification) loss.
|
||||
|
||||
This op implements the CTC loss as presented in (Graves et al., 2016).
|
||||
|
||||
Notes:
|
||||
|
||||
- Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss
|
||||
setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True
|
||||
- Labels may be supplied as either a dense, zero-padded tensor with a
|
||||
vector of label sequence lengths OR as a SparseTensor.
|
||||
- On TPU and GPU: Only dense padded labels are supported.
|
||||
- On CPU: Caller may use SparseTensor or dense padded labels but calling with
|
||||
a SparseTensor will be significantly faster.
|
||||
- Default blank label is 0 rather num_classes - 1, unless overridden by
|
||||
blank_index.
|
||||
|
||||
Args:
|
||||
labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor
|
||||
logits: tensor of shape [frames, batch_size, num_labels], if
|
||||
logits_time_major == False, shape is [batch_size, frames, num_labels].
|
||||
label_length: tensor of shape [batch_size], None if labels is SparseTensor
|
||||
Length of reference label sequence in labels.
|
||||
logit_length: tensor of shape [batch_size] Length of input sequence in
|
||||
logits.
|
||||
logits_time_major: (optional) If True (default), logits is shaped [time,
|
||||
batch, logits]. If False, shape is [batch, time, logits]
|
||||
unique: (optional) Unique label indices as computed by
|
||||
ctc_unique_labels(labels). If supplied, enable a faster, memory efficient
|
||||
implementation on TPU.
|
||||
blank_index: (optional) Set the class index to use for the blank label.
|
||||
Negative values will start from num_classes, ie, -1 will reproduce the
|
||||
ctc_loss behavior of using num_classes - 1 for the blank symbol. There is
|
||||
some memory/performance overhead to switching from the default of 0 as an
|
||||
additional shifted copy of the logits may be created.
|
||||
name: A name for this `Op`. Defaults to "ctc_loss_dense".
|
||||
|
||||
Returns:
|
||||
loss: tensor of shape [batch_size], negative log probabilities.
|
||||
|
||||
References:
|
||||
Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
|
||||
with Recurrent Neural Networks:
|
||||
[Graves et al., 2016](https://dl.acm.org/citation.cfm?id=1143891)
|
||||
([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
|
||||
"""
|
||||
if isinstance(labels, sparse_tensor.SparseTensor):
|
||||
if blank_index is None:
|
||||
raise ValueError(
|
||||
"blank_index must be given when using SparseTensor labels.")
|
||||
|
||||
if blank_index < 0:
|
||||
blank_index += _get_dim(logits, 2)
|
||||
|
||||
params = {
|
||||
"labels": labels,
|
||||
"logits": logits,
|
||||
"logit_length": logit_length,
|
||||
"logits_time_major": logits_time_major,
|
||||
"blank_index": blank_index
|
||||
}
|
||||
|
||||
if context.executing_eagerly():
|
||||
device_type = _get_context_device_type()
|
||||
can_use_gpu = (
|
||||
# Either user specified GPU or unspecified but GPU is available.
|
||||
(device_type == _GPU_DEVICE_NAME or
|
||||
(device_type is None and context.num_gpus() > 0)))
|
||||
# Under eager context, check the device placement and prefer the
|
||||
if can_use_gpu:
|
||||
res = _ctc_loss_op_cudnn(**params)
|
||||
else:
|
||||
res = _ctc_loss_op_standard(**params)
|
||||
else:
|
||||
api_name = "ctc_loss_" + str(uuid.uuid4())
|
||||
ctc_loss_op_standard = _generate_defun_backend(api_name, _CPU_DEVICE_NAME,
|
||||
_ctc_loss_op_standard)
|
||||
ctc_loss_op_cudnn = _generate_defun_backend(api_name, _GPU_DEVICE_NAME,
|
||||
_ctc_loss_op_cudnn)
|
||||
res = ctc_loss_op_standard(**params)
|
||||
function_eager.register(ctc_loss_op_cudnn, **params)
|
||||
return res
|
||||
|
||||
if blank_index is None:
|
||||
blank_index = 0
|
||||
|
||||
return ctc_loss_dense(
|
||||
labels=labels,
|
||||
logits=logits,
|
||||
label_length=label_length,
|
||||
logit_length=logit_length,
|
||||
logits_time_major=logits_time_major,
|
||||
unique=unique,
|
||||
blank_index=blank_index,
|
||||
name=name)
|
||||
|
||||
|
||||
def ctc_loss_dense(labels,
|
||||
logits,
|
||||
label_length,
|
||||
|
@ -408,6 +408,13 @@ struct PersistentRnnPlanDeleter {
|
||||
CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
|
||||
}
|
||||
};
|
||||
#if CUDNN_VERSION >= 7603
|
||||
struct CtcLossDescriptorDeleter {
|
||||
void operator()(cudnnCTCLossDescriptor_t descriptor) const {
|
||||
CHECK_CUDNN_OK(cudnnDestroyCTCLossDescriptor(descriptor));
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
// RAII wrappers for cuDNN types.
|
||||
using TensorDescriptor =
|
||||
@ -430,6 +437,10 @@ using DropoutDescriptor =
|
||||
using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
|
||||
using PersistentRnnPlan =
|
||||
std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
|
||||
#if CUDNN_VERSION >= 7603
|
||||
using CtcLossDescriptor =
|
||||
std::unique_ptr<cudnnCTCLossStruct, CtcLossDescriptorDeleter>;
|
||||
#endif
|
||||
|
||||
// Factory methods for cuDNN types.
|
||||
TensorDescriptor CreateTensorDescriptor() {
|
||||
@ -479,6 +490,13 @@ RnnDescriptor CreateRnnDescriptor() {
|
||||
CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
|
||||
return RnnDescriptor(result);
|
||||
}
|
||||
#if CUDNN_VERSION >= 7603
|
||||
CtcLossDescriptor CreateCtcLossDescriptor() {
|
||||
cudnnCTCLossDescriptor_t result;
|
||||
CHECK_CUDNN_OK(cudnnCreateCTCLossDescriptor(&result));
|
||||
return CtcLossDescriptor(result);
|
||||
}
|
||||
#endif
|
||||
|
||||
port::StatusOr<PersistentRnnPlan> CreatePersistentRnnPlan(
|
||||
cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) {
|
||||
@ -1193,6 +1211,33 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
|
||||
};
|
||||
|
||||
#if CUDNN_VERSION >= 7603
|
||||
class CudnnCtcLossDescriptor {
|
||||
public:
|
||||
explicit CudnnCtcLossDescriptor(cudnnDataType_t data_type)
|
||||
: handle_(CreateCtcLossDescriptor()) {
|
||||
CHECK_CUDNN_OK(cudnnSetCTCLossDescriptorEx(
|
||||
/*ctcLossDesc=*/handle_.get(),
|
||||
/*compType=*/data_type,
|
||||
/*normMode=*/CUDNN_LOSS_NORMALIZATION_SOFTMAX,
|
||||
/*gradMode=*/CUDNN_NOT_PROPAGATE_NAN));
|
||||
}
|
||||
|
||||
cudnnCTCLossDescriptor_t handle() const { return handle_.get(); }
|
||||
|
||||
private:
|
||||
CtcLossDescriptor handle_; // Owned
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(CudnnCtcLossDescriptor);
|
||||
};
|
||||
#else
|
||||
// dummy class
|
||||
class CudnnCtcLossDescriptor {
|
||||
public:
|
||||
CudnnCtcLossDescriptor(cudnnDataType_t data_type) {}
|
||||
};
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
// Check if the LSTM projection is used. If yes, an additional weigth matrix
|
||||
@ -1660,6 +1705,7 @@ port::StatusOr<DeviceMemory<uint8>> CreateBatchNormBackwardWorkspace(
|
||||
}
|
||||
return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
@ -1973,6 +2019,43 @@ port::Status CudnnSupport::DoRnnBackwardImpl(
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::Status CudnnSupport::DoCtcLossImpl(
|
||||
Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
|
||||
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||
const CudnnRnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
|
||||
DeviceMemory<uint8> scratch_memory) {
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
|
||||
int kNumTimestamps = probs_desc.num_layers();
|
||||
int kBatchSize = probs_desc.batch_size();
|
||||
int kNumLabels = probs_desc.data_size();
|
||||
int total_size = kNumLabels * kNumTimestamps * kBatchSize;
|
||||
(void)total_size;
|
||||
|
||||
#if CUDNN_VERSION >= 7603
|
||||
RETURN_IF_CUDNN_ERROR(cudnnCTCLoss(
|
||||
/*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
|
||||
/*probs=*/probs_data.opaque(), /*labels=*/labels_data.data(),
|
||||
/*labelLengths=*/labels_lengths_data.data(),
|
||||
/*inputLengths=*/input_lengths_data.data(),
|
||||
/*costs=*/costs_data.opaque(), /*gradientsDesc=*/grads_desc.handle(),
|
||||
/*gradients=*/grads_data.opaque(),
|
||||
/*algo=*/CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
|
||||
/*ctcLossDesc=*/ctc_loss_desc.handle(),
|
||||
/*workspace=*/scratch_memory.opaque(),
|
||||
/*workSpaceSizeInBytes=*/scratch_memory.size()));
|
||||
#else
|
||||
return port::Status(port::error::INVALID_ARGUMENT,
|
||||
"No supported cudnnCTCLoss when "
|
||||
"CUDNN_VERSION < 7.6.3");
|
||||
#endif
|
||||
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
||||
CudnnSupport::createRnnDescriptor(
|
||||
int num_layers, int hidden_size, int input_size, int cell_size,
|
||||
@ -3832,6 +3915,79 @@ bool CudnnSupport::DoFusedConvolve(
|
||||
/*report_error=*/!output_profile_result);
|
||||
}
|
||||
|
||||
port::Status CudnnSupport::DoPrepareForCtcLoss(
|
||||
Stream* stream, dnn::DataType element_type,
|
||||
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory) {
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
|
||||
const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
|
||||
static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
|
||||
const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
|
||||
static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
|
||||
// Query the workspace size.
|
||||
size_t workspace_size_in_bytes = 0;
|
||||
#if CUDNN_VERSION >= 7603
|
||||
RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
|
||||
/*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
|
||||
/*gradientsDesc=*/cudnn_grads_desc.handle(),
|
||||
/*labels=*/labels_data.data(),
|
||||
/*labelLengths=*/labels_lengths_data.data(),
|
||||
/*inputLengths=*/input_lengths_data.data(),
|
||||
/*algo=*/CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
|
||||
/*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
|
||||
/*sizeInBytes=*/&workspace_size_in_bytes));
|
||||
#else
|
||||
return port::Status(port::error::INVALID_ARGUMENT,
|
||||
"No supported cudnnGetCTCLossWorkspaceSize when "
|
||||
"CUDNN_VERSION < 7.6.3");
|
||||
#endif
|
||||
// Allocate the workspace.
|
||||
if (workspace_size_in_bytes == 0) {
|
||||
*scratch_memory = DeviceMemory<uint8>();
|
||||
return port::Status::OK();
|
||||
}
|
||||
const auto scratch_or =
|
||||
scratch_allocator->AllocateBytes(workspace_size_in_bytes);
|
||||
if (scratch_or.ok()) {
|
||||
*scratch_memory = scratch_or.ValueOrDie();
|
||||
return port::Status::OK();
|
||||
}
|
||||
return port::InternalError(
|
||||
"Failed to allocate scratch memory for the CuDNN CTC Loss");
|
||||
}
|
||||
|
||||
port::Status CudnnSupport::DoCtcLoss(
|
||||
Stream* stream, dnn::DataType element_type,
|
||||
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||
const DeviceMemoryBase probs_data,
|
||||
|
||||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory) {
|
||||
// Current cuDNN CTC Loss only supports the float datatype
|
||||
if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) {
|
||||
return port::Status(port::error::INVALID_ARGUMENT,
|
||||
"CudnnCtcLossDescriptor is supported only when the "
|
||||
"CUDNN_VERSION >= 7.6.3 and DataType is float");
|
||||
}
|
||||
CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
|
||||
const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
|
||||
static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
|
||||
const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
|
||||
static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
|
||||
return DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data,
|
||||
labels_lengths_data, input_lengths_data, costs_data,
|
||||
cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc,
|
||||
scratch_memory);
|
||||
}
|
||||
|
||||
bool CudnnSupport::DoTransformTensor(Stream* stream,
|
||||
const dnn::BatchDescriptor& input_desc,
|
||||
dnn::DataType input_type,
|
||||
|
@ -33,6 +33,7 @@ class GpuExecutor;
|
||||
class CudnnRnnDescriptor;
|
||||
class CudnnRnnSequenceTensorDescriptor;
|
||||
class CudnnRnnStateTensorDescriptor;
|
||||
class CudnnCtcLossDescriptor;
|
||||
|
||||
// Opaque and unique identifier for the cuDNN plugin.
|
||||
extern const PluginId kCuDnnPlugin;
|
||||
@ -562,6 +563,17 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
dnn::BatchDescriptor* output_batch_descriptor);
|
||||
|
||||
port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
|
||||
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||
const DeviceMemoryBase probs_data,
|
||||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
DeviceMemoryBase costs_data,
|
||||
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data,
|
||||
DeviceMemory<uint8> scratch_memory) override;
|
||||
|
||||
bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
|
||||
dnn::DataType input_type,
|
||||
const DeviceMemoryBase& input_data,
|
||||
@ -673,6 +685,15 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||
ScratchAllocator* workspace_allocator,
|
||||
dnn::ProfileResult* output_profile_result);
|
||||
|
||||
port::Status DoCtcLossImpl(
|
||||
Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
|
||||
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||
const CudnnRnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
|
||||
DeviceMemory<uint8> scratch_memory);
|
||||
|
||||
private:
|
||||
port::Status DoPrepareForConvolution(
|
||||
dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
|
||||
@ -686,6 +707,16 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||
ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
|
||||
DeviceMemory<uint8>* scratch_memory) override;
|
||||
|
||||
port::Status DoPrepareForCtcLoss(
|
||||
Stream* stream, dnn::DataType element_type,
|
||||
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
DeviceMemory<uint8>* scratch_memory) override;
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
|
||||
};
|
||||
|
||||
|
@ -615,5 +615,18 @@ bool DnnSupport::IsStatusOk(const port::Status& status, bool report_error) {
|
||||
return false;
|
||||
}
|
||||
|
||||
port::Status DnnSupport::DoCtcLoss(Stream* stream, dnn::DataType element_type,
|
||||
const RnnStateTensorDescriptor& probs_desc,
|
||||
const DeviceMemoryBase probs_data,
|
||||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
DeviceMemoryBase costs_data,
|
||||
const RnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data,
|
||||
DeviceMemory<uint8> scratch_memory) {
|
||||
return port::UnimplementedError("CtcLoss not implemented");
|
||||
}
|
||||
|
||||
} // namespace dnn
|
||||
} // namespace stream_executor
|
||||
|
@ -2391,6 +2391,73 @@ class DnnSupport {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
port::Status PrepareForCtcLoss(Stream* stream,
|
||||
const RnnStateTensorDescriptor& probs_desc,
|
||||
DeviceMemory<ElementType> probs_data,
|
||||
const RnnStateTensorDescriptor& grads_desc,
|
||||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
ScratchAllocator* workspace_allocator,
|
||||
DeviceMemory<uint8>* scratch_memory) {
|
||||
return DoPrepareForCtcLoss(stream, ToDataType<ElementType>::value,
|
||||
probs_desc, grads_desc, labels_data,
|
||||
labels_lengths_data, input_lengths_data,
|
||||
workspace_allocator, scratch_memory);
|
||||
}
|
||||
|
||||
// Enqueue a CTC Loss operation onto the stream.
|
||||
//
|
||||
// Arguments:
|
||||
// stream: pointer to the stream where this operation should be enqueued to.
|
||||
// element_type: date type of the input tensors
|
||||
// probs_desc: specifies the shape and the data layout of the input tensor.
|
||||
// probs_data: the device memory region that contains the input tensor.
|
||||
// labels_data: the device memory region that contains the labels_value
|
||||
// tensor.
|
||||
// labels_lengths_data: the device memory region that contains the
|
||||
// labels_lengths tensor
|
||||
// input_lengths_data: the device memory region that contains the seq_lengths
|
||||
// tensor
|
||||
// costs_data: the device memory region that contains the costs tensor.
|
||||
// grads_desc: specifies the shape and the data layout of the grads tensor.
|
||||
// grads_data: the device memory region that contains the grads tensor.
|
||||
// ctc_loss_desc: a CTCLoss descriptor.
|
||||
// workspace_allocator: a memory allocator that creates the temporary
|
||||
// workspace memory used by this operation. The caller is responsible for
|
||||
// keeping the memory alive long enough for this operation, and recylces
|
||||
// afterwards.
|
||||
virtual port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
|
||||
const RnnStateTensorDescriptor& probs_desc,
|
||||
const DeviceMemoryBase probs_data,
|
||||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
DeviceMemoryBase costs_data,
|
||||
const RnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data,
|
||||
DeviceMemory<uint8> scratch_memory);
|
||||
|
||||
template <typename ElementType>
|
||||
bool DoCtcLoss(Stream* stream,
|
||||
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||
const DeviceMemory<ElementType>& probs_data,
|
||||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
DeviceMemory<ElementType>* costs_data,
|
||||
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemory<ElementType>* grads_data,
|
||||
DeviceMemory<uint8>* scratch_memory) {
|
||||
return IsStatusOk(
|
||||
DoCtcLoss(stream, ToDataType<ElementType>::value, probs_desc,
|
||||
probs_data, labels_data, labels_lengths_data,
|
||||
input_lengths_data, *costs_data, grads_desc, *grads_data,
|
||||
*scratch_memory),
|
||||
false);
|
||||
}
|
||||
|
||||
// Transforms a tensor into another tensor with a different layout and/or data
|
||||
// type.
|
||||
//
|
||||
@ -2646,6 +2713,19 @@ class DnnSupport {
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
virtual port::Status DoPrepareForCtcLoss(
|
||||
Stream* stream, DataType element_type,
|
||||
const RnnStateTensorDescriptor& probs_desc,
|
||||
const RnnStateTensorDescriptor& grads_desc,
|
||||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
DeviceMemory<uint8>* scratch_memory) {
|
||||
*scratch_memory = {};
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport);
|
||||
};
|
||||
|
||||
|
@ -5230,6 +5230,39 @@ Stream &Stream::ThenRnnBackward(
|
||||
return *this;
|
||||
}
|
||||
|
||||
Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
|
||||
const DeviceMemory<float> &probs_data,
|
||||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
DeviceMemory<float> *costs_data,
|
||||
const dnn::RnnStateTensorDescriptor &grads_desc,
|
||||
DeviceMemory<float> *grads_data,
|
||||
ScratchAllocator *workspace_allocator) {
|
||||
if (ok()) {
|
||||
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||
DeviceMemory<uint8> scratch_memory;
|
||||
auto status = dnn->PrepareForCtcLoss(
|
||||
this, probs_desc, probs_data, grads_desc,
|
||||
labels_data, labels_lengths_data, input_lengths_data,
|
||||
workspace_allocator, &scratch_memory)
|
||||
.ok();
|
||||
if (status) {
|
||||
status =
|
||||
dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data,
|
||||
labels_lengths_data, input_lengths_data, costs_data,
|
||||
grads_desc, grads_data, &scratch_memory);
|
||||
}
|
||||
if (!status) {
|
||||
SetError();
|
||||
}
|
||||
} else {
|
||||
SetErrorAndLogNoDnnSupport();
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
|
||||
dnn::DataType input_type,
|
||||
const DeviceMemoryBase &input_data,
|
||||
|
@ -1912,6 +1912,18 @@ class Stream {
|
||||
ScratchAllocator *workspace_allocator,
|
||||
dnn::ProfileResult *output_profile_result);
|
||||
|
||||
// Enqueue a CTCLoss operation onto the stream.
|
||||
// See DnnSupport::DoCtcLoss for more details.
|
||||
Stream &ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
|
||||
const DeviceMemory<float> &probs_data,
|
||||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
DeviceMemory<float> *costs_data,
|
||||
const dnn::RnnStateTensorDescriptor &grads_desc,
|
||||
DeviceMemory<float> *grads_data,
|
||||
ScratchAllocator *workspace_allocator);
|
||||
|
||||
// Enqueue onto the stream a operation that transforms a tensor.
|
||||
// See DnnSupport::DoTransformTensor for more details.
|
||||
Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
|
||||
|
@ -636,6 +636,10 @@ tf_module {
|
||||
name: "CTCLoss"
|
||||
argspec: "args=[\'inputs\', \'labels_indices\', \'labels_values\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'ignore_longer_outputs_than_inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CTCLossV2"
|
||||
argspec: "args=[\'inputs\', \'labels_indices\', \'labels_values\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'ignore_longer_outputs_than_inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CacheDataset"
|
||||
argspec: "args=[\'input_dataset\', \'filename\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -636,6 +636,10 @@ tf_module {
|
||||
name: "CTCLoss"
|
||||
argspec: "args=[\'inputs\', \'labels_indices\', \'labels_values\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'ignore_longer_outputs_than_inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CTCLossV2"
|
||||
argspec: "args=[\'inputs\', \'labels_indices\', \'labels_values\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'ignore_longer_outputs_than_inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CacheDataset"
|
||||
argspec: "args=[\'input_dataset\', \'filename\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user