Add changes to support cuDNN CTC loss

This commit is contained in:
Kaixi Hou 2019-07-10 11:28:46 -07:00
parent 1e96bba8bf
commit a98e8ca0fb
12 changed files with 669 additions and 24 deletions

View File

@ -2296,7 +2296,9 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core/util/ctc:ctc_beam_search_lib",
"//tensorflow/core/util/ctc:ctc_loss_calculator_lib",
],
] + if_cuda([
"//tensorflow/core:stream_executor",
]),
)
tf_cc_test(

View File

@ -15,6 +15,10 @@ limitations under the License.
// See docs in ../ops/ctc_ops.cc.
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -25,8 +29,89 @@ limitations under the License.
#include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/util/stream_executor_util.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
#if GOOGLE_CUDA
using GPUDevice = Eigen::GpuDevice;
namespace {
using se::DeviceMemory;
using se::Stream;
using se::StreamExecutor;
using se::ScratchAllocator;
using se::dnn::CtcLossDescriptor;
using se::dnn::RnnStateTensorDescriptor;
using se::dnn::ToDataType;
using se::port::StatusOr;
template<typename T>
void DoHistogram(OpKernelContext* ctx, const Tensor* labels_indices,
int num_indices, int batch_size,
std::vector<int> *labels_lengths) {
const T* h_in = labels_indices->flat<T>().data();
for(int i = 0; i < num_indices; i++) {
T key = h_in[i * 2];
(*labels_lengths)[key]++;
OP_REQUIRES(ctx, (*labels_lengths)[key] < 256,
errors::InvalidArgument("Label lengths cannot exceed 256"
"for GPU implementation"));
}
}
// A helper to allocate temporary scratch memory for cudnnCTCLoss ops. It
// takes the ownership of the underlying memory. The expectation is that the
// memory should be alive for the span of the cudnnCTCLoss itself.
template <typename T>
class CudnnCtcLossAllocatorInTemp : public ScratchAllocator {
public:
~CudnnCtcLossAllocatorInTemp() override = default;
explicit CudnnCtcLossAllocatorInTemp(OpKernelContext* context)
: context_(context) {}
int64 GetMemoryLimitInBytes() override {
return std::numeric_limits<int64>::max();
}
StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
Tensor temporary_memory;
const DataType tf_data_type = DataTypeToEnum<T>::v();
int64 allocate_count =
Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
Status allocation_status(context_->allocate_temp(
tf_data_type, TensorShape({allocate_count}), &temporary_memory));
if (!allocation_status.ok()) {
return allocation_status;
}
// Hold the reference of the allocated tensors until the end of the
// allocator.
allocated_tensors_.push_back(temporary_memory);
total_byte_size_ += byte_size;
return DeviceMemory<uint8>::MakeFromByteSize(
temporary_memory.template flat<T>().data(),
temporary_memory.template flat<T>().size() * sizeof(T));
}
int64 TotalByteSize() const { return total_byte_size_; }
Tensor get_allocated_tensor(int index) const {
return allocated_tensors_[index];
}
private:
int64 total_byte_size_ = 0;
OpKernelContext* context_; // not owned
std::vector<Tensor> allocated_tensors_;
};
} // end namespace
#endif // GOOGLE_CUDA
template <typename T>
class CTCLossOp : public OpKernel {
typedef Eigen::Map<
@ -186,4 +271,156 @@ REGISTER_CPU(double);
#undef REGISTER_CPU
#if GOOGLE_CUDA
class CTCLossOpGPU : public OpKernel {
public:
explicit CTCLossOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("preprocess_collapse_repeated",
&preprocess_collapse_repeated_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs",
&ignore_longer_outputs_than_inputs_));
}
void Compute(OpKernelContext* ctx) override {
const Tensor* inputs;
const Tensor* labels_indices;
const Tensor* labels_values;
const Tensor* seq_len;
OP_REQUIRES_OK(ctx, ctx->input("inputs", &inputs));
OP_REQUIRES_OK(ctx, ctx->input("labels_indices", &labels_indices));
OP_REQUIRES_OK(ctx, ctx->input("labels_values", &labels_values));
OP_REQUIRES_OK(ctx, ctx->input("sequence_length", &seq_len));
OP_REQUIRES(ctx, inputs->shape().dims() == 3,
errors::InvalidArgument("inputs is not a 3-Tensor"));
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(seq_len->shape()),
errors::InvalidArgument("sequence_length is not a vector"));
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_indices->shape()),
errors::InvalidArgument("labels_indices is not a matrix"));
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_values->shape()),
errors::InvalidArgument("labels_values is not a vector"));
const TensorShape& inputs_shape = inputs->shape();
const int64 max_time_raw = inputs_shape.dim_size(0);
const int64 batch_size_raw = inputs_shape.dim_size(1);
const int64 num_classes_raw = inputs_shape.dim_size(2);
OP_REQUIRES(
ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
errors::InvalidArgument("num_classes cannot exceed max int"));
const int max_time = static_cast<const int>(max_time_raw);
const int batch_size = static_cast<const int>(batch_size_raw);
const int num_classes = static_cast<const int>(num_classes_raw);
OP_REQUIRES(
ctx, batch_size == seq_len->dim_size(0),
errors::InvalidArgument("len(sequence_length) != batch_size. ",
"len(sequence_length): ", seq_len->dim_size(0),
" batch_size: ", batch_size));
OP_REQUIRES(ctx, labels_indices->dim_size(0) == labels_values->dim_size(0),
errors::InvalidArgument(
"labels_indices and labels_values must contain the "
"same number of rows, but saw shapes: ",
labels_indices->shape().DebugString(), " vs. ",
labels_values->shape().DebugString()));
auto num_indices = labels_indices->dim_size(0);
OP_REQUIRES(ctx, batch_size != 0,
errors::InvalidArgument("batch_size must not be 0"));
Tensor* loss = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss));
Tensor* gradient = nullptr;
OP_REQUIRES_OK(ctx,
ctx->allocate_output("gradient", inputs_shape, &gradient));
OP_REQUIRES(ctx, preprocess_collapse_repeated_ == false,
errors::InvalidArgument("GPU CTCLossOp requires "
"preprocess_collapse_repeated to be "
"false"));
OP_REQUIRES(ctx, ctc_merge_repeated_ == true,
errors::InvalidArgument("GPU CTCLossOp requires "
"ctc_merge_repeated_ to be "
"true"));
OP_REQUIRES(ctx, ignore_longer_outputs_than_inputs_ == false,
errors::InvalidArgument("GPU CTCLossOp requires "
"ignore_longer_outputs_than_inputs_ to"
"be false"));
// Convert the labels_indices to labels_lengths
std::vector<int> labels_lengths(batch_size, 0);
DoHistogram<int64>(ctx, labels_indices, num_indices, batch_size,
&labels_lengths);
StreamExecutor* executor = ctx->op_device_context()->stream()->parent();
se::dnn::DataType data_type = ToDataType<float>::value;
std::unique_ptr<CtcLossDescriptor> ctc_loss_desc;
std::unique_ptr<RnnStateTensorDescriptor> probs_desc;
std::unique_ptr<RnnStateTensorDescriptor> grads_desc;
auto ctc_loss_desc_s = executor->createCtcLossDescriptor(data_type);
OP_REQUIRES_OK(ctx, ctc_loss_desc_s.status());
ctc_loss_desc = ctc_loss_desc_s.ConsumeValueOrDie();
auto probs_desc_s = executor->createRnnStateTensorDescriptor(
max_time, batch_size, num_classes, data_type);
OP_REQUIRES_OK(ctx, probs_desc_s.status());
probs_desc = probs_desc_s.ConsumeValueOrDie();
auto grads_desc_s = executor->createRnnStateTensorDescriptor(
max_time, batch_size, num_classes, data_type);
OP_REQUIRES_OK(ctx, grads_desc_s.status());
grads_desc = grads_desc_s.ConsumeValueOrDie();
absl::Span<const int32> labels_data;
absl::Span<const int32> labels_lengths_data;
absl::Span<const int32> input_lengths_data;
labels_data = absl::Span<const int32>(
labels_values->flat<int32>().data(), num_indices);
labels_lengths_data = absl::Span<const int32>(
labels_lengths.data(), batch_size);
input_lengths_data = absl::Span<const int32>(
seq_len->flat<int32>().data(), batch_size);
auto probs_data = StreamExecutorUtil::AsDeviceMemory<float>(*inputs);
auto costs_data = StreamExecutorUtil::AsDeviceMemory<float>(*loss);
auto grads_data = StreamExecutorUtil::AsDeviceMemory<float>(*gradient);
CudnnCtcLossAllocatorInTemp<uint8> workspace_allocator(ctx);
Stream* stream = ctx->op_device_context()->stream();
bool cudnn_launch_status =
stream
->ThenCtcLoss(
*probs_desc, probs_data, labels_data, labels_lengths_data,
input_lengths_data, &costs_data, *grads_desc, &grads_data,
*ctc_loss_desc, &workspace_allocator)
.ok();
if (!cudnn_launch_status) {
ctx->SetStatus(
errors::Internal("cuDNN CTCLoss launch failure"));
}
}
private:
bool preprocess_collapse_repeated_;
bool ctc_merge_repeated_;
bool ignore_longer_outputs_than_inputs_;
TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOpGPU);
};
REGISTER_KERNEL_BUILDER(Name("CTCLossV2").Device(DEVICE_GPU)
.HostMemory("labels_indices")
.HostMemory("labels_values")
.HostMemory("sequence_length"),
CTCLossOpGPU);
#endif // GOOGLE_CUDA
} // end namespace tensorflow

View File

@ -62,6 +62,43 @@ REGISTER_OP("CTCLoss")
return Status::OK();
});
REGISTER_OP("CTCLossV2")
.Input("inputs: float")
.Input("labels_indices: int64")
.Input("labels_values: int32")
.Input("sequence_length: int32")
.Attr("preprocess_collapse_repeated: bool = false")
.Attr("ctc_merge_repeated: bool = true")
.Attr("ignore_longer_outputs_than_inputs: bool = false")
.Output("loss: float")
.Output("gradient: float")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle inputs;
ShapeHandle labels_indices;
ShapeHandle labels_values;
ShapeHandle sequence_length;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &labels_indices));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &labels_values));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &sequence_length));
DimensionHandle unused;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(labels_indices, 0),
c->Dim(labels_values, 0), &unused));
// Get batch size from inputs and sequence_length, and update inputs
// with the merged batch_size since it is returned.
DimensionHandle batch_size;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
TF_RETURN_IF_ERROR(c->ReplaceDim(inputs, 1, batch_size, &inputs));
c->set_output(0, c->Vector(batch_size));
c->set_output(1, inputs);
return Status::OK();
});
REGISTER_OP("CTCGreedyDecoder")
.Input("inputs: T")
.Input("sequence_length: int32")

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
import os
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@ -840,4 +841,5 @@ class CTCLossTestV2(test.TestCase):
[[1.0, 2.0], [5.0, 8.0], [14.0, 20.0]], out)
if __name__ == "__main__":
os.environ['TF_CUDNN_CTC_LOSS'] = '1'
test.main()

View File

@ -42,6 +42,7 @@ from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
import os
# pylint: disable=protected-access, invalid-name
@tf_export(v1=["nn.ctc_loss"])
@ -155,6 +156,24 @@ def ctc_loss(labels,
Raises:
TypeError: if labels is not a `SparseTensor`.
"""
return _ctc_loss_impl(labels, inputs, sequence_length,
preprocess_collapse_repeated, ctc_merge_repeated,
ignore_longer_outputs_than_inputs, time_major, logits,
use_cudnn=False)
def _ctc_loss_impl(labels,
inputs=None,
sequence_length=None,
preprocess_collapse_repeated=False,
ctc_merge_repeated=True,
ignore_longer_outputs_than_inputs=False,
time_major=True,
logits=None,
use_cudnn=False):
# Helper function of ctc_loss with one additional param:
# use_cudnn: A bool to enable cuDNN CTC loss operation. If true, the blank
# index has to be 0.
# The second, third, etc output tensors contain the gradients. We use it in
# _CTCLossGrad() below.
if not isinstance(labels, sparse_tensor.SparseTensor):
@ -166,7 +185,14 @@ def ctc_loss(labels,
if not time_major:
inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,N) => (T,B,N)
loss, _ = gen_ctc_ops.ctc_loss(
# gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the
# blank index to be 0, but v1 views it as the last index.
if use_cudnn:
ctc_loss_func = gen_ctc_ops.ctc_loss_v2
else:
ctc_loss_func = gen_ctc_ops.ctc_loss
loss, _ = ctc_loss_func(
inputs,
labels.indices,
labels.values,
@ -177,19 +203,8 @@ def ctc_loss(labels,
return loss
# pylint: disable=unused-argument
@ops.RegisterGradient("CTCLoss")
def _CTCLossGrad(op, grad_loss, _):
"""The derivative provided by CTC Loss.
Args:
op: the CTCLoss op.
grad_loss: The backprop for cost.
Returns:
The CTC Loss gradient.
"""
def _CTCLossGradImpl(op, grad_loss, _):
# Outputs are: loss, grad
#
# Currently there is no way to take the second derivative of this op
@ -205,6 +220,33 @@ def _CTCLossGrad(op, grad_loss, _):
# labels_indices, labels_values and sequence_length
return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None]
# pylint: disable=unused-argument
@ops.RegisterGradient("CTCLoss")
def _CTCLossGrad(op, grad_loss, _):
"""The derivative provided by CTC Loss.
Args:
op: the CTCLoss op.
grad_loss: The backprop for cost.
Returns:
The CTC Loss gradient.
"""
return _CTCLossGradImpl(op, grad_loss, _)
# pylint: disable=unused-argument
@ops.RegisterGradient("CTCLossV2")
def _CTCLossGrad(op, grad_loss, _):
"""The derivative provided by CTC Loss V2.
Args:
op: the CTCLossV2 op.
grad_loss: The backprop for cost.
Returns:
The CTC Loss V2 gradient.
"""
return _CTCLossGradImpl(op, grad_loss, _)
@tf_export("nn.ctc_greedy_decoder")
def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
@ -654,26 +696,36 @@ def ctc_loss_v2(labels,
raise ValueError(
"blank_index must be given when using SparseTensor labels.")
_ctc_use_cudnn = os.environ.get("TF_CUDNN_CTC_LOSS", "0")
if _ctc_use_cudnn == "1":
use_cudnn = True
else:
use_cudnn = False
if blank_index < 0:
blank_index += _get_dim(logits, 2)
if blank_index != _get_dim(logits, 2) - 1:
logits = array_ops.concat([
logits[:, :, :blank_index],
logits[:, :, blank_index + 1:],
logits[:, :, blank_index:blank_index + 1],
],
axis=2)
part_before = logits[:, :, :blank_index]
part_after = logits[:, :, blank_index + 1:]
part_blank = logits[:, :, blank_index:blank_index + 1]
if use_cudnn:
logits = array_ops.concat([part_blank, part_before, part_after], axis=2)
labels = sparse_tensor.SparseTensor(
labels.indices,
array_ops.where(labels.values < blank_index, labels.values + 1,
labels.values), labels.dense_shape)
else:
logits = array_ops.concat([part_before, part_after, part_blank], axis=2)
labels = sparse_tensor.SparseTensor(
labels.indices,
array_ops.where(labels.values < blank_index, labels.values,
labels.values - 1), labels.dense_shape)
return ctc_loss(
return _ctc_loss_impl(
labels=labels,
inputs=logits,
sequence_length=logit_length,
time_major=logits_time_major)
time_major=logits_time_major,
use_cudnn=use_cudnn)
if blank_index is None:
blank_index = 0

View File

@ -408,6 +408,13 @@ struct PersistentRnnPlanDeleter {
CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
}
};
#if CUDNN_VERSION >= 7601
struct CtcLossDescriptorDeleter {
void operator()(cudnnCTCLossDescriptor_t descriptor) const {
CHECK_CUDNN_OK(cudnnDestroyCTCLossDescriptor(descriptor));
}
};
#endif
// RAII wrappers for cuDNN types.
using TensorDescriptor =
@ -430,6 +437,10 @@ using DropoutDescriptor =
using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
using PersistentRnnPlan =
std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
#if CUDNN_VERSION >= 7601
using CtcLossDescriptor =
std::unique_ptr<cudnnCTCLossStruct, CtcLossDescriptorDeleter>;
#endif
// Factory methods for cuDNN types.
TensorDescriptor CreateTensorDescriptor() {
@ -479,6 +490,13 @@ RnnDescriptor CreateRnnDescriptor() {
CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
return RnnDescriptor(result);
}
#if CUDNN_VERSION >= 7601
CtcLossDescriptor CreateCtcLossDescriptor() {
cudnnCTCLossDescriptor_t result;
CHECK_CUDNN_OK(cudnnCreateCTCLossDescriptor(&result));
return CtcLossDescriptor(result);
}
#endif
port::StatusOr<PersistentRnnPlan> CreatePersistentRnnPlan(
cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) {
@ -1189,6 +1207,53 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
};
class CudnnCtcLossDescriptor : public dnn::CtcLossDescriptor {
CudnnCtcLossDescriptor(gpu::CtcLossDescriptor ctc_loss_desc,
cudnnDataType_t data_type,
cudnnLossNormalizationMode_t norm_mode,
cudnnNanPropagation_t grad_mode)
: ctc_loss_desc_(std::move(ctc_loss_desc)),
data_type_(data_type),
norm_mode_(norm_mode),
grad_mode_(grad_mode){}
public:
CudnnCtcLossDescriptor(CudnnCtcLossDescriptor&& other) = default;
static port::StatusOr<CudnnCtcLossDescriptor> Create(
cudnnDataType_t data_type,
cudnnLossNormalizationMode_t norm_mode=CUDNN_LOSS_NORMALIZATION_SOFTMAX,
cudnnNanPropagation_t grad_mode=CUDNN_NOT_PROPAGATE_NAN) {
gpu::CtcLossDescriptor ctc_loss_desc = CreateCtcLossDescriptor();
#if CUDNN_VERSION >= 7601
RETURN_IF_CUDNN_ERROR(cudnnSetCTCLossDescriptorEx(
/*ctcLossDesc=*/ctc_loss_desc.get(),
/*compType=*/data_type,
/*normMode=*/norm_mode,
/*gradMode=*/grad_mode));
#else
return port::Status(port::error::INVALID_ARGUMENT,
"No supported cudnnSetCTCLossDescriptorEx when "
"CUDNN_VERSION < 7.6.3");
#endif
return CudnnCtcLossDescriptor(std::move(ctc_loss_desc), data_type,
norm_mode, grad_mode);
}
cudnnCTCLossDescriptor_t handle() const { return ctc_loss_desc_.get(); }
cudnnDataType_t data_type() const { return data_type_; }
cudnnLossNormalizationMode_t lnorm_mode() const { return norm_mode_; }
cudnnNanPropagation_t grad_mode() const { return grad_mode_; }
private:
gpu::CtcLossDescriptor ctc_loss_desc_;
cudnnDataType_t data_type_;
cudnnLossNormalizationMode_t norm_mode_;
cudnnNanPropagation_t grad_mode_;
SE_DISALLOW_COPY_AND_ASSIGN(CudnnCtcLossDescriptor);
};
namespace {
// Check if the LSTM projection is used. If yes, an additional weigth matrix
@ -1656,6 +1721,39 @@ port::StatusOr<DeviceMemory<uint8>> CreateBatchNormBackwardWorkspace(
}
return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
}
port::StatusOr<DeviceMemory<uint8>> CreateCtcLossWorkspace(
Stream* stream, const CudnnHandle& cudnn,
const CudnnCtcLossDescriptor& ctc_loss_desc,
const CudnnRnnStateTensorDescriptor& probs_desc,
const CudnnRnnStateTensorDescriptor& grads_desc,
const absl::Span<const int32>& labels_data,
const absl::Span<const int32>& labels_lengths_data,
const absl::Span<const int32>& input_lengths_data,
ScratchAllocator* workspace_allocator) {
// Query the workspace size.
size_t workspace_size_in_bytes = 0;
#if CUDNN_VERSION >= 7601
RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
/*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
/*gradientsDesc=*/grads_desc.handle(),
/*labels=*/labels_data.data(),
/*labelLengths=*/labels_lengths_data.data(),
/*inputLengths=*/input_lengths_data.data(),
/*algo=*/CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
/*ctcLossDesc=*/ctc_loss_desc.handle(),
/*sizeInBytes=*/&workspace_size_in_bytes));
#else
return port::Status(port::error::INVALID_ARGUMENT,
"No supported cudnnGetCTCLossWorkspaceSize when "
"CUDNN_VERSION < 7.6.3");
#endif
// Allocate the workspace.
if (workspace_size_in_bytes == 0) {
return DeviceMemory<uint8>();
}
return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
}
#endif
} // namespace
@ -1969,6 +2067,51 @@ port::Status CudnnSupport::DoRnnBackwardImpl(
return port::Status::OK();
}
port::Status CudnnSupport::DoCtcLossImpl(
Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
const DeviceMemory<float>& probs_data,
const absl::Span<const int32>& labels_data,
const absl::Span<const int32>& labels_lengths_data,
const absl::Span<const int32>& input_lengths_data,
DeviceMemory<float>* costs_data,
const CudnnRnnStateTensorDescriptor& grads_desc,
DeviceMemory<float>* grads_data,
const CudnnCtcLossDescriptor& ctc_loss_desc,
ScratchAllocator* workspace_allocator) {
auto cudnn = cudnn_->GetHandle(parent_, stream);
SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
CreateCtcLossWorkspace(stream, cudnn, ctc_loss_desc,
probs_desc, grads_desc,
labels_data, labels_lengths_data,
input_lengths_data,
workspace_allocator));
int kNumTimestamps = probs_desc.num_layers();
int kBatchSize = probs_desc.batch_size();
int kNumLabels = probs_desc.data_size();
int total_size = kNumLabels * kNumTimestamps * kBatchSize;
#if CUDNN_VERSION >= 7601
RETURN_IF_CUDNN_ERROR(cudnnCTCLoss(
/*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
/*probs=*/probs_data.opaque(), /*labels=*/labels_data.data(),
/*labelsLengths=*/labels_lengths_data.data(),
/*inputLengths=*/input_lengths_data.data(),
/*costs=*/costs_data->opaque(), /*gradientsDesc=*/grads_desc.handle(),
/*gradients=*/grads_data->opaque(),
/*algo=*/CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
/*ctcLossDesc=*/ctc_loss_desc.handle(),
/*workspace=*/workspace.opaque(),
/*workSpaceSizeInBytes=*/workspace.size()));
#else
return port::Status(port::error::INVALID_ARGUMENT,
"No supported cudnnCTCLoss when "
"CUDNN_VERSION < 7.6.3");
#endif
return port::Status::OK();
}
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
CudnnSupport::createRnnDescriptor(
int num_layers, int hidden_size, int input_size, int cell_size,
@ -1992,6 +2135,16 @@ CudnnSupport::createRnnDescriptor(
new CudnnRnnDescriptor(std::move(rnn_desc)));
}
port::StatusOr<std::unique_ptr<dnn::CtcLossDescriptor>>
CudnnSupport::createCtcLossDescriptor(
dnn::DataType data_type) {
SE_ASSIGN_OR_RETURN(CudnnCtcLossDescriptor ctc_loss_desc,
CudnnCtcLossDescriptor::Create(
ToCudnnDataType(data_type)));
return std::unique_ptr<dnn::CtcLossDescriptor>(
new CudnnCtcLossDescriptor(std::move(ctc_loss_desc)));
}
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length,
int batch_size, int data_size,
@ -3828,6 +3981,31 @@ bool CudnnSupport::DoFusedConvolve(
/*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoCtcLoss(
Stream* stream, const dnn::RnnStateTensorDescriptor &probs_desc,
const DeviceMemory<float> &probs_data,
const absl::Span<const int32> &labels_data,
const absl::Span<const int32> &labels_lengths_data,
const absl::Span<const int32> &input_lengths_data,
DeviceMemory<float> *costs_data,
const dnn::RnnStateTensorDescriptor &grads_desc,
DeviceMemory<float> *grads_data,
const dnn::CtcLossDescriptor &ctc_loss_desc,
ScratchAllocator *workspace_allocator) {
const CudnnCtcLossDescriptor& cudnn_ctc_loss_desc =
static_cast<const CudnnCtcLossDescriptor&>(ctc_loss_desc);
const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
return IsStatusOk(
DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data,
labels_lengths_data, input_lengths_data, costs_data,
cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc,
workspace_allocator),
/*report_error=*/true);
}
bool CudnnSupport::DoTransformTensor(Stream* stream,
const dnn::BatchDescriptor& input_desc,
dnn::DataType input_type,

View File

@ -33,6 +33,7 @@ class GpuExecutor;
class CudnnRnnDescriptor;
class CudnnRnnSequenceTensorDescriptor;
class CudnnRnnStateTensorDescriptor;
class CudnnCtcLossDescriptor;
// Opaque and unique identifier for the cuDNN plugin.
extern const PluginId kCuDnnPlugin;
@ -54,6 +55,9 @@ class CudnnSupport : public dnn::DnnSupport {
float dropout, uint64 seed, ScratchAllocator* state_allocator,
bool use_padded_io) override;
port::StatusOr<std::unique_ptr<dnn::CtcLossDescriptor>>
createCtcLossDescriptor(dnn::DataType data_type) override;
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
int data_size,
@ -562,6 +566,18 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::ConvolutionDescriptor& convolution_descriptor,
dnn::BatchDescriptor* output_batch_descriptor);
bool DoCtcLoss(
Stream* stream, const dnn::RnnStateTensorDescriptor &probs_desc,
const DeviceMemory<float> &probs_data,
const absl::Span<const int32> &labels_data,
const absl::Span<const int32> &labels_lengths_data,
const absl::Span<const int32> &input_lengths_data,
DeviceMemory<float> *costs_data,
const dnn::RnnStateTensorDescriptor &grads_desc,
DeviceMemory<float> *grads_data,
const dnn::CtcLossDescriptor &ctc_loss_desc,
ScratchAllocator *workspace_allocator);
bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
dnn::DataType input_type,
const DeviceMemoryBase& input_data,
@ -673,6 +689,18 @@ class CudnnSupport : public dnn::DnnSupport {
ScratchAllocator* workspace_allocator,
dnn::ProfileResult* output_profile_result);
port::Status DoCtcLossImpl(
Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
const DeviceMemory<float>& probs_data,
const absl::Span<const int32>& labels_data,
const absl::Span<const int32>& labels_lengths_data,
const absl::Span<const int32>& input_lengths_data,
DeviceMemory<float>* costs_data,
const CudnnRnnStateTensorDescriptor& grads_desc,
DeviceMemory<float>* grads_data,
const CudnnCtcLossDescriptor& ctc_loss_desc,
ScratchAllocator* workspace_allocator);
private:
port::Status DoPrepareForConvolution(
dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,

View File

@ -190,6 +190,15 @@ class RnnDescriptor {
virtual ParamsRegions ParamsBiasRegions() const { return ParamsRegions(); }
};
// Specifies the CTC Loss computation.
//
// The user is responsible for releasing this descriptor when it is no longer
// in use. The destructor releases the underlying descriptors.
class CtcLossDescriptor {
public:
virtual ~CtcLossDescriptor() {}
};
// Specifies the sequence in a RNN model.
//
// The user is responsible for releasing this descriptor when it is no longer
@ -2133,6 +2142,16 @@ class DnnSupport {
"createRnnDescriptor is unimplemented");
}
// Create an CTC Loss descriptor.
//
// Arguments:
// data_type: an enum to specify the data types used in this model.
virtual port::StatusOr<std::unique_ptr<dnn::CtcLossDescriptor>>
createCtcLossDescriptor(dnn::DataType data_type) {
return port::Status(port::error::UNIMPLEMENTED,
"createCtcLossDescriptor is unimplemented");
}
// Create a RNN sequence descriptor that specifies either the input or output
// sequence. The caller retains the ownership of the returned descriptor.
//
@ -2383,6 +2402,40 @@ class DnnSupport {
return false;
}
// Enqueue a CTC Loss operation onto the stream.
//
// Arguments:
// stream: pointer to the stream where this operation should be enqueued to.
// probs_desc: specifies the shape and the data layout of the input tensor.
// probs_data: the device memory region that contains the input tensor.
// labels_data: the device memory region that contains the labels_value
// tensor.
// labels_lengths_data: the device memory region that contains the
// labels_lengths tensor
// input_lengths_data: the device memory region that contains the seq_lengths
// tensor
// costs_data: the device memory region that contains the costs tensor.
// grads_desc: specifies the shape and the data layout of the grads tensor.
// grads_data: the device memory region that contains the grads tensor.
// ctc_loss_desc: a CTCLoss descriptor created by createCTCLossDescriptor.
// workspace_allocator: a memory allocator that creates the temporary
// workspace memory used by this operation. The caller is responsible for
// keeping the memory alive long enough for this operation, and recylces
// afterwards.
virtual bool DoCtcLoss(Stream* stream,
const dnn::RnnStateTensorDescriptor &probs_desc,
const DeviceMemory<float> &probs_data,
const absl::Span<const int32> &labels_data,
const absl::Span<const int32> &labels_lengths_data,
const absl::Span<const int32> &input_lengths_data,
DeviceMemory<float> *costs_data,
const dnn::RnnStateTensorDescriptor &grads_desc,
DeviceMemory<float> *grads_data,
const dnn::CtcLossDescriptor &ctc_loss_desc,
ScratchAllocator *workspace_allocator) {
return false;
}
// Transforms a tensor into another tensor with a different layout and/or data
// type.
//

View File

@ -5230,6 +5230,33 @@ Stream &Stream::ThenRnnBackward(
return *this;
}
Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
const DeviceMemory<float> &probs_data,
const absl::Span<const int32> &labels_data,
const absl::Span<const int32> &labels_lengths_data,
const absl::Span<const int32> &input_lengths_data,
DeviceMemory<float> *costs_data,
const dnn::RnnStateTensorDescriptor &grads_desc,
DeviceMemory<float> *grads_data,
const dnn::CtcLossDescriptor &ctc_loss_desc,
ScratchAllocator *workspace_allocator) {
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
auto status = dnn->DoCtcLoss(
this, probs_desc, probs_data, labels_data, labels_lengths_data,
input_lengths_data, costs_data, grads_desc, grads_data, ctc_loss_desc,
workspace_allocator);
if (!status) {
SetError();
}
} else {
SetErrorAndLogNoDnnSupport();
}
}
return *this;
}
Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
dnn::DataType input_type,
const DeviceMemoryBase &input_data,

View File

@ -1912,6 +1912,20 @@ class Stream {
ScratchAllocator *workspace_allocator,
dnn::ProfileResult *output_profile_result);
// Enqueue a CTCLoss operation onto the stream.
// See DnnSupport::DoCtcLoss for more details.
Stream &ThenCtcLoss(
const dnn::RnnStateTensorDescriptor &probs_desc,
const DeviceMemory<float> &probs_data,
const absl::Span<const int32> &labels_data,
const absl::Span<const int32> &labels_lengths_data,
const absl::Span<const int32> &input_lengths_data,
DeviceMemory<float> *costs_data,
const dnn::RnnStateTensorDescriptor &grads_desc,
DeviceMemory<float> *grads_data,
const dnn::CtcLossDescriptor &ctc_loss_desc,
ScratchAllocator *workspace_allocator);
// Enqueue onto the stream a operation that transforms a tensor.
// See DnnSupport::DoTransformTensor for more details.
Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,

View File

@ -353,6 +353,16 @@ StreamExecutor::createRnnDescriptor(
state_allocator, use_padded_io);
}
port::StatusOr<std::unique_ptr<dnn::CtcLossDescriptor>>
StreamExecutor::createCtcLossDescriptor(dnn::DataType data_type) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
return port::Status(port::error::UNKNOWN,
"Fail to find the dnn implementation.");
}
return dnn_support->createCtcLossDescriptor(data_type);
}
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
int batch_size, int data_size,

View File

@ -399,6 +399,11 @@ class StreamExecutor {
float dropout, uint64 seed, ScratchAllocator *state_allocator,
bool use_padded_io);
// Create an CTC loss descriptor. The caller retains the ownership of the
// descriptor.
port::StatusOr<std::unique_ptr<dnn::CtcLossDescriptor>>
createCtcLossDescriptor(dnn::DataType data_type);
// Create a RNN sequence descriptor that specifies either the input or output
// sequence. The caller retains the ownership of the returned descriptor.
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>