Add changes to support cuDNN CTC loss
This commit is contained in:
parent
1e96bba8bf
commit
a98e8ca0fb
@ -2296,7 +2296,9 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/util/ctc:ctc_beam_search_lib",
|
"//tensorflow/core/util/ctc:ctc_beam_search_lib",
|
||||||
"//tensorflow/core/util/ctc:ctc_loss_calculator_lib",
|
"//tensorflow/core/util/ctc:ctc_loss_calculator_lib",
|
||||||
],
|
] + if_cuda([
|
||||||
|
"//tensorflow/core:stream_executor",
|
||||||
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
|
@ -15,6 +15,10 @@ limitations under the License.
|
|||||||
|
|
||||||
// See docs in ../ops/ctc_ops.cc.
|
// See docs in ../ops/ctc_ops.cc.
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
#include "tensorflow/core/framework/bounds_check.h"
|
#include "tensorflow/core/framework/bounds_check.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
@ -25,8 +29,89 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
|
#include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
|
||||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
|
#include "tensorflow/core/util/stream_executor_util.h"
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
using GPUDevice = Eigen::GpuDevice;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
using se::DeviceMemory;
|
||||||
|
using se::Stream;
|
||||||
|
using se::StreamExecutor;
|
||||||
|
using se::ScratchAllocator;
|
||||||
|
using se::dnn::CtcLossDescriptor;
|
||||||
|
using se::dnn::RnnStateTensorDescriptor;
|
||||||
|
using se::dnn::ToDataType;
|
||||||
|
using se::port::StatusOr;
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void DoHistogram(OpKernelContext* ctx, const Tensor* labels_indices,
|
||||||
|
int num_indices, int batch_size,
|
||||||
|
std::vector<int> *labels_lengths) {
|
||||||
|
const T* h_in = labels_indices->flat<T>().data();
|
||||||
|
for(int i = 0; i < num_indices; i++) {
|
||||||
|
T key = h_in[i * 2];
|
||||||
|
(*labels_lengths)[key]++;
|
||||||
|
OP_REQUIRES(ctx, (*labels_lengths)[key] < 256,
|
||||||
|
errors::InvalidArgument("Label lengths cannot exceed 256"
|
||||||
|
"for GPU implementation"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A helper to allocate temporary scratch memory for cudnnCTCLoss ops. It
|
||||||
|
// takes the ownership of the underlying memory. The expectation is that the
|
||||||
|
// memory should be alive for the span of the cudnnCTCLoss itself.
|
||||||
|
template <typename T>
|
||||||
|
class CudnnCtcLossAllocatorInTemp : public ScratchAllocator {
|
||||||
|
public:
|
||||||
|
~CudnnCtcLossAllocatorInTemp() override = default;
|
||||||
|
|
||||||
|
explicit CudnnCtcLossAllocatorInTemp(OpKernelContext* context)
|
||||||
|
: context_(context) {}
|
||||||
|
|
||||||
|
int64 GetMemoryLimitInBytes() override {
|
||||||
|
return std::numeric_limits<int64>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override {
|
||||||
|
Tensor temporary_memory;
|
||||||
|
const DataType tf_data_type = DataTypeToEnum<T>::v();
|
||||||
|
int64 allocate_count =
|
||||||
|
Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
|
||||||
|
Status allocation_status(context_->allocate_temp(
|
||||||
|
tf_data_type, TensorShape({allocate_count}), &temporary_memory));
|
||||||
|
if (!allocation_status.ok()) {
|
||||||
|
return allocation_status;
|
||||||
|
}
|
||||||
|
// Hold the reference of the allocated tensors until the end of the
|
||||||
|
// allocator.
|
||||||
|
allocated_tensors_.push_back(temporary_memory);
|
||||||
|
total_byte_size_ += byte_size;
|
||||||
|
return DeviceMemory<uint8>::MakeFromByteSize(
|
||||||
|
temporary_memory.template flat<T>().data(),
|
||||||
|
temporary_memory.template flat<T>().size() * sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 TotalByteSize() const { return total_byte_size_; }
|
||||||
|
|
||||||
|
Tensor get_allocated_tensor(int index) const {
|
||||||
|
return allocated_tensors_[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64 total_byte_size_ = 0;
|
||||||
|
OpKernelContext* context_; // not owned
|
||||||
|
std::vector<Tensor> allocated_tensors_;
|
||||||
|
};
|
||||||
|
} // end namespace
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class CTCLossOp : public OpKernel {
|
class CTCLossOp : public OpKernel {
|
||||||
typedef Eigen::Map<
|
typedef Eigen::Map<
|
||||||
@ -186,4 +271,156 @@ REGISTER_CPU(double);
|
|||||||
|
|
||||||
#undef REGISTER_CPU
|
#undef REGISTER_CPU
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
class CTCLossOpGPU : public OpKernel {
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit CTCLossOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("preprocess_collapse_repeated",
|
||||||
|
&preprocess_collapse_repeated_));
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs",
|
||||||
|
&ignore_longer_outputs_than_inputs_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
const Tensor* inputs;
|
||||||
|
const Tensor* labels_indices;
|
||||||
|
const Tensor* labels_values;
|
||||||
|
const Tensor* seq_len;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->input("inputs", &inputs));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->input("labels_indices", &labels_indices));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->input("labels_values", &labels_values));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->input("sequence_length", &seq_len));
|
||||||
|
|
||||||
|
OP_REQUIRES(ctx, inputs->shape().dims() == 3,
|
||||||
|
errors::InvalidArgument("inputs is not a 3-Tensor"));
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(seq_len->shape()),
|
||||||
|
errors::InvalidArgument("sequence_length is not a vector"));
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_indices->shape()),
|
||||||
|
errors::InvalidArgument("labels_indices is not a matrix"));
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_values->shape()),
|
||||||
|
errors::InvalidArgument("labels_values is not a vector"));
|
||||||
|
|
||||||
|
const TensorShape& inputs_shape = inputs->shape();
|
||||||
|
const int64 max_time_raw = inputs_shape.dim_size(0);
|
||||||
|
const int64 batch_size_raw = inputs_shape.dim_size(1);
|
||||||
|
const int64 num_classes_raw = inputs_shape.dim_size(2);
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
|
||||||
|
errors::InvalidArgument("num_classes cannot exceed max int"));
|
||||||
|
const int max_time = static_cast<const int>(max_time_raw);
|
||||||
|
const int batch_size = static_cast<const int>(batch_size_raw);
|
||||||
|
const int num_classes = static_cast<const int>(num_classes_raw);
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, batch_size == seq_len->dim_size(0),
|
||||||
|
errors::InvalidArgument("len(sequence_length) != batch_size. ",
|
||||||
|
"len(sequence_length): ", seq_len->dim_size(0),
|
||||||
|
" batch_size: ", batch_size));
|
||||||
|
|
||||||
|
OP_REQUIRES(ctx, labels_indices->dim_size(0) == labels_values->dim_size(0),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"labels_indices and labels_values must contain the "
|
||||||
|
"same number of rows, but saw shapes: ",
|
||||||
|
labels_indices->shape().DebugString(), " vs. ",
|
||||||
|
labels_values->shape().DebugString()));
|
||||||
|
auto num_indices = labels_indices->dim_size(0);
|
||||||
|
|
||||||
|
OP_REQUIRES(ctx, batch_size != 0,
|
||||||
|
errors::InvalidArgument("batch_size must not be 0"));
|
||||||
|
|
||||||
|
|
||||||
|
Tensor* loss = nullptr;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss));
|
||||||
|
|
||||||
|
Tensor* gradient = nullptr;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->allocate_output("gradient", inputs_shape, &gradient));
|
||||||
|
|
||||||
|
OP_REQUIRES(ctx, preprocess_collapse_repeated_ == false,
|
||||||
|
errors::InvalidArgument("GPU CTCLossOp requires "
|
||||||
|
"preprocess_collapse_repeated to be "
|
||||||
|
"false"));
|
||||||
|
OP_REQUIRES(ctx, ctc_merge_repeated_ == true,
|
||||||
|
errors::InvalidArgument("GPU CTCLossOp requires "
|
||||||
|
"ctc_merge_repeated_ to be "
|
||||||
|
"true"));
|
||||||
|
OP_REQUIRES(ctx, ignore_longer_outputs_than_inputs_ == false,
|
||||||
|
errors::InvalidArgument("GPU CTCLossOp requires "
|
||||||
|
"ignore_longer_outputs_than_inputs_ to"
|
||||||
|
"be false"));
|
||||||
|
|
||||||
|
// Convert the labels_indices to labels_lengths
|
||||||
|
std::vector<int> labels_lengths(batch_size, 0);
|
||||||
|
DoHistogram<int64>(ctx, labels_indices, num_indices, batch_size,
|
||||||
|
&labels_lengths);
|
||||||
|
|
||||||
|
StreamExecutor* executor = ctx->op_device_context()->stream()->parent();
|
||||||
|
se::dnn::DataType data_type = ToDataType<float>::value;
|
||||||
|
|
||||||
|
std::unique_ptr<CtcLossDescriptor> ctc_loss_desc;
|
||||||
|
std::unique_ptr<RnnStateTensorDescriptor> probs_desc;
|
||||||
|
std::unique_ptr<RnnStateTensorDescriptor> grads_desc;
|
||||||
|
|
||||||
|
auto ctc_loss_desc_s = executor->createCtcLossDescriptor(data_type);
|
||||||
|
OP_REQUIRES_OK(ctx, ctc_loss_desc_s.status());
|
||||||
|
ctc_loss_desc = ctc_loss_desc_s.ConsumeValueOrDie();
|
||||||
|
|
||||||
|
auto probs_desc_s = executor->createRnnStateTensorDescriptor(
|
||||||
|
max_time, batch_size, num_classes, data_type);
|
||||||
|
OP_REQUIRES_OK(ctx, probs_desc_s.status());
|
||||||
|
probs_desc = probs_desc_s.ConsumeValueOrDie();
|
||||||
|
|
||||||
|
auto grads_desc_s = executor->createRnnStateTensorDescriptor(
|
||||||
|
max_time, batch_size, num_classes, data_type);
|
||||||
|
OP_REQUIRES_OK(ctx, grads_desc_s.status());
|
||||||
|
grads_desc = grads_desc_s.ConsumeValueOrDie();
|
||||||
|
|
||||||
|
absl::Span<const int32> labels_data;
|
||||||
|
absl::Span<const int32> labels_lengths_data;
|
||||||
|
absl::Span<const int32> input_lengths_data;
|
||||||
|
labels_data = absl::Span<const int32>(
|
||||||
|
labels_values->flat<int32>().data(), num_indices);
|
||||||
|
labels_lengths_data = absl::Span<const int32>(
|
||||||
|
labels_lengths.data(), batch_size);
|
||||||
|
input_lengths_data = absl::Span<const int32>(
|
||||||
|
seq_len->flat<int32>().data(), batch_size);
|
||||||
|
|
||||||
|
auto probs_data = StreamExecutorUtil::AsDeviceMemory<float>(*inputs);
|
||||||
|
auto costs_data = StreamExecutorUtil::AsDeviceMemory<float>(*loss);
|
||||||
|
auto grads_data = StreamExecutorUtil::AsDeviceMemory<float>(*gradient);
|
||||||
|
|
||||||
|
CudnnCtcLossAllocatorInTemp<uint8> workspace_allocator(ctx);
|
||||||
|
|
||||||
|
Stream* stream = ctx->op_device_context()->stream();
|
||||||
|
bool cudnn_launch_status =
|
||||||
|
stream
|
||||||
|
->ThenCtcLoss(
|
||||||
|
*probs_desc, probs_data, labels_data, labels_lengths_data,
|
||||||
|
input_lengths_data, &costs_data, *grads_desc, &grads_data,
|
||||||
|
*ctc_loss_desc, &workspace_allocator)
|
||||||
|
.ok();
|
||||||
|
|
||||||
|
if (!cudnn_launch_status) {
|
||||||
|
ctx->SetStatus(
|
||||||
|
errors::Internal("cuDNN CTCLoss launch failure"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool preprocess_collapse_repeated_;
|
||||||
|
bool ctc_merge_repeated_;
|
||||||
|
bool ignore_longer_outputs_than_inputs_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOpGPU);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("CTCLossV2").Device(DEVICE_GPU)
|
||||||
|
.HostMemory("labels_indices")
|
||||||
|
.HostMemory("labels_values")
|
||||||
|
.HostMemory("sequence_length"),
|
||||||
|
CTCLossOpGPU);
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -62,6 +62,43 @@ REGISTER_OP("CTCLoss")
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("CTCLossV2")
|
||||||
|
.Input("inputs: float")
|
||||||
|
.Input("labels_indices: int64")
|
||||||
|
.Input("labels_values: int32")
|
||||||
|
.Input("sequence_length: int32")
|
||||||
|
.Attr("preprocess_collapse_repeated: bool = false")
|
||||||
|
.Attr("ctc_merge_repeated: bool = true")
|
||||||
|
.Attr("ignore_longer_outputs_than_inputs: bool = false")
|
||||||
|
.Output("loss: float")
|
||||||
|
.Output("gradient: float")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeHandle inputs;
|
||||||
|
ShapeHandle labels_indices;
|
||||||
|
ShapeHandle labels_values;
|
||||||
|
ShapeHandle sequence_length;
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &labels_indices));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &labels_values));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &sequence_length));
|
||||||
|
|
||||||
|
DimensionHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(c->Dim(labels_indices, 0),
|
||||||
|
c->Dim(labels_values, 0), &unused));
|
||||||
|
|
||||||
|
// Get batch size from inputs and sequence_length, and update inputs
|
||||||
|
// with the merged batch_size since it is returned.
|
||||||
|
DimensionHandle batch_size;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
|
||||||
|
TF_RETURN_IF_ERROR(c->ReplaceDim(inputs, 1, batch_size, &inputs));
|
||||||
|
|
||||||
|
c->set_output(0, c->Vector(batch_size));
|
||||||
|
c->set_output(1, inputs);
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
REGISTER_OP("CTCGreedyDecoder")
|
REGISTER_OP("CTCGreedyDecoder")
|
||||||
.Input("inputs: T")
|
.Input("inputs: T")
|
||||||
.Input("sequence_length: int32")
|
.Input("sequence_length: int32")
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -840,4 +841,5 @@ class CTCLossTestV2(test.TestCase):
|
|||||||
[[1.0, 2.0], [5.0, 8.0], [14.0, 20.0]], out)
|
[[1.0, 2.0], [5.0, 8.0], [14.0, 20.0]], out)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
os.environ['TF_CUDNN_CTC_LOSS'] = '1'
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -42,6 +42,7 @@ from tensorflow.python.util import deprecation
|
|||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
# pylint: disable=protected-access, invalid-name
|
# pylint: disable=protected-access, invalid-name
|
||||||
@tf_export(v1=["nn.ctc_loss"])
|
@tf_export(v1=["nn.ctc_loss"])
|
||||||
@ -155,6 +156,24 @@ def ctc_loss(labels,
|
|||||||
Raises:
|
Raises:
|
||||||
TypeError: if labels is not a `SparseTensor`.
|
TypeError: if labels is not a `SparseTensor`.
|
||||||
"""
|
"""
|
||||||
|
return _ctc_loss_impl(labels, inputs, sequence_length,
|
||||||
|
preprocess_collapse_repeated, ctc_merge_repeated,
|
||||||
|
ignore_longer_outputs_than_inputs, time_major, logits,
|
||||||
|
use_cudnn=False)
|
||||||
|
|
||||||
|
def _ctc_loss_impl(labels,
|
||||||
|
inputs=None,
|
||||||
|
sequence_length=None,
|
||||||
|
preprocess_collapse_repeated=False,
|
||||||
|
ctc_merge_repeated=True,
|
||||||
|
ignore_longer_outputs_than_inputs=False,
|
||||||
|
time_major=True,
|
||||||
|
logits=None,
|
||||||
|
use_cudnn=False):
|
||||||
|
# Helper function of ctc_loss with one additional param:
|
||||||
|
# use_cudnn: A bool to enable cuDNN CTC loss operation. If true, the blank
|
||||||
|
# index has to be 0.
|
||||||
|
|
||||||
# The second, third, etc output tensors contain the gradients. We use it in
|
# The second, third, etc output tensors contain the gradients. We use it in
|
||||||
# _CTCLossGrad() below.
|
# _CTCLossGrad() below.
|
||||||
if not isinstance(labels, sparse_tensor.SparseTensor):
|
if not isinstance(labels, sparse_tensor.SparseTensor):
|
||||||
@ -166,7 +185,14 @@ def ctc_loss(labels,
|
|||||||
if not time_major:
|
if not time_major:
|
||||||
inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,N) => (T,B,N)
|
inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,N) => (T,B,N)
|
||||||
|
|
||||||
loss, _ = gen_ctc_ops.ctc_loss(
|
# gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the
|
||||||
|
# blank index to be 0, but v1 views it as the last index.
|
||||||
|
if use_cudnn:
|
||||||
|
ctc_loss_func = gen_ctc_ops.ctc_loss_v2
|
||||||
|
else:
|
||||||
|
ctc_loss_func = gen_ctc_ops.ctc_loss
|
||||||
|
|
||||||
|
loss, _ = ctc_loss_func(
|
||||||
inputs,
|
inputs,
|
||||||
labels.indices,
|
labels.indices,
|
||||||
labels.values,
|
labels.values,
|
||||||
@ -177,19 +203,8 @@ def ctc_loss(labels,
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
@ops.RegisterGradient("CTCLoss")
|
def _CTCLossGradImpl(op, grad_loss, _):
|
||||||
def _CTCLossGrad(op, grad_loss, _):
|
|
||||||
"""The derivative provided by CTC Loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
op: the CTCLoss op.
|
|
||||||
grad_loss: The backprop for cost.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The CTC Loss gradient.
|
|
||||||
"""
|
|
||||||
# Outputs are: loss, grad
|
# Outputs are: loss, grad
|
||||||
#
|
#
|
||||||
# Currently there is no way to take the second derivative of this op
|
# Currently there is no way to take the second derivative of this op
|
||||||
@ -205,7 +220,34 @@ def _CTCLossGrad(op, grad_loss, _):
|
|||||||
# labels_indices, labels_values and sequence_length
|
# labels_indices, labels_values and sequence_length
|
||||||
return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None]
|
return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None]
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
@ops.RegisterGradient("CTCLoss")
|
||||||
|
def _CTCLossGrad(op, grad_loss, _):
|
||||||
|
"""The derivative provided by CTC Loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op: the CTCLoss op.
|
||||||
|
grad_loss: The backprop for cost.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The CTC Loss gradient.
|
||||||
|
"""
|
||||||
|
return _CTCLossGradImpl(op, grad_loss, _)
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
@ops.RegisterGradient("CTCLossV2")
|
||||||
|
def _CTCLossGrad(op, grad_loss, _):
|
||||||
|
"""The derivative provided by CTC Loss V2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op: the CTCLossV2 op.
|
||||||
|
grad_loss: The backprop for cost.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The CTC Loss V2 gradient.
|
||||||
|
"""
|
||||||
|
return _CTCLossGradImpl(op, grad_loss, _)
|
||||||
|
|
||||||
@tf_export("nn.ctc_greedy_decoder")
|
@tf_export("nn.ctc_greedy_decoder")
|
||||||
def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
|
def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
|
||||||
"""Performs greedy decoding on the logits given in input (best path).
|
"""Performs greedy decoding on the logits given in input (best path).
|
||||||
@ -654,26 +696,36 @@ def ctc_loss_v2(labels,
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"blank_index must be given when using SparseTensor labels.")
|
"blank_index must be given when using SparseTensor labels.")
|
||||||
|
|
||||||
|
_ctc_use_cudnn = os.environ.get("TF_CUDNN_CTC_LOSS", "0")
|
||||||
|
if _ctc_use_cudnn == "1":
|
||||||
|
use_cudnn = True
|
||||||
|
else:
|
||||||
|
use_cudnn = False
|
||||||
|
|
||||||
if blank_index < 0:
|
if blank_index < 0:
|
||||||
blank_index += _get_dim(logits, 2)
|
blank_index += _get_dim(logits, 2)
|
||||||
|
|
||||||
if blank_index != _get_dim(logits, 2) - 1:
|
part_before = logits[:, :, :blank_index]
|
||||||
logits = array_ops.concat([
|
part_after = logits[:, :, blank_index + 1:]
|
||||||
logits[:, :, :blank_index],
|
part_blank = logits[:, :, blank_index:blank_index + 1]
|
||||||
logits[:, :, blank_index + 1:],
|
if use_cudnn:
|
||||||
logits[:, :, blank_index:blank_index + 1],
|
logits = array_ops.concat([part_blank, part_before, part_after], axis=2)
|
||||||
],
|
labels = sparse_tensor.SparseTensor(
|
||||||
axis=2)
|
labels.indices,
|
||||||
|
array_ops.where(labels.values < blank_index, labels.values + 1,
|
||||||
|
labels.values), labels.dense_shape)
|
||||||
|
else:
|
||||||
|
logits = array_ops.concat([part_before, part_after, part_blank], axis=2)
|
||||||
labels = sparse_tensor.SparseTensor(
|
labels = sparse_tensor.SparseTensor(
|
||||||
labels.indices,
|
labels.indices,
|
||||||
array_ops.where(labels.values < blank_index, labels.values,
|
array_ops.where(labels.values < blank_index, labels.values,
|
||||||
labels.values - 1), labels.dense_shape)
|
labels.values - 1), labels.dense_shape)
|
||||||
|
return _ctc_loss_impl(
|
||||||
return ctc_loss(
|
|
||||||
labels=labels,
|
labels=labels,
|
||||||
inputs=logits,
|
inputs=logits,
|
||||||
sequence_length=logit_length,
|
sequence_length=logit_length,
|
||||||
time_major=logits_time_major)
|
time_major=logits_time_major,
|
||||||
|
use_cudnn=use_cudnn)
|
||||||
|
|
||||||
if blank_index is None:
|
if blank_index is None:
|
||||||
blank_index = 0
|
blank_index = 0
|
||||||
|
@ -408,6 +408,13 @@ struct PersistentRnnPlanDeleter {
|
|||||||
CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
|
CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
#if CUDNN_VERSION >= 7601
|
||||||
|
struct CtcLossDescriptorDeleter {
|
||||||
|
void operator()(cudnnCTCLossDescriptor_t descriptor) const {
|
||||||
|
CHECK_CUDNN_OK(cudnnDestroyCTCLossDescriptor(descriptor));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
// RAII wrappers for cuDNN types.
|
// RAII wrappers for cuDNN types.
|
||||||
using TensorDescriptor =
|
using TensorDescriptor =
|
||||||
@ -430,6 +437,10 @@ using DropoutDescriptor =
|
|||||||
using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
|
using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
|
||||||
using PersistentRnnPlan =
|
using PersistentRnnPlan =
|
||||||
std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
|
std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
|
||||||
|
#if CUDNN_VERSION >= 7601
|
||||||
|
using CtcLossDescriptor =
|
||||||
|
std::unique_ptr<cudnnCTCLossStruct, CtcLossDescriptorDeleter>;
|
||||||
|
#endif
|
||||||
|
|
||||||
// Factory methods for cuDNN types.
|
// Factory methods for cuDNN types.
|
||||||
TensorDescriptor CreateTensorDescriptor() {
|
TensorDescriptor CreateTensorDescriptor() {
|
||||||
@ -479,6 +490,13 @@ RnnDescriptor CreateRnnDescriptor() {
|
|||||||
CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
|
CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
|
||||||
return RnnDescriptor(result);
|
return RnnDescriptor(result);
|
||||||
}
|
}
|
||||||
|
#if CUDNN_VERSION >= 7601
|
||||||
|
CtcLossDescriptor CreateCtcLossDescriptor() {
|
||||||
|
cudnnCTCLossDescriptor_t result;
|
||||||
|
CHECK_CUDNN_OK(cudnnCreateCTCLossDescriptor(&result));
|
||||||
|
return CtcLossDescriptor(result);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
port::StatusOr<PersistentRnnPlan> CreatePersistentRnnPlan(
|
port::StatusOr<PersistentRnnPlan> CreatePersistentRnnPlan(
|
||||||
cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) {
|
cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) {
|
||||||
@ -1189,6 +1207,53 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
|||||||
SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
|
SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class CudnnCtcLossDescriptor : public dnn::CtcLossDescriptor {
|
||||||
|
CudnnCtcLossDescriptor(gpu::CtcLossDescriptor ctc_loss_desc,
|
||||||
|
cudnnDataType_t data_type,
|
||||||
|
cudnnLossNormalizationMode_t norm_mode,
|
||||||
|
cudnnNanPropagation_t grad_mode)
|
||||||
|
: ctc_loss_desc_(std::move(ctc_loss_desc)),
|
||||||
|
data_type_(data_type),
|
||||||
|
norm_mode_(norm_mode),
|
||||||
|
grad_mode_(grad_mode){}
|
||||||
|
|
||||||
|
public:
|
||||||
|
CudnnCtcLossDescriptor(CudnnCtcLossDescriptor&& other) = default;
|
||||||
|
|
||||||
|
static port::StatusOr<CudnnCtcLossDescriptor> Create(
|
||||||
|
cudnnDataType_t data_type,
|
||||||
|
cudnnLossNormalizationMode_t norm_mode=CUDNN_LOSS_NORMALIZATION_SOFTMAX,
|
||||||
|
cudnnNanPropagation_t grad_mode=CUDNN_NOT_PROPAGATE_NAN) {
|
||||||
|
gpu::CtcLossDescriptor ctc_loss_desc = CreateCtcLossDescriptor();
|
||||||
|
#if CUDNN_VERSION >= 7601
|
||||||
|
RETURN_IF_CUDNN_ERROR(cudnnSetCTCLossDescriptorEx(
|
||||||
|
/*ctcLossDesc=*/ctc_loss_desc.get(),
|
||||||
|
/*compType=*/data_type,
|
||||||
|
/*normMode=*/norm_mode,
|
||||||
|
/*gradMode=*/grad_mode));
|
||||||
|
#else
|
||||||
|
return port::Status(port::error::INVALID_ARGUMENT,
|
||||||
|
"No supported cudnnSetCTCLossDescriptorEx when "
|
||||||
|
"CUDNN_VERSION < 7.6.3");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return CudnnCtcLossDescriptor(std::move(ctc_loss_desc), data_type,
|
||||||
|
norm_mode, grad_mode);
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnnCTCLossDescriptor_t handle() const { return ctc_loss_desc_.get(); }
|
||||||
|
cudnnDataType_t data_type() const { return data_type_; }
|
||||||
|
cudnnLossNormalizationMode_t lnorm_mode() const { return norm_mode_; }
|
||||||
|
cudnnNanPropagation_t grad_mode() const { return grad_mode_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
gpu::CtcLossDescriptor ctc_loss_desc_;
|
||||||
|
cudnnDataType_t data_type_;
|
||||||
|
cudnnLossNormalizationMode_t norm_mode_;
|
||||||
|
cudnnNanPropagation_t grad_mode_;
|
||||||
|
SE_DISALLOW_COPY_AND_ASSIGN(CudnnCtcLossDescriptor);
|
||||||
|
};
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Check if the LSTM projection is used. If yes, an additional weigth matrix
|
// Check if the LSTM projection is used. If yes, an additional weigth matrix
|
||||||
@ -1656,6 +1721,39 @@ port::StatusOr<DeviceMemory<uint8>> CreateBatchNormBackwardWorkspace(
|
|||||||
}
|
}
|
||||||
return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
|
return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
port::StatusOr<DeviceMemory<uint8>> CreateCtcLossWorkspace(
|
||||||
|
Stream* stream, const CudnnHandle& cudnn,
|
||||||
|
const CudnnCtcLossDescriptor& ctc_loss_desc,
|
||||||
|
const CudnnRnnStateTensorDescriptor& probs_desc,
|
||||||
|
const CudnnRnnStateTensorDescriptor& grads_desc,
|
||||||
|
const absl::Span<const int32>& labels_data,
|
||||||
|
const absl::Span<const int32>& labels_lengths_data,
|
||||||
|
const absl::Span<const int32>& input_lengths_data,
|
||||||
|
ScratchAllocator* workspace_allocator) {
|
||||||
|
// Query the workspace size.
|
||||||
|
size_t workspace_size_in_bytes = 0;
|
||||||
|
#if CUDNN_VERSION >= 7601
|
||||||
|
RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
|
||||||
|
/*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
|
||||||
|
/*gradientsDesc=*/grads_desc.handle(),
|
||||||
|
/*labels=*/labels_data.data(),
|
||||||
|
/*labelLengths=*/labels_lengths_data.data(),
|
||||||
|
/*inputLengths=*/input_lengths_data.data(),
|
||||||
|
/*algo=*/CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
|
||||||
|
/*ctcLossDesc=*/ctc_loss_desc.handle(),
|
||||||
|
/*sizeInBytes=*/&workspace_size_in_bytes));
|
||||||
|
#else
|
||||||
|
return port::Status(port::error::INVALID_ARGUMENT,
|
||||||
|
"No supported cudnnGetCTCLossWorkspaceSize when "
|
||||||
|
"CUDNN_VERSION < 7.6.3");
|
||||||
|
#endif
|
||||||
|
// Allocate the workspace.
|
||||||
|
if (workspace_size_in_bytes == 0) {
|
||||||
|
return DeviceMemory<uint8>();
|
||||||
|
}
|
||||||
|
return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -1969,6 +2067,51 @@ port::Status CudnnSupport::DoRnnBackwardImpl(
|
|||||||
return port::Status::OK();
|
return port::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
port::Status CudnnSupport::DoCtcLossImpl(
|
||||||
|
Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
|
||||||
|
const DeviceMemory<float>& probs_data,
|
||||||
|
const absl::Span<const int32>& labels_data,
|
||||||
|
const absl::Span<const int32>& labels_lengths_data,
|
||||||
|
const absl::Span<const int32>& input_lengths_data,
|
||||||
|
DeviceMemory<float>* costs_data,
|
||||||
|
const CudnnRnnStateTensorDescriptor& grads_desc,
|
||||||
|
DeviceMemory<float>* grads_data,
|
||||||
|
const CudnnCtcLossDescriptor& ctc_loss_desc,
|
||||||
|
ScratchAllocator* workspace_allocator) {
|
||||||
|
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||||
|
|
||||||
|
SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
|
||||||
|
CreateCtcLossWorkspace(stream, cudnn, ctc_loss_desc,
|
||||||
|
probs_desc, grads_desc,
|
||||||
|
labels_data, labels_lengths_data,
|
||||||
|
input_lengths_data,
|
||||||
|
workspace_allocator));
|
||||||
|
int kNumTimestamps = probs_desc.num_layers();
|
||||||
|
int kBatchSize = probs_desc.batch_size();
|
||||||
|
int kNumLabels = probs_desc.data_size();
|
||||||
|
int total_size = kNumLabels * kNumTimestamps * kBatchSize;
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 7601
|
||||||
|
RETURN_IF_CUDNN_ERROR(cudnnCTCLoss(
|
||||||
|
/*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
|
||||||
|
/*probs=*/probs_data.opaque(), /*labels=*/labels_data.data(),
|
||||||
|
/*labelsLengths=*/labels_lengths_data.data(),
|
||||||
|
/*inputLengths=*/input_lengths_data.data(),
|
||||||
|
/*costs=*/costs_data->opaque(), /*gradientsDesc=*/grads_desc.handle(),
|
||||||
|
/*gradients=*/grads_data->opaque(),
|
||||||
|
/*algo=*/CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
|
||||||
|
/*ctcLossDesc=*/ctc_loss_desc.handle(),
|
||||||
|
/*workspace=*/workspace.opaque(),
|
||||||
|
/*workSpaceSizeInBytes=*/workspace.size()));
|
||||||
|
#else
|
||||||
|
return port::Status(port::error::INVALID_ARGUMENT,
|
||||||
|
"No supported cudnnCTCLoss when "
|
||||||
|
"CUDNN_VERSION < 7.6.3");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return port::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
||||||
CudnnSupport::createRnnDescriptor(
|
CudnnSupport::createRnnDescriptor(
|
||||||
int num_layers, int hidden_size, int input_size, int cell_size,
|
int num_layers, int hidden_size, int input_size, int cell_size,
|
||||||
@ -1992,6 +2135,16 @@ CudnnSupport::createRnnDescriptor(
|
|||||||
new CudnnRnnDescriptor(std::move(rnn_desc)));
|
new CudnnRnnDescriptor(std::move(rnn_desc)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
port::StatusOr<std::unique_ptr<dnn::CtcLossDescriptor>>
|
||||||
|
CudnnSupport::createCtcLossDescriptor(
|
||||||
|
dnn::DataType data_type) {
|
||||||
|
SE_ASSIGN_OR_RETURN(CudnnCtcLossDescriptor ctc_loss_desc,
|
||||||
|
CudnnCtcLossDescriptor::Create(
|
||||||
|
ToCudnnDataType(data_type)));
|
||||||
|
return std::unique_ptr<dnn::CtcLossDescriptor>(
|
||||||
|
new CudnnCtcLossDescriptor(std::move(ctc_loss_desc)));
|
||||||
|
}
|
||||||
|
|
||||||
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
||||||
CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length,
|
CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length,
|
||||||
int batch_size, int data_size,
|
int batch_size, int data_size,
|
||||||
@ -3828,6 +3981,31 @@ bool CudnnSupport::DoFusedConvolve(
|
|||||||
/*report_error=*/!output_profile_result);
|
/*report_error=*/!output_profile_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool CudnnSupport::DoCtcLoss(
|
||||||
|
Stream* stream, const dnn::RnnStateTensorDescriptor &probs_desc,
|
||||||
|
const DeviceMemory<float> &probs_data,
|
||||||
|
const absl::Span<const int32> &labels_data,
|
||||||
|
const absl::Span<const int32> &labels_lengths_data,
|
||||||
|
const absl::Span<const int32> &input_lengths_data,
|
||||||
|
DeviceMemory<float> *costs_data,
|
||||||
|
const dnn::RnnStateTensorDescriptor &grads_desc,
|
||||||
|
DeviceMemory<float> *grads_data,
|
||||||
|
const dnn::CtcLossDescriptor &ctc_loss_desc,
|
||||||
|
ScratchAllocator *workspace_allocator) {
|
||||||
|
const CudnnCtcLossDescriptor& cudnn_ctc_loss_desc =
|
||||||
|
static_cast<const CudnnCtcLossDescriptor&>(ctc_loss_desc);
|
||||||
|
const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
|
||||||
|
static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
|
||||||
|
const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
|
||||||
|
static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
|
||||||
|
return IsStatusOk(
|
||||||
|
DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data,
|
||||||
|
labels_lengths_data, input_lengths_data, costs_data,
|
||||||
|
cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc,
|
||||||
|
workspace_allocator),
|
||||||
|
/*report_error=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
bool CudnnSupport::DoTransformTensor(Stream* stream,
|
bool CudnnSupport::DoTransformTensor(Stream* stream,
|
||||||
const dnn::BatchDescriptor& input_desc,
|
const dnn::BatchDescriptor& input_desc,
|
||||||
dnn::DataType input_type,
|
dnn::DataType input_type,
|
||||||
|
@ -33,6 +33,7 @@ class GpuExecutor;
|
|||||||
class CudnnRnnDescriptor;
|
class CudnnRnnDescriptor;
|
||||||
class CudnnRnnSequenceTensorDescriptor;
|
class CudnnRnnSequenceTensorDescriptor;
|
||||||
class CudnnRnnStateTensorDescriptor;
|
class CudnnRnnStateTensorDescriptor;
|
||||||
|
class CudnnCtcLossDescriptor;
|
||||||
|
|
||||||
// Opaque and unique identifier for the cuDNN plugin.
|
// Opaque and unique identifier for the cuDNN plugin.
|
||||||
extern const PluginId kCuDnnPlugin;
|
extern const PluginId kCuDnnPlugin;
|
||||||
@ -54,6 +55,9 @@ class CudnnSupport : public dnn::DnnSupport {
|
|||||||
float dropout, uint64 seed, ScratchAllocator* state_allocator,
|
float dropout, uint64 seed, ScratchAllocator* state_allocator,
|
||||||
bool use_padded_io) override;
|
bool use_padded_io) override;
|
||||||
|
|
||||||
|
port::StatusOr<std::unique_ptr<dnn::CtcLossDescriptor>>
|
||||||
|
createCtcLossDescriptor(dnn::DataType data_type) override;
|
||||||
|
|
||||||
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
||||||
createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
|
createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
|
||||||
int data_size,
|
int data_size,
|
||||||
@ -562,6 +566,18 @@ class CudnnSupport : public dnn::DnnSupport {
|
|||||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||||
dnn::BatchDescriptor* output_batch_descriptor);
|
dnn::BatchDescriptor* output_batch_descriptor);
|
||||||
|
|
||||||
|
bool DoCtcLoss(
|
||||||
|
Stream* stream, const dnn::RnnStateTensorDescriptor &probs_desc,
|
||||||
|
const DeviceMemory<float> &probs_data,
|
||||||
|
const absl::Span<const int32> &labels_data,
|
||||||
|
const absl::Span<const int32> &labels_lengths_data,
|
||||||
|
const absl::Span<const int32> &input_lengths_data,
|
||||||
|
DeviceMemory<float> *costs_data,
|
||||||
|
const dnn::RnnStateTensorDescriptor &grads_desc,
|
||||||
|
DeviceMemory<float> *grads_data,
|
||||||
|
const dnn::CtcLossDescriptor &ctc_loss_desc,
|
||||||
|
ScratchAllocator *workspace_allocator);
|
||||||
|
|
||||||
bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
|
bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
|
||||||
dnn::DataType input_type,
|
dnn::DataType input_type,
|
||||||
const DeviceMemoryBase& input_data,
|
const DeviceMemoryBase& input_data,
|
||||||
@ -673,6 +689,18 @@ class CudnnSupport : public dnn::DnnSupport {
|
|||||||
ScratchAllocator* workspace_allocator,
|
ScratchAllocator* workspace_allocator,
|
||||||
dnn::ProfileResult* output_profile_result);
|
dnn::ProfileResult* output_profile_result);
|
||||||
|
|
||||||
|
port::Status DoCtcLossImpl(
|
||||||
|
Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
|
||||||
|
const DeviceMemory<float>& probs_data,
|
||||||
|
const absl::Span<const int32>& labels_data,
|
||||||
|
const absl::Span<const int32>& labels_lengths_data,
|
||||||
|
const absl::Span<const int32>& input_lengths_data,
|
||||||
|
DeviceMemory<float>* costs_data,
|
||||||
|
const CudnnRnnStateTensorDescriptor& grads_desc,
|
||||||
|
DeviceMemory<float>* grads_data,
|
||||||
|
const CudnnCtcLossDescriptor& ctc_loss_desc,
|
||||||
|
ScratchAllocator* workspace_allocator);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
port::Status DoPrepareForConvolution(
|
port::Status DoPrepareForConvolution(
|
||||||
dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
|
dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
|
||||||
|
@ -190,6 +190,15 @@ class RnnDescriptor {
|
|||||||
virtual ParamsRegions ParamsBiasRegions() const { return ParamsRegions(); }
|
virtual ParamsRegions ParamsBiasRegions() const { return ParamsRegions(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Specifies the CTC Loss computation.
|
||||||
|
//
|
||||||
|
// The user is responsible for releasing this descriptor when it is no longer
|
||||||
|
// in use. The destructor releases the underlying descriptors.
|
||||||
|
class CtcLossDescriptor {
|
||||||
|
public:
|
||||||
|
virtual ~CtcLossDescriptor() {}
|
||||||
|
};
|
||||||
|
|
||||||
// Specifies the sequence in a RNN model.
|
// Specifies the sequence in a RNN model.
|
||||||
//
|
//
|
||||||
// The user is responsible for releasing this descriptor when it is no longer
|
// The user is responsible for releasing this descriptor when it is no longer
|
||||||
@ -2133,6 +2142,16 @@ class DnnSupport {
|
|||||||
"createRnnDescriptor is unimplemented");
|
"createRnnDescriptor is unimplemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create an CTC Loss descriptor.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
// data_type: an enum to specify the data types used in this model.
|
||||||
|
virtual port::StatusOr<std::unique_ptr<dnn::CtcLossDescriptor>>
|
||||||
|
createCtcLossDescriptor(dnn::DataType data_type) {
|
||||||
|
return port::Status(port::error::UNIMPLEMENTED,
|
||||||
|
"createCtcLossDescriptor is unimplemented");
|
||||||
|
}
|
||||||
|
|
||||||
// Create a RNN sequence descriptor that specifies either the input or output
|
// Create a RNN sequence descriptor that specifies either the input or output
|
||||||
// sequence. The caller retains the ownership of the returned descriptor.
|
// sequence. The caller retains the ownership of the returned descriptor.
|
||||||
//
|
//
|
||||||
@ -2383,6 +2402,40 @@ class DnnSupport {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enqueue a CTC Loss operation onto the stream.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
// stream: pointer to the stream where this operation should be enqueued to.
|
||||||
|
// probs_desc: specifies the shape and the data layout of the input tensor.
|
||||||
|
// probs_data: the device memory region that contains the input tensor.
|
||||||
|
// labels_data: the device memory region that contains the labels_value
|
||||||
|
// tensor.
|
||||||
|
// labels_lengths_data: the device memory region that contains the
|
||||||
|
// labels_lengths tensor
|
||||||
|
// input_lengths_data: the device memory region that contains the seq_lengths
|
||||||
|
// tensor
|
||||||
|
// costs_data: the device memory region that contains the costs tensor.
|
||||||
|
// grads_desc: specifies the shape and the data layout of the grads tensor.
|
||||||
|
// grads_data: the device memory region that contains the grads tensor.
|
||||||
|
// ctc_loss_desc: a CTCLoss descriptor created by createCTCLossDescriptor.
|
||||||
|
// workspace_allocator: a memory allocator that creates the temporary
|
||||||
|
// workspace memory used by this operation. The caller is responsible for
|
||||||
|
// keeping the memory alive long enough for this operation, and recylces
|
||||||
|
// afterwards.
|
||||||
|
virtual bool DoCtcLoss(Stream* stream,
|
||||||
|
const dnn::RnnStateTensorDescriptor &probs_desc,
|
||||||
|
const DeviceMemory<float> &probs_data,
|
||||||
|
const absl::Span<const int32> &labels_data,
|
||||||
|
const absl::Span<const int32> &labels_lengths_data,
|
||||||
|
const absl::Span<const int32> &input_lengths_data,
|
||||||
|
DeviceMemory<float> *costs_data,
|
||||||
|
const dnn::RnnStateTensorDescriptor &grads_desc,
|
||||||
|
DeviceMemory<float> *grads_data,
|
||||||
|
const dnn::CtcLossDescriptor &ctc_loss_desc,
|
||||||
|
ScratchAllocator *workspace_allocator) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Transforms a tensor into another tensor with a different layout and/or data
|
// Transforms a tensor into another tensor with a different layout and/or data
|
||||||
// type.
|
// type.
|
||||||
//
|
//
|
||||||
|
@ -5230,6 +5230,33 @@ Stream &Stream::ThenRnnBackward(
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
|
||||||
|
const DeviceMemory<float> &probs_data,
|
||||||
|
const absl::Span<const int32> &labels_data,
|
||||||
|
const absl::Span<const int32> &labels_lengths_data,
|
||||||
|
const absl::Span<const int32> &input_lengths_data,
|
||||||
|
DeviceMemory<float> *costs_data,
|
||||||
|
const dnn::RnnStateTensorDescriptor &grads_desc,
|
||||||
|
DeviceMemory<float> *grads_data,
|
||||||
|
const dnn::CtcLossDescriptor &ctc_loss_desc,
|
||||||
|
ScratchAllocator *workspace_allocator) {
|
||||||
|
if (ok()) {
|
||||||
|
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||||
|
auto status = dnn->DoCtcLoss(
|
||||||
|
this, probs_desc, probs_data, labels_data, labels_lengths_data,
|
||||||
|
input_lengths_data, costs_data, grads_desc, grads_data, ctc_loss_desc,
|
||||||
|
workspace_allocator);
|
||||||
|
if (!status) {
|
||||||
|
SetError();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
SetErrorAndLogNoDnnSupport();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
|
Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
|
||||||
dnn::DataType input_type,
|
dnn::DataType input_type,
|
||||||
const DeviceMemoryBase &input_data,
|
const DeviceMemoryBase &input_data,
|
||||||
|
@ -1912,6 +1912,20 @@ class Stream {
|
|||||||
ScratchAllocator *workspace_allocator,
|
ScratchAllocator *workspace_allocator,
|
||||||
dnn::ProfileResult *output_profile_result);
|
dnn::ProfileResult *output_profile_result);
|
||||||
|
|
||||||
|
// Enqueue a CTCLoss operation onto the stream.
|
||||||
|
// See DnnSupport::DoCtcLoss for more details.
|
||||||
|
Stream &ThenCtcLoss(
|
||||||
|
const dnn::RnnStateTensorDescriptor &probs_desc,
|
||||||
|
const DeviceMemory<float> &probs_data,
|
||||||
|
const absl::Span<const int32> &labels_data,
|
||||||
|
const absl::Span<const int32> &labels_lengths_data,
|
||||||
|
const absl::Span<const int32> &input_lengths_data,
|
||||||
|
DeviceMemory<float> *costs_data,
|
||||||
|
const dnn::RnnStateTensorDescriptor &grads_desc,
|
||||||
|
DeviceMemory<float> *grads_data,
|
||||||
|
const dnn::CtcLossDescriptor &ctc_loss_desc,
|
||||||
|
ScratchAllocator *workspace_allocator);
|
||||||
|
|
||||||
// Enqueue onto the stream a operation that transforms a tensor.
|
// Enqueue onto the stream a operation that transforms a tensor.
|
||||||
// See DnnSupport::DoTransformTensor for more details.
|
// See DnnSupport::DoTransformTensor for more details.
|
||||||
Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
|
Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
|
||||||
|
@ -353,6 +353,16 @@ StreamExecutor::createRnnDescriptor(
|
|||||||
state_allocator, use_padded_io);
|
state_allocator, use_padded_io);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
port::StatusOr<std::unique_ptr<dnn::CtcLossDescriptor>>
|
||||||
|
StreamExecutor::createCtcLossDescriptor(dnn::DataType data_type) {
|
||||||
|
dnn::DnnSupport *dnn_support = AsDnn();
|
||||||
|
if (!dnn_support) {
|
||||||
|
return port::Status(port::error::UNKNOWN,
|
||||||
|
"Fail to find the dnn implementation.");
|
||||||
|
}
|
||||||
|
return dnn_support->createCtcLossDescriptor(data_type);
|
||||||
|
}
|
||||||
|
|
||||||
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
||||||
StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
|
StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
|
||||||
int batch_size, int data_size,
|
int batch_size, int data_size,
|
||||||
|
@ -399,6 +399,11 @@ class StreamExecutor {
|
|||||||
float dropout, uint64 seed, ScratchAllocator *state_allocator,
|
float dropout, uint64 seed, ScratchAllocator *state_allocator,
|
||||||
bool use_padded_io);
|
bool use_padded_io);
|
||||||
|
|
||||||
|
// Create an CTC loss descriptor. The caller retains the ownership of the
|
||||||
|
// descriptor.
|
||||||
|
port::StatusOr<std::unique_ptr<dnn::CtcLossDescriptor>>
|
||||||
|
createCtcLossDescriptor(dnn::DataType data_type);
|
||||||
|
|
||||||
// Create a RNN sequence descriptor that specifies either the input or output
|
// Create a RNN sequence descriptor that specifies either the input or output
|
||||||
// sequence. The caller retains the ownership of the returned descriptor.
|
// sequence. The caller retains the ownership of the returned descriptor.
|
||||||
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user