Merge pull request #32302 from houtoms:pr_cudnn_ctc_loss

PiperOrigin-RevId: 290387603
Change-Id: I28491f42a4559a9f79bd6a7b73d8e6b670f55368
This commit is contained in:
TensorFlower Gardener 2020-01-17 20:32:44 -08:00
commit bd4c38b3dc
15 changed files with 941 additions and 15 deletions

View File

@ -0,0 +1,72 @@
op {
graph_op_name: "CTCLossV2"
visibility: HIDDEN
in_arg {
name: "inputs"
description: <<END
3-D, shape: `(max_time x batch_size x num_classes)`, the logits. Default blank
label is 0 rather num_classes - 1.
END
}
in_arg {
name: "labels_indices"
description: <<END
The indices of a `SparseTensor<int32, 2>`.
`labels_indices(i, :) == [b, t]` means `labels_values(i)` stores the id for
`(batch b, time t)`.
END
}
in_arg {
name: "labels_values"
description: <<END
The values (labels) associated with the given batch and time.
END
}
in_arg {
name: "sequence_length"
description: <<END
A vector containing sequence lengths (batch).
END
}
out_arg {
name: "loss"
description: <<END
A vector (batch) containing log-probabilities.
END
}
out_arg {
name: "gradient"
description: <<END
The gradient of `loss`. 3-D, shape:
`(max_time x batch_size x num_classes)`.
END
}
attr {
name: "preprocess_collapse_repeated"
description: <<END
Scalar, if true then repeated labels are
collapsed prior to the CTC calculation.
END
}
attr {
name: "ctc_merge_repeated"
description: <<END
Scalar. If set to false, *during* CTC calculation
repeated non-blank labels will not be merged and are interpreted as
individual labels. This is a simplified version of CTC.
END
}
attr {
name: "ignore_longer_outputs_than_inputs"
description: <<END
Scalar. If set to true, during CTC
calculation, items that have longer output sequences than input sequences
are skipped: they don't contribute to the loss term and have zero-gradient.
END
}
summary: "Calculates the CTC Loss (log probability) for each batch entry. Also calculates"
description: <<END
the gradient. This class performs the softmax operation for you, so inputs
should be e.g. linear projections of outputs by an LSTM.
END
}

View File

@ -2358,7 +2358,11 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core/util/ctc:ctc_beam_search_lib",
"//tensorflow/core/util/ctc:ctc_loss_calculator_lib",
],
] + if_cuda([
":gpu_utils",
":conv_ops_gpu_hdrs",
"@local_config_cuda//cuda:cudnn_header",
]),
)
tf_cc_test(

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {

View File

@ -15,6 +15,10 @@ limitations under the License.
// See docs in ../ops/ctc_ops.cc.
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -25,8 +29,39 @@ limitations under the License.
#include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
#if GOOGLE_CUDA
#include "third_party/gpus/cudnn/cudnn.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/util/stream_executor_util.h"
#include "tensorflow/core/util/tensor_format.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
#if GOOGLE_CUDA
using GPUDevice = Eigen::GpuDevice;
namespace {
using se::Stream;
using se::StreamExecutor;
using se::dnn::RnnStateTensorDescriptor;
using se::dnn::ToDataType;
template <typename T>
void DoHistogram(OpKernelContext* ctx, const Tensor* labels_indices,
int num_indices, int batch_size,
std::vector<int>* labels_lengths) {
const T* h_in = labels_indices->flat<T>().data();
for (int i = 0; i < num_indices; i++) {
const T& key = h_in[i * 2];
(*labels_lengths)[key]++;
}
}
} // end namespace
#endif // GOOGLE_CUDA
template <typename T>
class CTCLossOp : public OpKernel {
typedef Eigen::Map<
@ -186,4 +221,150 @@ REGISTER_CPU(double);
#undef REGISTER_CPU
#if GOOGLE_CUDA && CUDNN_VERSION >= 7603
class CTCLossOpGPU : public OpKernel {
public:
explicit CTCLossOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) {
bool preprocess_collapse_repeated;
bool ctc_merge_repeated;
bool ignore_longer_outputs_than_inputs;
OP_REQUIRES_OK(ctx, ctx->GetAttr("preprocess_collapse_repeated",
&preprocess_collapse_repeated));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated));
OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs",
&ignore_longer_outputs_than_inputs));
OP_REQUIRES(ctx, !preprocess_collapse_repeated,
errors::InvalidArgument("GPU CTCLossOp requires "
"preprocess_collapse_repeated to be "
"false"));
OP_REQUIRES(ctx, ctc_merge_repeated,
errors::InvalidArgument("GPU CTCLossOp requires "
"ctc_merge_repeated to be "
"true"));
OP_REQUIRES(ctx, !ignore_longer_outputs_than_inputs,
errors::InvalidArgument("GPU CTCLossOp requires "
"ignore_longer_outputs_than_inputs to"
"be false"));
}
void Compute(OpKernelContext* ctx) override {
const Tensor* inputs;
const Tensor* labels_indices;
const Tensor* labels_values;
const Tensor* seq_len;
OP_REQUIRES_OK(ctx, ctx->input("inputs", &inputs));
OP_REQUIRES_OK(ctx, ctx->input("labels_indices", &labels_indices));
OP_REQUIRES_OK(ctx, ctx->input("labels_values", &labels_values));
OP_REQUIRES_OK(ctx, ctx->input("sequence_length", &seq_len));
OP_REQUIRES(ctx, inputs->shape().dims() == 3,
errors::InvalidArgument("inputs is not a 3-Tensor"));
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(seq_len->shape()),
errors::InvalidArgument("sequence_length is not a vector"));
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_indices->shape()),
errors::InvalidArgument("labels_indices is not a matrix"));
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_values->shape()),
errors::InvalidArgument("labels_values is not a vector"));
const TensorShape& inputs_shape = inputs->shape();
const int64 max_time_raw = inputs_shape.dim_size(0);
const int64 batch_size_raw = inputs_shape.dim_size(1);
const int64 num_classes_raw = inputs_shape.dim_size(2);
OP_REQUIRES(ctx,
FastBoundsCheck(max_time_raw, std::numeric_limits<int>::max()),
errors::InvalidArgument("max_time_ cannot exceed max int"));
OP_REQUIRES(
ctx, FastBoundsCheck(batch_size_raw, std::numeric_limits<int>::max()),
errors::InvalidArgument("batch_size cannot exceed max int"));
OP_REQUIRES(
ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
errors::InvalidArgument("num_classes cannot exceed max int"));
const int max_time = static_cast<const int>(max_time_raw);
const int batch_size = static_cast<const int>(batch_size_raw);
const int num_classes = static_cast<const int>(num_classes_raw);
OP_REQUIRES(
ctx, batch_size == seq_len->dim_size(0),
errors::InvalidArgument("len(sequence_length) != batch_size. ",
"len(sequence_length): ", seq_len->dim_size(0),
" batch_size: ", batch_size));
OP_REQUIRES(ctx, labels_indices->dim_size(0) == labels_values->dim_size(0),
errors::InvalidArgument(
"labels_indices and labels_values must contain the "
"same number of rows, but saw shapes: ",
labels_indices->shape().DebugString(), " vs. ",
labels_values->shape().DebugString()));
auto num_indices = labels_indices->dim_size(0);
OP_REQUIRES(ctx, batch_size != 0,
errors::InvalidArgument("batch_size must not be 0"));
Tensor* loss = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss));
Tensor* gradient = nullptr;
OP_REQUIRES_OK(ctx,
ctx->allocate_output("gradient", inputs_shape, &gradient));
// Convert the labels_indices to labels_lengths.
std::vector<int> labels_lengths(batch_size, 0);
DoHistogram<int64>(ctx, labels_indices, num_indices, batch_size,
&labels_lengths);
StreamExecutor* executor = ctx->op_device_context()->stream()->parent();
se::dnn::DataType data_type = ToDataType<float>::value;
auto probs_desc_s = executor->createRnnStateTensorDescriptor(
max_time, batch_size, num_classes, data_type);
OP_REQUIRES_OK(ctx, probs_desc_s.status());
std::unique_ptr<RnnStateTensorDescriptor> probs_desc =
probs_desc_s.ConsumeValueOrDie();
auto grads_desc_s = executor->createRnnStateTensorDescriptor(
max_time, batch_size, num_classes, data_type);
OP_REQUIRES_OK(ctx, grads_desc_s.status());
std::unique_ptr<RnnStateTensorDescriptor> grads_desc =
grads_desc_s.ConsumeValueOrDie();
absl::Span<const int32> labels_data(labels_values->flat<int32>().data(),
num_indices);
absl::Span<const int32> labels_lengths_data(labels_lengths.data(),
batch_size);
absl::Span<const int32> input_lengths_data(seq_len->flat<int32>().data(),
batch_size);
auto probs_data = StreamExecutorUtil::AsDeviceMemory<float>(*inputs);
auto costs_data = StreamExecutorUtil::AsDeviceMemory<float>(*loss);
auto grads_data = StreamExecutorUtil::AsDeviceMemory<float>(*gradient);
// Set the memory limitation to 4GB for workspace memory.
DnnScratchAllocator workspace_allocator(1LL << 32, ctx);
Stream* stream = ctx->op_device_context()->stream();
bool cudnn_launch_status =
stream
->ThenCtcLoss(*probs_desc, probs_data, labels_data,
labels_lengths_data, input_lengths_data, &costs_data,
*grads_desc, &grads_data, &workspace_allocator)
.ok();
if (!cudnn_launch_status) {
ctx->SetStatus(errors::Internal("cuDNN CTCLoss launch failure"));
}
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOpGPU);
};
REGISTER_KERNEL_BUILDER(Name("CTCLossV2")
.Device(DEVICE_GPU)
.HostMemory("labels_indices")
.HostMemory("labels_values")
.HostMemory("sequence_length"),
CTCLossOpGPU);
#endif // GOOGLE_CUDA && CUDNN_VERSION >= 7603
} // end namespace tensorflow

View File

@ -62,6 +62,43 @@ REGISTER_OP("CTCLoss")
return Status::OK();
});
REGISTER_OP("CTCLossV2")
.Input("inputs: float")
.Input("labels_indices: int64")
.Input("labels_values: int32")
.Input("sequence_length: int32")
.Attr("preprocess_collapse_repeated: bool = false")
.Attr("ctc_merge_repeated: bool = true")
.Attr("ignore_longer_outputs_than_inputs: bool = false")
.Output("loss: float")
.Output("gradient: float")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle inputs;
ShapeHandle labels_indices;
ShapeHandle labels_values;
ShapeHandle sequence_length;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &labels_indices));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &labels_values));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &sequence_length));
DimensionHandle unused;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(labels_indices, 0),
c->Dim(labels_values, 0), &unused));
// Get batch size from inputs and sequence_length, and update inputs
// with the merged batch_size since it is returned.
DimensionHandle batch_size;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
TF_RETURN_IF_ERROR(c->ReplaceDim(inputs, 1, batch_size, &inputs));
c->set_output(0, c->Vector(batch_size));
c->set_output(1, inputs);
return Status::OK();
});
REGISTER_OP("CTCGreedyDecoder")
.Input("inputs: T")
.Input("sequence_length: int32")

View File

@ -18,10 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@ -29,6 +31,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import ctc_ops
from tensorflow.python.ops import gradients_impl
@ -839,5 +842,82 @@ class CTCLossTestV2(test.TestCase):
self.assertAllEqual(
[[1.0, 2.0], [5.0, 8.0], [14.0, 20.0]], out)
@keras_parameterized.run_all_keras_modes
class CTCLossTestV3(keras_parameterized.TestCase):
@parameterized.parameters([False, True])
@test_util.run_v2_only
def testCtcLossV3(self, run_tf_func):
"""Testing GPU CTC loss.
testing if GPU CTC loss will generate same result with CPU version
"""
if not test.is_gpu_available():
self.skipTest("Need GPU for testing.")
random_seed.set_random_seed(5)
batch_size = 8
num_labels = 6
max_label_length = 5
num_frames = 12
labels = random_ops.random_uniform([batch_size, max_label_length],
minval=1,
maxval=num_labels,
dtype=dtypes.int64)
logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
label_length = random_ops.random_uniform([batch_size],
minval=2,
maxval=max_label_length,
dtype=dtypes.int64)
label_mask = array_ops.sequence_mask(
label_length, maxlen=max_label_length, dtype=label_length.dtype)
labels *= label_mask
logit_length = [num_frames] * batch_size
def ctc_loss_cpu(labels, logits, label_length, logit_length):
with test_util.device(use_gpu=False):
sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
with backprop.GradientTape() as t:
t.watch(logits)
ref_loss = ctc_ops.ctc_loss_v3(
labels=sparse_labels,
logits=logits,
label_length=label_length,
logit_length=logit_length,
blank_index=0)
ref_grad = t.gradient(ref_loss, [logits])
return ref_loss, ref_grad
def ctc_loss_gpu(labels, logits, label_length, logit_length):
with test_util.device(use_gpu=True):
sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
with backprop.GradientTape() as t:
t.watch(logits)
loss = ctc_ops.ctc_loss_v3(
labels=sparse_labels,
logits=logits,
label_length=label_length,
logit_length=logit_length,
blank_index=0)
grad = t.gradient(loss, [logits])
return loss, grad
if run_tf_func:
ctc_loss_cpu = def_function.function(ctc_loss_cpu)
ctc_loss_gpu = def_function.function(ctc_loss_gpu)
ref_loss, ref_grad = ctc_loss_cpu(labels, logits, label_length,
logit_length)
loss, grad = ctc_loss_gpu(labels, logits, label_length, logit_length)
self.assertAllClose(loss, ref_loss, atol=1e-6)
self.assertAllClose(grad, ref_grad, atol=2e-6)
if __name__ == "__main__":
test.main()

View File

@ -18,9 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import uuid
from tensorflow.python.eager import context
from tensorflow.python.eager import function as function_eager
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
@ -42,6 +46,27 @@ from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
_DEFUN_API_NAME_ATTRIBUTE = "api_implements"
_DEFUN_DEVICE_ATTRIBUTE = "api_preferred_device"
_CPU_DEVICE_NAME = "CPU"
_GPU_DEVICE_NAME = "GPU"
def _get_context_device_type():
"""Parse the current context and return the device type, eg CPU/GPU."""
current_device = context.context().device_name
if current_device is None:
return None
return device.DeviceSpec.from_string(current_device).device_type
def _generate_defun_backend(unique_api_name, preferred_device, func):
function_attributes = {
_DEFUN_API_NAME_ATTRIBUTE: unique_api_name,
_DEFUN_DEVICE_ATTRIBUTE: preferred_device,
}
return function_eager.defun_with_attributes(
func=func, attributes=function_attributes, autograph=False)
# pylint: disable=protected-access, invalid-name
@tf_export(v1=["nn.ctc_loss"])
@ -156,6 +181,31 @@ def ctc_loss(labels,
[Graves et al., 2016](https://dl.acm.org/citation.cfm?id=1143891)
([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
"""
return _ctc_loss_impl(
labels,
inputs,
sequence_length,
preprocess_collapse_repeated,
ctc_merge_repeated,
ignore_longer_outputs_than_inputs,
time_major,
logits,
use_cudnn=False)
def _ctc_loss_impl(labels,
inputs=None,
sequence_length=None,
preprocess_collapse_repeated=False,
ctc_merge_repeated=True,
ignore_longer_outputs_than_inputs=False,
time_major=True,
logits=None,
use_cudnn=False):
# Helper function of ctc_loss with one additional param:
# use_cudnn: A bool to enable cuDNN CTC loss operation. If true, the blank
# index has to be 0.
# The second, third, etc output tensors contain the gradients. We use it in
# _CTCLossGrad() below.
if not isinstance(labels, sparse_tensor.SparseTensor):
@ -167,7 +217,14 @@ def ctc_loss(labels,
if not time_major:
inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,N) => (T,B,N)
loss, _ = gen_ctc_ops.ctc_loss(
# gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the
# blank index to be 0, but v1 views it as the last index.
if use_cudnn:
ctc_loss_func = gen_ctc_ops.ctc_loss_v2
else:
ctc_loss_func = gen_ctc_ops.ctc_loss
loss, _ = ctc_loss_func(
inputs,
labels.indices,
labels.values,
@ -178,19 +235,8 @@ def ctc_loss(labels,
return loss
# pylint: disable=unused-argument
@ops.RegisterGradient("CTCLoss")
def _CTCLossGrad(op, grad_loss, _):
"""The derivative provided by CTC Loss.
Args:
op: the CTCLoss op.
grad_loss: The backprop for cost.
Returns:
The CTC Loss gradient.
"""
def _CTCLossGradImpl(op, grad_loss, _):
# Outputs are: loss, grad
#
# Currently there is no way to take the second derivative of this op
@ -207,6 +253,36 @@ def _CTCLossGrad(op, grad_loss, _):
return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None]
# pylint: disable=unused-argument
@ops.RegisterGradient("CTCLoss")
def _CTCLossGrad(op, grad_loss, _):
"""The derivative provided by CTC Loss.
Args:
op: the CTCLoss op.
grad_loss: The backprop for cost.
Returns:
The CTC Loss gradient.
"""
return _CTCLossGradImpl(op, grad_loss, _)
# pylint: disable=unused-argument
@ops.RegisterGradient("CTCLossV2")
def _CTCLossV2Grad(op, grad_loss, _):
"""The derivative provided by CTC Loss V2.
Args:
op: the CTCLossV2 op.
grad_loss: The backprop for cost.
Returns:
The CTC Loss V2 gradient.
"""
return _CTCLossGradImpl(op, grad_loss, _)
@tf_export("nn.ctc_greedy_decoder")
def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
"""Performs greedy decoding on the logits given in input (best path).
@ -593,11 +669,48 @@ def _ctc_loss_grad(op, grad_loss, _):
return grad
def _ctc_loss_op_standard(labels, logits, logit_length, logits_time_major,
blank_index):
part_before = logits[:, :, :blank_index]
part_after = logits[:, :, blank_index + 1:]
part_blank = logits[:, :, blank_index:blank_index + 1]
logits = array_ops.concat([part_before, part_after, part_blank], axis=2)
labels = sparse_tensor.SparseTensor(
labels.indices,
array_ops.where(labels.values < blank_index, labels.values,
labels.values - 1), labels.dense_shape)
return _ctc_loss_impl(
labels=labels,
inputs=logits,
sequence_length=logit_length,
time_major=logits_time_major,
use_cudnn=False)
def _ctc_loss_op_cudnn(labels, logits, logit_length, logits_time_major,
blank_index):
part_before = logits[:, :, :blank_index]
part_after = logits[:, :, blank_index + 1:]
part_blank = logits[:, :, blank_index:blank_index + 1]
logits = array_ops.concat([part_blank, part_before, part_after], axis=2)
labels = sparse_tensor.SparseTensor(
labels.indices,
array_ops.where(labels.values < blank_index, labels.values + 1,
labels.values), labels.dense_shape)
return _ctc_loss_impl(
labels=labels,
inputs=logits,
sequence_length=logit_length,
time_major=logits_time_major,
use_cudnn=True)
def _ctc_loss_shape(op):
return [op.inputs[2].get_shape(), op.inputs[0].get_shape()]
@tf_export("nn.ctc_loss", v1=["nn.ctc_loss_v2"])
# pylint: disable=protected-access, invalid-name
@tf_export(v1=["nn.ctc_loss_v2"])
def ctc_loss_v2(labels,
logits,
label_length,
@ -691,6 +804,111 @@ def ctc_loss_v2(labels,
name=name)
@tf_export("nn.ctc_loss", v1=[])
def ctc_loss_v3(labels,
logits,
label_length,
logit_length,
logits_time_major=True,
unique=None,
blank_index=None,
name=None):
"""Computes CTC (Connectionist Temporal Classification) loss.
This op implements the CTC loss as presented in (Graves et al., 2016).
Notes:
- Same as the "Classic CTC" in TensorFlow 1.x's tf.compat.v1.nn.ctc_loss
setting of preprocess_collapse_repeated=False, ctc_merge_repeated=True
- Labels may be supplied as either a dense, zero-padded tensor with a
vector of label sequence lengths OR as a SparseTensor.
- On TPU and GPU: Only dense padded labels are supported.
- On CPU: Caller may use SparseTensor or dense padded labels but calling with
a SparseTensor will be significantly faster.
- Default blank label is 0 rather num_classes - 1, unless overridden by
blank_index.
Args:
labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor
logits: tensor of shape [frames, batch_size, num_labels], if
logits_time_major == False, shape is [batch_size, frames, num_labels].
label_length: tensor of shape [batch_size], None if labels is SparseTensor
Length of reference label sequence in labels.
logit_length: tensor of shape [batch_size] Length of input sequence in
logits.
logits_time_major: (optional) If True (default), logits is shaped [time,
batch, logits]. If False, shape is [batch, time, logits]
unique: (optional) Unique label indices as computed by
ctc_unique_labels(labels). If supplied, enable a faster, memory efficient
implementation on TPU.
blank_index: (optional) Set the class index to use for the blank label.
Negative values will start from num_classes, ie, -1 will reproduce the
ctc_loss behavior of using num_classes - 1 for the blank symbol. There is
some memory/performance overhead to switching from the default of 0 as an
additional shifted copy of the logits may be created.
name: A name for this `Op`. Defaults to "ctc_loss_dense".
Returns:
loss: tensor of shape [batch_size], negative log probabilities.
References:
Connectionist Temporal Classification - Labeling Unsegmented Sequence Data
with Recurrent Neural Networks:
[Graves et al., 2016](https://dl.acm.org/citation.cfm?id=1143891)
([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf))
"""
if isinstance(labels, sparse_tensor.SparseTensor):
if blank_index is None:
raise ValueError(
"blank_index must be given when using SparseTensor labels.")
if blank_index < 0:
blank_index += _get_dim(logits, 2)
params = {
"labels": labels,
"logits": logits,
"logit_length": logit_length,
"logits_time_major": logits_time_major,
"blank_index": blank_index
}
if context.executing_eagerly():
device_type = _get_context_device_type()
can_use_gpu = (
# Either user specified GPU or unspecified but GPU is available.
(device_type == _GPU_DEVICE_NAME or
(device_type is None and context.num_gpus() > 0)))
# Under eager context, check the device placement and prefer the
if can_use_gpu:
res = _ctc_loss_op_cudnn(**params)
else:
res = _ctc_loss_op_standard(**params)
else:
api_name = "ctc_loss_" + str(uuid.uuid4())
ctc_loss_op_standard = _generate_defun_backend(api_name, _CPU_DEVICE_NAME,
_ctc_loss_op_standard)
ctc_loss_op_cudnn = _generate_defun_backend(api_name, _GPU_DEVICE_NAME,
_ctc_loss_op_cudnn)
res = ctc_loss_op_standard(**params)
function_eager.register(ctc_loss_op_cudnn, **params)
return res
if blank_index is None:
blank_index = 0
return ctc_loss_dense(
labels=labels,
logits=logits,
label_length=label_length,
logit_length=logit_length,
logits_time_major=logits_time_major,
unique=unique,
blank_index=blank_index,
name=name)
def ctc_loss_dense(labels,
logits,
label_length,

View File

@ -408,6 +408,13 @@ struct PersistentRnnPlanDeleter {
CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
}
};
#if CUDNN_VERSION >= 7603
struct CtcLossDescriptorDeleter {
void operator()(cudnnCTCLossDescriptor_t descriptor) const {
CHECK_CUDNN_OK(cudnnDestroyCTCLossDescriptor(descriptor));
}
};
#endif
// RAII wrappers for cuDNN types.
using TensorDescriptor =
@ -430,6 +437,10 @@ using DropoutDescriptor =
using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
using PersistentRnnPlan =
std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
#if CUDNN_VERSION >= 7603
using CtcLossDescriptor =
std::unique_ptr<cudnnCTCLossStruct, CtcLossDescriptorDeleter>;
#endif
// Factory methods for cuDNN types.
TensorDescriptor CreateTensorDescriptor() {
@ -479,6 +490,13 @@ RnnDescriptor CreateRnnDescriptor() {
CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
return RnnDescriptor(result);
}
#if CUDNN_VERSION >= 7603
CtcLossDescriptor CreateCtcLossDescriptor() {
cudnnCTCLossDescriptor_t result;
CHECK_CUDNN_OK(cudnnCreateCTCLossDescriptor(&result));
return CtcLossDescriptor(result);
}
#endif
port::StatusOr<PersistentRnnPlan> CreatePersistentRnnPlan(
cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) {
@ -1193,6 +1211,33 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
};
#if CUDNN_VERSION >= 7603
class CudnnCtcLossDescriptor {
public:
explicit CudnnCtcLossDescriptor(cudnnDataType_t data_type)
: handle_(CreateCtcLossDescriptor()) {
CHECK_CUDNN_OK(cudnnSetCTCLossDescriptorEx(
/*ctcLossDesc=*/handle_.get(),
/*compType=*/data_type,
/*normMode=*/CUDNN_LOSS_NORMALIZATION_SOFTMAX,
/*gradMode=*/CUDNN_NOT_PROPAGATE_NAN));
}
cudnnCTCLossDescriptor_t handle() const { return handle_.get(); }
private:
CtcLossDescriptor handle_; // Owned
SE_DISALLOW_COPY_AND_ASSIGN(CudnnCtcLossDescriptor);
};
#else
// dummy class
class CudnnCtcLossDescriptor {
public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type) {}
};
#endif
namespace {
// Check if the LSTM projection is used. If yes, an additional weigth matrix
@ -1660,6 +1705,7 @@ port::StatusOr<DeviceMemory<uint8>> CreateBatchNormBackwardWorkspace(
}
return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
}
#endif
} // namespace
@ -1973,6 +2019,43 @@ port::Status CudnnSupport::DoRnnBackwardImpl(
return port::Status::OK();
}
port::Status CudnnSupport::DoCtcLossImpl(
Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
const CudnnRnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
DeviceMemory<uint8> scratch_memory) {
auto cudnn = cudnn_->GetHandle(parent_, stream);
int kNumTimestamps = probs_desc.num_layers();
int kBatchSize = probs_desc.batch_size();
int kNumLabels = probs_desc.data_size();
int total_size = kNumLabels * kNumTimestamps * kBatchSize;
(void)total_size;
#if CUDNN_VERSION >= 7603
RETURN_IF_CUDNN_ERROR(cudnnCTCLoss(
/*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
/*probs=*/probs_data.opaque(), /*labels=*/labels_data.data(),
/*labelLengths=*/labels_lengths_data.data(),
/*inputLengths=*/input_lengths_data.data(),
/*costs=*/costs_data.opaque(), /*gradientsDesc=*/grads_desc.handle(),
/*gradients=*/grads_data.opaque(),
/*algo=*/CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
/*ctcLossDesc=*/ctc_loss_desc.handle(),
/*workspace=*/scratch_memory.opaque(),
/*workSpaceSizeInBytes=*/scratch_memory.size()));
#else
return port::Status(port::error::INVALID_ARGUMENT,
"No supported cudnnCTCLoss when "
"CUDNN_VERSION < 7.6.3");
#endif
return port::Status::OK();
}
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
CudnnSupport::createRnnDescriptor(
int num_layers, int hidden_size, int input_size, int cell_size,
@ -3832,6 +3915,79 @@ bool CudnnSupport::DoFusedConvolve(
/*report_error=*/!output_profile_result);
}
port::Status CudnnSupport::DoPrepareForCtcLoss(
Stream* stream, dnn::DataType element_type,
const dnn::RnnStateTensorDescriptor& probs_desc,
const dnn::RnnStateTensorDescriptor& grads_desc,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory) {
auto cudnn = cudnn_->GetHandle(parent_, stream);
CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
// Query the workspace size.
size_t workspace_size_in_bytes = 0;
#if CUDNN_VERSION >= 7603
RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
/*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
/*gradientsDesc=*/cudnn_grads_desc.handle(),
/*labels=*/labels_data.data(),
/*labelLengths=*/labels_lengths_data.data(),
/*inputLengths=*/input_lengths_data.data(),
/*algo=*/CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC,
/*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
/*sizeInBytes=*/&workspace_size_in_bytes));
#else
return port::Status(port::error::INVALID_ARGUMENT,
"No supported cudnnGetCTCLossWorkspaceSize when "
"CUDNN_VERSION < 7.6.3");
#endif
// Allocate the workspace.
if (workspace_size_in_bytes == 0) {
*scratch_memory = DeviceMemory<uint8>();
return port::Status::OK();
}
const auto scratch_or =
scratch_allocator->AllocateBytes(workspace_size_in_bytes);
if (scratch_or.ok()) {
*scratch_memory = scratch_or.ValueOrDie();
return port::Status::OK();
}
return port::InternalError(
"Failed to allocate scratch memory for the CuDNN CTC Loss");
}
port::Status CudnnSupport::DoCtcLoss(
Stream* stream, dnn::DataType element_type,
const dnn::RnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
const dnn::RnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory) {
// Current cuDNN CTC Loss only supports the float datatype
if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) {
return port::Status(port::error::INVALID_ARGUMENT,
"CudnnCtcLossDescriptor is supported only when the "
"CUDNN_VERSION >= 7.6.3 and DataType is float");
}
CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
return DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data,
labels_lengths_data, input_lengths_data, costs_data,
cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc,
scratch_memory);
}
bool CudnnSupport::DoTransformTensor(Stream* stream,
const dnn::BatchDescriptor& input_desc,
dnn::DataType input_type,

View File

@ -33,6 +33,7 @@ class GpuExecutor;
class CudnnRnnDescriptor;
class CudnnRnnSequenceTensorDescriptor;
class CudnnRnnStateTensorDescriptor;
class CudnnCtcLossDescriptor;
// Opaque and unique identifier for the cuDNN plugin.
extern const PluginId kCuDnnPlugin;
@ -562,6 +563,17 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::ConvolutionDescriptor& convolution_descriptor,
dnn::BatchDescriptor* output_batch_descriptor);
port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
const dnn::RnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
DeviceMemoryBase costs_data,
const dnn::RnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data,
DeviceMemory<uint8> scratch_memory) override;
bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
dnn::DataType input_type,
const DeviceMemoryBase& input_data,
@ -673,6 +685,15 @@ class CudnnSupport : public dnn::DnnSupport {
ScratchAllocator* workspace_allocator,
dnn::ProfileResult* output_profile_result);
port::Status DoCtcLossImpl(
Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
const CudnnRnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
DeviceMemory<uint8> scratch_memory);
private:
port::Status DoPrepareForConvolution(
dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
@ -686,6 +707,16 @@ class CudnnSupport : public dnn::DnnSupport {
ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
DeviceMemory<uint8>* scratch_memory) override;
port::Status DoPrepareForCtcLoss(
Stream* stream, dnn::DataType element_type,
const dnn::RnnStateTensorDescriptor& probs_desc,
const dnn::RnnStateTensorDescriptor& grads_desc,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
ScratchAllocator* scratch_allocator,
DeviceMemory<uint8>* scratch_memory) override;
SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
};

View File

@ -615,5 +615,18 @@ bool DnnSupport::IsStatusOk(const port::Status& status, bool report_error) {
return false;
}
port::Status DnnSupport::DoCtcLoss(Stream* stream, dnn::DataType element_type,
const RnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
DeviceMemoryBase costs_data,
const RnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data,
DeviceMemory<uint8> scratch_memory) {
return port::UnimplementedError("CtcLoss not implemented");
}
} // namespace dnn
} // namespace stream_executor

View File

@ -2391,6 +2391,73 @@ class DnnSupport {
return false;
}
template <typename ElementType>
port::Status PrepareForCtcLoss(Stream* stream,
const RnnStateTensorDescriptor& probs_desc,
DeviceMemory<ElementType> probs_data,
const RnnStateTensorDescriptor& grads_desc,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
ScratchAllocator* workspace_allocator,
DeviceMemory<uint8>* scratch_memory) {
return DoPrepareForCtcLoss(stream, ToDataType<ElementType>::value,
probs_desc, grads_desc, labels_data,
labels_lengths_data, input_lengths_data,
workspace_allocator, scratch_memory);
}
// Enqueue a CTC Loss operation onto the stream.
//
// Arguments:
// stream: pointer to the stream where this operation should be enqueued to.
// element_type: date type of the input tensors
// probs_desc: specifies the shape and the data layout of the input tensor.
// probs_data: the device memory region that contains the input tensor.
// labels_data: the device memory region that contains the labels_value
// tensor.
// labels_lengths_data: the device memory region that contains the
// labels_lengths tensor
// input_lengths_data: the device memory region that contains the seq_lengths
// tensor
// costs_data: the device memory region that contains the costs tensor.
// grads_desc: specifies the shape and the data layout of the grads tensor.
// grads_data: the device memory region that contains the grads tensor.
// ctc_loss_desc: a CTCLoss descriptor.
// workspace_allocator: a memory allocator that creates the temporary
// workspace memory used by this operation. The caller is responsible for
// keeping the memory alive long enough for this operation, and recylces
// afterwards.
virtual port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
const RnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
DeviceMemoryBase costs_data,
const RnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data,
DeviceMemory<uint8> scratch_memory);
template <typename ElementType>
bool DoCtcLoss(Stream* stream,
const dnn::RnnStateTensorDescriptor& probs_desc,
const DeviceMemory<ElementType>& probs_data,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
DeviceMemory<ElementType>* costs_data,
const dnn::RnnStateTensorDescriptor& grads_desc,
DeviceMemory<ElementType>* grads_data,
DeviceMemory<uint8>* scratch_memory) {
return IsStatusOk(
DoCtcLoss(stream, ToDataType<ElementType>::value, probs_desc,
probs_data, labels_data, labels_lengths_data,
input_lengths_data, *costs_data, grads_desc, *grads_data,
*scratch_memory),
false);
}
// Transforms a tensor into another tensor with a different layout and/or data
// type.
//
@ -2646,6 +2713,19 @@ class DnnSupport {
return port::Status::OK();
}
virtual port::Status DoPrepareForCtcLoss(
Stream* stream, DataType element_type,
const RnnStateTensorDescriptor& probs_desc,
const RnnStateTensorDescriptor& grads_desc,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
ScratchAllocator* scratch_allocator,
DeviceMemory<uint8>* scratch_memory) {
*scratch_memory = {};
return port::Status::OK();
}
SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport);
};

View File

@ -5230,6 +5230,39 @@ Stream &Stream::ThenRnnBackward(
return *this;
}
Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
const DeviceMemory<float> &probs_data,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
DeviceMemory<float> *costs_data,
const dnn::RnnStateTensorDescriptor &grads_desc,
DeviceMemory<float> *grads_data,
ScratchAllocator *workspace_allocator) {
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
DeviceMemory<uint8> scratch_memory;
auto status = dnn->PrepareForCtcLoss(
this, probs_desc, probs_data, grads_desc,
labels_data, labels_lengths_data, input_lengths_data,
workspace_allocator, &scratch_memory)
.ok();
if (status) {
status =
dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data,
labels_lengths_data, input_lengths_data, costs_data,
grads_desc, grads_data, &scratch_memory);
}
if (!status) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
dnn::DataType input_type,
const DeviceMemoryBase &input_data,

View File

@ -1912,6 +1912,18 @@ class Stream {
ScratchAllocator *workspace_allocator,
dnn::ProfileResult *output_profile_result);
// Enqueue a CTCLoss operation onto the stream.
// See DnnSupport::DoCtcLoss for more details.
Stream &ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
const DeviceMemory<float> &probs_data,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
DeviceMemory<float> *costs_data,
const dnn::RnnStateTensorDescriptor &grads_desc,
DeviceMemory<float> *grads_data,
ScratchAllocator *workspace_allocator);
// Enqueue onto the stream a operation that transforms a tensor.
// See DnnSupport::DoTransformTensor for more details.
Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,

View File

@ -636,6 +636,10 @@ tf_module {
name: "CTCLoss"
argspec: "args=[\'inputs\', \'labels_indices\', \'labels_values\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'ignore_longer_outputs_than_inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'False\', \'None\'], "
}
member_method {
name: "CTCLossV2"
argspec: "args=[\'inputs\', \'labels_indices\', \'labels_values\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'ignore_longer_outputs_than_inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'False\', \'None\'], "
}
member_method {
name: "CacheDataset"
argspec: "args=[\'input_dataset\', \'filename\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -636,6 +636,10 @@ tf_module {
name: "CTCLoss"
argspec: "args=[\'inputs\', \'labels_indices\', \'labels_values\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'ignore_longer_outputs_than_inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'False\', \'None\'], "
}
member_method {
name: "CTCLossV2"
argspec: "args=[\'inputs\', \'labels_indices\', \'labels_values\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'ignore_longer_outputs_than_inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'False\', \'None\'], "
}
member_method {
name: "CacheDataset"
argspec: "args=[\'input_dataset\', \'filename\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "