diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index fd0d60103e8..8c634df061a 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2296,7 +2296,9 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core/util/ctc:ctc_beam_search_lib", "//tensorflow/core/util/ctc:ctc_loss_calculator_lib", - ], + ] + if_cuda([ + "//tensorflow/core:stream_executor", + ]), ) tf_cc_test( diff --git a/tensorflow/core/kernels/ctc_loss_op.cc b/tensorflow/core/kernels/ctc_loss_op.cc index 995d28a158c..b1379cfb9e8 100644 --- a/tensorflow/core/kernels/ctc_loss_op.cc +++ b/tensorflow/core/kernels/ctc_loss_op.cc @@ -15,6 +15,10 @@ limitations under the License. // See docs in ../ops/ctc_ops.cc. +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -25,8 +29,89 @@ limitations under the License. #include "tensorflow/core/util/ctc/ctc_loss_calculator.h" #include "tensorflow/core/util/sparse/sparse_tensor.h" +#if GOOGLE_CUDA +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/stream_executor_util.h" +#endif // GOOGLE_CUDA + namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; +#if GOOGLE_CUDA +using GPUDevice = Eigen::GpuDevice; + +namespace { +using se::DeviceMemory; +using se::Stream; +using se::StreamExecutor; +using se::ScratchAllocator; +using se::dnn::CtcLossDescriptor; +using se::dnn::RnnStateTensorDescriptor; +using se::dnn::ToDataType; +using se::port::StatusOr; + +template +void DoHistogram(OpKernelContext* ctx, const Tensor* labels_indices, + int num_indices, int batch_size, + std::vector *labels_lengths) { + const T* h_in = labels_indices->flat().data(); + for(int i = 0; i < num_indices; i++) { + T key = h_in[i * 2]; + (*labels_lengths)[key]++; + OP_REQUIRES(ctx, (*labels_lengths)[key] < 256, + errors::InvalidArgument("Label lengths cannot exceed 256" + "for GPU implementation")); + } +} + +// A helper to allocate temporary scratch memory for cudnnCTCLoss ops. It +// takes the ownership of the underlying memory. The expectation is that the +// memory should be alive for the span of the cudnnCTCLoss itself. +template +class CudnnCtcLossAllocatorInTemp : public ScratchAllocator { + public: + ~CudnnCtcLossAllocatorInTemp() override = default; + + explicit CudnnCtcLossAllocatorInTemp(OpKernelContext* context) + : context_(context) {} + + int64 GetMemoryLimitInBytes() override { + return std::numeric_limits::max(); + } + + StatusOr> AllocateBytes(int64 byte_size) override { + Tensor temporary_memory; + const DataType tf_data_type = DataTypeToEnum::v(); + int64 allocate_count = + Eigen::divup(byte_size, static_cast(sizeof(T))); + Status allocation_status(context_->allocate_temp( + tf_data_type, TensorShape({allocate_count}), &temporary_memory)); + if (!allocation_status.ok()) { + return allocation_status; + } + // Hold the reference of the allocated tensors until the end of the + // allocator. + allocated_tensors_.push_back(temporary_memory); + total_byte_size_ += byte_size; + return DeviceMemory::MakeFromByteSize( + temporary_memory.template flat().data(), + temporary_memory.template flat().size() * sizeof(T)); + } + + int64 TotalByteSize() const { return total_byte_size_; } + + Tensor get_allocated_tensor(int index) const { + return allocated_tensors_[index]; + } + + private: + int64 total_byte_size_ = 0; + OpKernelContext* context_; // not owned + std::vector allocated_tensors_; +}; +} // end namespace +#endif // GOOGLE_CUDA + template class CTCLossOp : public OpKernel { typedef Eigen::Map< @@ -186,4 +271,156 @@ REGISTER_CPU(double); #undef REGISTER_CPU +#if GOOGLE_CUDA +class CTCLossOpGPU : public OpKernel { + + public: + explicit CTCLossOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("preprocess_collapse_repeated", + &preprocess_collapse_repeated_)); + OP_REQUIRES_OK(ctx, + ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs", + &ignore_longer_outputs_than_inputs_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* inputs; + const Tensor* labels_indices; + const Tensor* labels_values; + const Tensor* seq_len; + OP_REQUIRES_OK(ctx, ctx->input("inputs", &inputs)); + OP_REQUIRES_OK(ctx, ctx->input("labels_indices", &labels_indices)); + OP_REQUIRES_OK(ctx, ctx->input("labels_values", &labels_values)); + OP_REQUIRES_OK(ctx, ctx->input("sequence_length", &seq_len)); + + OP_REQUIRES(ctx, inputs->shape().dims() == 3, + errors::InvalidArgument("inputs is not a 3-Tensor")); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(seq_len->shape()), + errors::InvalidArgument("sequence_length is not a vector")); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(labels_indices->shape()), + errors::InvalidArgument("labels_indices is not a matrix")); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_values->shape()), + errors::InvalidArgument("labels_values is not a vector")); + + const TensorShape& inputs_shape = inputs->shape(); + const int64 max_time_raw = inputs_shape.dim_size(0); + const int64 batch_size_raw = inputs_shape.dim_size(1); + const int64 num_classes_raw = inputs_shape.dim_size(2); + OP_REQUIRES( + ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits::max()), + errors::InvalidArgument("num_classes cannot exceed max int")); + const int max_time = static_cast(max_time_raw); + const int batch_size = static_cast(batch_size_raw); + const int num_classes = static_cast(num_classes_raw); + + OP_REQUIRES( + ctx, batch_size == seq_len->dim_size(0), + errors::InvalidArgument("len(sequence_length) != batch_size. ", + "len(sequence_length): ", seq_len->dim_size(0), + " batch_size: ", batch_size)); + + OP_REQUIRES(ctx, labels_indices->dim_size(0) == labels_values->dim_size(0), + errors::InvalidArgument( + "labels_indices and labels_values must contain the " + "same number of rows, but saw shapes: ", + labels_indices->shape().DebugString(), " vs. ", + labels_values->shape().DebugString())); + auto num_indices = labels_indices->dim_size(0); + + OP_REQUIRES(ctx, batch_size != 0, + errors::InvalidArgument("batch_size must not be 0")); + + + Tensor* loss = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss)); + + Tensor* gradient = nullptr; + OP_REQUIRES_OK(ctx, + ctx->allocate_output("gradient", inputs_shape, &gradient)); + + OP_REQUIRES(ctx, preprocess_collapse_repeated_ == false, + errors::InvalidArgument("GPU CTCLossOp requires " + "preprocess_collapse_repeated to be " + "false")); + OP_REQUIRES(ctx, ctc_merge_repeated_ == true, + errors::InvalidArgument("GPU CTCLossOp requires " + "ctc_merge_repeated_ to be " + "true")); + OP_REQUIRES(ctx, ignore_longer_outputs_than_inputs_ == false, + errors::InvalidArgument("GPU CTCLossOp requires " + "ignore_longer_outputs_than_inputs_ to" + "be false")); + + // Convert the labels_indices to labels_lengths + std::vector labels_lengths(batch_size, 0); + DoHistogram(ctx, labels_indices, num_indices, batch_size, + &labels_lengths); + + StreamExecutor* executor = ctx->op_device_context()->stream()->parent(); + se::dnn::DataType data_type = ToDataType::value; + + std::unique_ptr ctc_loss_desc; + std::unique_ptr probs_desc; + std::unique_ptr grads_desc; + + auto ctc_loss_desc_s = executor->createCtcLossDescriptor(data_type); + OP_REQUIRES_OK(ctx, ctc_loss_desc_s.status()); + ctc_loss_desc = ctc_loss_desc_s.ConsumeValueOrDie(); + + auto probs_desc_s = executor->createRnnStateTensorDescriptor( + max_time, batch_size, num_classes, data_type); + OP_REQUIRES_OK(ctx, probs_desc_s.status()); + probs_desc = probs_desc_s.ConsumeValueOrDie(); + + auto grads_desc_s = executor->createRnnStateTensorDescriptor( + max_time, batch_size, num_classes, data_type); + OP_REQUIRES_OK(ctx, grads_desc_s.status()); + grads_desc = grads_desc_s.ConsumeValueOrDie(); + + absl::Span labels_data; + absl::Span labels_lengths_data; + absl::Span input_lengths_data; + labels_data = absl::Span( + labels_values->flat().data(), num_indices); + labels_lengths_data = absl::Span( + labels_lengths.data(), batch_size); + input_lengths_data = absl::Span( + seq_len->flat().data(), batch_size); + + auto probs_data = StreamExecutorUtil::AsDeviceMemory(*inputs); + auto costs_data = StreamExecutorUtil::AsDeviceMemory(*loss); + auto grads_data = StreamExecutorUtil::AsDeviceMemory(*gradient); + + CudnnCtcLossAllocatorInTemp workspace_allocator(ctx); + + Stream* stream = ctx->op_device_context()->stream(); + bool cudnn_launch_status = + stream + ->ThenCtcLoss( + *probs_desc, probs_data, labels_data, labels_lengths_data, + input_lengths_data, &costs_data, *grads_desc, &grads_data, + *ctc_loss_desc, &workspace_allocator) + .ok(); + + if (!cudnn_launch_status) { + ctx->SetStatus( + errors::Internal("cuDNN CTCLoss launch failure")); + } + } + + private: + bool preprocess_collapse_repeated_; + bool ctc_merge_repeated_; + bool ignore_longer_outputs_than_inputs_; + + TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOpGPU); +}; + +REGISTER_KERNEL_BUILDER(Name("CTCLossV2").Device(DEVICE_GPU) + .HostMemory("labels_indices") + .HostMemory("labels_values") + .HostMemory("sequence_length"), + CTCLossOpGPU); +#endif // GOOGLE_CUDA } // end namespace tensorflow diff --git a/tensorflow/core/ops/ctc_ops.cc b/tensorflow/core/ops/ctc_ops.cc index f82ebb77001..77b1b576295 100644 --- a/tensorflow/core/ops/ctc_ops.cc +++ b/tensorflow/core/ops/ctc_ops.cc @@ -62,6 +62,43 @@ REGISTER_OP("CTCLoss") return Status::OK(); }); +REGISTER_OP("CTCLossV2") + .Input("inputs: float") + .Input("labels_indices: int64") + .Input("labels_values: int32") + .Input("sequence_length: int32") + .Attr("preprocess_collapse_repeated: bool = false") + .Attr("ctc_merge_repeated: bool = true") + .Attr("ignore_longer_outputs_than_inputs: bool = false") + .Output("loss: float") + .Output("gradient: float") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle inputs; + ShapeHandle labels_indices; + ShapeHandle labels_values; + ShapeHandle sequence_length; + + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &labels_indices)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &labels_values)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &sequence_length)); + + DimensionHandle unused; + TF_RETURN_IF_ERROR(c->Merge(c->Dim(labels_indices, 0), + c->Dim(labels_values, 0), &unused)); + + // Get batch size from inputs and sequence_length, and update inputs + // with the merged batch_size since it is returned. + DimensionHandle batch_size; + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size)); + TF_RETURN_IF_ERROR(c->ReplaceDim(inputs, 1, batch_size, &inputs)); + + c->set_output(0, c->Vector(batch_size)); + c->set_output(1, inputs); + return Status::OK(); + }); + REGISTER_OP("CTCGreedyDecoder") .Input("inputs: T") .Input("sequence_length: int32") diff --git a/tensorflow/python/kernel_tests/ctc_loss_op_test.py b/tensorflow/python/kernel_tests/ctc_loss_op_test.py index 85a121e2d9f..0d9f7e6b53d 100644 --- a/tensorflow/python/kernel_tests/ctc_loss_op_test.py +++ b/tensorflow/python/kernel_tests/ctc_loss_op_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import numpy as np +import os from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -840,4 +841,5 @@ class CTCLossTestV2(test.TestCase): [[1.0, 2.0], [5.0, 8.0], [14.0, 20.0]], out) if __name__ == "__main__": + os.environ['TF_CUDNN_CTC_LOSS'] = '1' test.main() diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py index a1d75f61fa2..5ad687bc251 100644 --- a/tensorflow/python/ops/ctc_ops.py +++ b/tensorflow/python/ops/ctc_ops.py @@ -42,6 +42,7 @@ from tensorflow.python.util import deprecation from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export +import os # pylint: disable=protected-access, invalid-name @tf_export(v1=["nn.ctc_loss"]) @@ -155,6 +156,24 @@ def ctc_loss(labels, Raises: TypeError: if labels is not a `SparseTensor`. """ + return _ctc_loss_impl(labels, inputs, sequence_length, + preprocess_collapse_repeated, ctc_merge_repeated, + ignore_longer_outputs_than_inputs, time_major, logits, + use_cudnn=False) + +def _ctc_loss_impl(labels, + inputs=None, + sequence_length=None, + preprocess_collapse_repeated=False, + ctc_merge_repeated=True, + ignore_longer_outputs_than_inputs=False, + time_major=True, + logits=None, + use_cudnn=False): + # Helper function of ctc_loss with one additional param: + # use_cudnn: A bool to enable cuDNN CTC loss operation. If true, the blank + # index has to be 0. + # The second, third, etc output tensors contain the gradients. We use it in # _CTCLossGrad() below. if not isinstance(labels, sparse_tensor.SparseTensor): @@ -166,7 +185,14 @@ def ctc_loss(labels, if not time_major: inputs = array_ops.transpose(inputs, [1, 0, 2]) # (B,T,N) => (T,B,N) - loss, _ = gen_ctc_ops.ctc_loss( + # gen_ctc_ops.ctc_loss_v2 differs from gen_ctc_ops.ctc_loss. v2 assumes the + # blank index to be 0, but v1 views it as the last index. + if use_cudnn: + ctc_loss_func = gen_ctc_ops.ctc_loss_v2 + else: + ctc_loss_func = gen_ctc_ops.ctc_loss + + loss, _ = ctc_loss_func( inputs, labels.indices, labels.values, @@ -177,19 +203,8 @@ def ctc_loss(labels, return loss - # pylint: disable=unused-argument -@ops.RegisterGradient("CTCLoss") -def _CTCLossGrad(op, grad_loss, _): - """The derivative provided by CTC Loss. - - Args: - op: the CTCLoss op. - grad_loss: The backprop for cost. - - Returns: - The CTC Loss gradient. - """ +def _CTCLossGradImpl(op, grad_loss, _): # Outputs are: loss, grad # # Currently there is no way to take the second derivative of this op @@ -205,7 +220,34 @@ def _CTCLossGrad(op, grad_loss, _): # labels_indices, labels_values and sequence_length return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None] +# pylint: disable=unused-argument +@ops.RegisterGradient("CTCLoss") +def _CTCLossGrad(op, grad_loss, _): + """The derivative provided by CTC Loss. + Args: + op: the CTCLoss op. + grad_loss: The backprop for cost. + + Returns: + The CTC Loss gradient. + """ + return _CTCLossGradImpl(op, grad_loss, _) + +# pylint: disable=unused-argument +@ops.RegisterGradient("CTCLossV2") +def _CTCLossGrad(op, grad_loss, _): + """The derivative provided by CTC Loss V2. + + Args: + op: the CTCLossV2 op. + grad_loss: The backprop for cost. + + Returns: + The CTC Loss V2 gradient. + """ + return _CTCLossGradImpl(op, grad_loss, _) + @tf_export("nn.ctc_greedy_decoder") def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True): """Performs greedy decoding on the logits given in input (best path). @@ -654,26 +696,36 @@ def ctc_loss_v2(labels, raise ValueError( "blank_index must be given when using SparseTensor labels.") + _ctc_use_cudnn = os.environ.get("TF_CUDNN_CTC_LOSS", "0") + if _ctc_use_cudnn == "1": + use_cudnn = True + else: + use_cudnn = False + if blank_index < 0: blank_index += _get_dim(logits, 2) - if blank_index != _get_dim(logits, 2) - 1: - logits = array_ops.concat([ - logits[:, :, :blank_index], - logits[:, :, blank_index + 1:], - logits[:, :, blank_index:blank_index + 1], - ], - axis=2) + part_before = logits[:, :, :blank_index] + part_after = logits[:, :, blank_index + 1:] + part_blank = logits[:, :, blank_index:blank_index + 1] + if use_cudnn: + logits = array_ops.concat([part_blank, part_before, part_after], axis=2) + labels = sparse_tensor.SparseTensor( + labels.indices, + array_ops.where(labels.values < blank_index, labels.values + 1, + labels.values), labels.dense_shape) + else: + logits = array_ops.concat([part_before, part_after, part_blank], axis=2) labels = sparse_tensor.SparseTensor( labels.indices, array_ops.where(labels.values < blank_index, labels.values, labels.values - 1), labels.dense_shape) - - return ctc_loss( + return _ctc_loss_impl( labels=labels, inputs=logits, sequence_length=logit_length, - time_major=logits_time_major) + time_major=logits_time_major, + use_cudnn=use_cudnn) if blank_index is None: blank_index = 0 diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index d15fdd06556..94e795a9aeb 100755 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -408,6 +408,13 @@ struct PersistentRnnPlanDeleter { CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan)); } }; +#if CUDNN_VERSION >= 7601 +struct CtcLossDescriptorDeleter { + void operator()(cudnnCTCLossDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyCTCLossDescriptor(descriptor)); + } +}; +#endif // RAII wrappers for cuDNN types. using TensorDescriptor = @@ -430,6 +437,10 @@ using DropoutDescriptor = using RnnDescriptor = std::unique_ptr; using PersistentRnnPlan = std::unique_ptr; +#if CUDNN_VERSION >= 7601 +using CtcLossDescriptor = + std::unique_ptr; +#endif // Factory methods for cuDNN types. TensorDescriptor CreateTensorDescriptor() { @@ -479,6 +490,13 @@ RnnDescriptor CreateRnnDescriptor() { CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result)); return RnnDescriptor(result); } +#if CUDNN_VERSION >= 7601 +CtcLossDescriptor CreateCtcLossDescriptor() { + cudnnCTCLossDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreateCTCLossDescriptor(&result)); + return CtcLossDescriptor(result); +} +#endif port::StatusOr CreatePersistentRnnPlan( cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) { @@ -1189,6 +1207,53 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor); }; +class CudnnCtcLossDescriptor : public dnn::CtcLossDescriptor { + CudnnCtcLossDescriptor(gpu::CtcLossDescriptor ctc_loss_desc, + cudnnDataType_t data_type, + cudnnLossNormalizationMode_t norm_mode, + cudnnNanPropagation_t grad_mode) + : ctc_loss_desc_(std::move(ctc_loss_desc)), + data_type_(data_type), + norm_mode_(norm_mode), + grad_mode_(grad_mode){} + + public: + CudnnCtcLossDescriptor(CudnnCtcLossDescriptor&& other) = default; + + static port::StatusOr Create( + cudnnDataType_t data_type, + cudnnLossNormalizationMode_t norm_mode=CUDNN_LOSS_NORMALIZATION_SOFTMAX, + cudnnNanPropagation_t grad_mode=CUDNN_NOT_PROPAGATE_NAN) { + gpu::CtcLossDescriptor ctc_loss_desc = CreateCtcLossDescriptor(); +#if CUDNN_VERSION >= 7601 + RETURN_IF_CUDNN_ERROR(cudnnSetCTCLossDescriptorEx( + /*ctcLossDesc=*/ctc_loss_desc.get(), + /*compType=*/data_type, + /*normMode=*/norm_mode, + /*gradMode=*/grad_mode)); +#else + return port::Status(port::error::INVALID_ARGUMENT, + "No supported cudnnSetCTCLossDescriptorEx when " + "CUDNN_VERSION < 7.6.3"); +#endif + + return CudnnCtcLossDescriptor(std::move(ctc_loss_desc), data_type, + norm_mode, grad_mode); + } + + cudnnCTCLossDescriptor_t handle() const { return ctc_loss_desc_.get(); } + cudnnDataType_t data_type() const { return data_type_; } + cudnnLossNormalizationMode_t lnorm_mode() const { return norm_mode_; } + cudnnNanPropagation_t grad_mode() const { return grad_mode_; } + + private: + gpu::CtcLossDescriptor ctc_loss_desc_; + cudnnDataType_t data_type_; + cudnnLossNormalizationMode_t norm_mode_; + cudnnNanPropagation_t grad_mode_; + SE_DISALLOW_COPY_AND_ASSIGN(CudnnCtcLossDescriptor); +}; + namespace { // Check if the LSTM projection is used. If yes, an additional weigth matrix @@ -1656,6 +1721,39 @@ port::StatusOr> CreateBatchNormBackwardWorkspace( } return workspace_allocator->AllocateBytes(workspace_size_in_bytes); } + +port::StatusOr> CreateCtcLossWorkspace( + Stream* stream, const CudnnHandle& cudnn, + const CudnnCtcLossDescriptor& ctc_loss_desc, + const CudnnRnnStateTensorDescriptor& probs_desc, + const CudnnRnnStateTensorDescriptor& grads_desc, + const absl::Span& labels_data, + const absl::Span& labels_lengths_data, + const absl::Span& input_lengths_data, + ScratchAllocator* workspace_allocator) { + // Query the workspace size. + size_t workspace_size_in_bytes = 0; +#if CUDNN_VERSION >= 7601 + RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize( + /*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(), + /*gradientsDesc=*/grads_desc.handle(), + /*labels=*/labels_data.data(), + /*labelLengths=*/labels_lengths_data.data(), + /*inputLengths=*/input_lengths_data.data(), + /*algo=*/CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, + /*ctcLossDesc=*/ctc_loss_desc.handle(), + /*sizeInBytes=*/&workspace_size_in_bytes)); +#else + return port::Status(port::error::INVALID_ARGUMENT, + "No supported cudnnGetCTCLossWorkspaceSize when " + "CUDNN_VERSION < 7.6.3"); +#endif + // Allocate the workspace. + if (workspace_size_in_bytes == 0) { + return DeviceMemory(); + } + return workspace_allocator->AllocateBytes(workspace_size_in_bytes); +} #endif } // namespace @@ -1969,6 +2067,51 @@ port::Status CudnnSupport::DoRnnBackwardImpl( return port::Status::OK(); } +port::Status CudnnSupport::DoCtcLossImpl( + Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc, + const DeviceMemory& probs_data, + const absl::Span& labels_data, + const absl::Span& labels_lengths_data, + const absl::Span& input_lengths_data, + DeviceMemory* costs_data, + const CudnnRnnStateTensorDescriptor& grads_desc, + DeviceMemory* grads_data, + const CudnnCtcLossDescriptor& ctc_loss_desc, + ScratchAllocator* workspace_allocator) { + auto cudnn = cudnn_->GetHandle(parent_, stream); + + SE_ASSIGN_OR_RETURN(DeviceMemory workspace, + CreateCtcLossWorkspace(stream, cudnn, ctc_loss_desc, + probs_desc, grads_desc, + labels_data, labels_lengths_data, + input_lengths_data, + workspace_allocator)); + int kNumTimestamps = probs_desc.num_layers(); + int kBatchSize = probs_desc.batch_size(); + int kNumLabels = probs_desc.data_size(); + int total_size = kNumLabels * kNumTimestamps * kBatchSize; + +#if CUDNN_VERSION >= 7601 + RETURN_IF_CUDNN_ERROR(cudnnCTCLoss( + /*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(), + /*probs=*/probs_data.opaque(), /*labels=*/labels_data.data(), + /*labelsLengths=*/labels_lengths_data.data(), + /*inputLengths=*/input_lengths_data.data(), + /*costs=*/costs_data->opaque(), /*gradientsDesc=*/grads_desc.handle(), + /*gradients=*/grads_data->opaque(), + /*algo=*/CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, + /*ctcLossDesc=*/ctc_loss_desc.handle(), + /*workspace=*/workspace.opaque(), + /*workSpaceSizeInBytes=*/workspace.size())); +#else + return port::Status(port::error::INVALID_ARGUMENT, + "No supported cudnnCTCLoss when " + "CUDNN_VERSION < 7.6.3"); +#endif + + return port::Status::OK(); +} + port::StatusOr> CudnnSupport::createRnnDescriptor( int num_layers, int hidden_size, int input_size, int cell_size, @@ -1992,6 +2135,16 @@ CudnnSupport::createRnnDescriptor( new CudnnRnnDescriptor(std::move(rnn_desc))); } +port::StatusOr> +CudnnSupport::createCtcLossDescriptor( + dnn::DataType data_type) { + SE_ASSIGN_OR_RETURN(CudnnCtcLossDescriptor ctc_loss_desc, + CudnnCtcLossDescriptor::Create( + ToCudnnDataType(data_type))); + return std::unique_ptr( + new CudnnCtcLossDescriptor(std::move(ctc_loss_desc))); +} + port::StatusOr> CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, int data_size, @@ -3828,6 +3981,31 @@ bool CudnnSupport::DoFusedConvolve( /*report_error=*/!output_profile_result); } +bool CudnnSupport::DoCtcLoss( + Stream* stream, const dnn::RnnStateTensorDescriptor &probs_desc, + const DeviceMemory &probs_data, + const absl::Span &labels_data, + const absl::Span &labels_lengths_data, + const absl::Span &input_lengths_data, + DeviceMemory *costs_data, + const dnn::RnnStateTensorDescriptor &grads_desc, + DeviceMemory *grads_data, + const dnn::CtcLossDescriptor &ctc_loss_desc, + ScratchAllocator *workspace_allocator) { + const CudnnCtcLossDescriptor& cudnn_ctc_loss_desc = + static_cast(ctc_loss_desc); + const CudnnRnnStateTensorDescriptor& cudnn_probs_desc = + static_cast(probs_desc); + const CudnnRnnStateTensorDescriptor& cudnn_grads_desc = + static_cast(grads_desc); + return IsStatusOk( + DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data, + labels_lengths_data, input_lengths_data, costs_data, + cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc, + workspace_allocator), + /*report_error=*/true); +} + bool CudnnSupport::DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc, dnn::DataType input_type, diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 93beee85a5a..e1a1f9c0674 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -33,6 +33,7 @@ class GpuExecutor; class CudnnRnnDescriptor; class CudnnRnnSequenceTensorDescriptor; class CudnnRnnStateTensorDescriptor; +class CudnnCtcLossDescriptor; // Opaque and unique identifier for the cuDNN plugin. extern const PluginId kCuDnnPlugin; @@ -54,6 +55,9 @@ class CudnnSupport : public dnn::DnnSupport { float dropout, uint64 seed, ScratchAllocator* state_allocator, bool use_padded_io) override; + port::StatusOr> + createCtcLossDescriptor(dnn::DataType data_type) override; + port::StatusOr> createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, int data_size, @@ -562,6 +566,18 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::ConvolutionDescriptor& convolution_descriptor, dnn::BatchDescriptor* output_batch_descriptor); + bool DoCtcLoss( + Stream* stream, const dnn::RnnStateTensorDescriptor &probs_desc, + const DeviceMemory &probs_data, + const absl::Span &labels_data, + const absl::Span &labels_lengths_data, + const absl::Span &input_lengths_data, + DeviceMemory *costs_data, + const dnn::RnnStateTensorDescriptor &grads_desc, + DeviceMemory *grads_data, + const dnn::CtcLossDescriptor &ctc_loss_desc, + ScratchAllocator *workspace_allocator); + bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc, dnn::DataType input_type, const DeviceMemoryBase& input_data, @@ -673,6 +689,18 @@ class CudnnSupport : public dnn::DnnSupport { ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result); + port::Status DoCtcLossImpl( + Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc, + const DeviceMemory& probs_data, + const absl::Span& labels_data, + const absl::Span& labels_lengths_data, + const absl::Span& input_lengths_data, + DeviceMemory* costs_data, + const CudnnRnnStateTensorDescriptor& grads_desc, + DeviceMemory* grads_data, + const CudnnCtcLossDescriptor& ctc_loss_desc, + ScratchAllocator* workspace_allocator); + private: port::Status DoPrepareForConvolution( dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream, diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 73e378a31ba..8aea669c801 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -190,6 +190,15 @@ class RnnDescriptor { virtual ParamsRegions ParamsBiasRegions() const { return ParamsRegions(); } }; +// Specifies the CTC Loss computation. +// +// The user is responsible for releasing this descriptor when it is no longer +// in use. The destructor releases the underlying descriptors. +class CtcLossDescriptor { + public: + virtual ~CtcLossDescriptor() {} +}; + // Specifies the sequence in a RNN model. // // The user is responsible for releasing this descriptor when it is no longer @@ -2133,6 +2142,16 @@ class DnnSupport { "createRnnDescriptor is unimplemented"); } + // Create an CTC Loss descriptor. + // + // Arguments: + // data_type: an enum to specify the data types used in this model. + virtual port::StatusOr> + createCtcLossDescriptor(dnn::DataType data_type) { + return port::Status(port::error::UNIMPLEMENTED, + "createCtcLossDescriptor is unimplemented"); + } + // Create a RNN sequence descriptor that specifies either the input or output // sequence. The caller retains the ownership of the returned descriptor. // @@ -2383,6 +2402,40 @@ class DnnSupport { return false; } + // Enqueue a CTC Loss operation onto the stream. + // + // Arguments: + // stream: pointer to the stream where this operation should be enqueued to. + // probs_desc: specifies the shape and the data layout of the input tensor. + // probs_data: the device memory region that contains the input tensor. + // labels_data: the device memory region that contains the labels_value + // tensor. + // labels_lengths_data: the device memory region that contains the + // labels_lengths tensor + // input_lengths_data: the device memory region that contains the seq_lengths + // tensor + // costs_data: the device memory region that contains the costs tensor. + // grads_desc: specifies the shape and the data layout of the grads tensor. + // grads_data: the device memory region that contains the grads tensor. + // ctc_loss_desc: a CTCLoss descriptor created by createCTCLossDescriptor. + // workspace_allocator: a memory allocator that creates the temporary + // workspace memory used by this operation. The caller is responsible for + // keeping the memory alive long enough for this operation, and recylces + // afterwards. + virtual bool DoCtcLoss(Stream* stream, + const dnn::RnnStateTensorDescriptor &probs_desc, + const DeviceMemory &probs_data, + const absl::Span &labels_data, + const absl::Span &labels_lengths_data, + const absl::Span &input_lengths_data, + DeviceMemory *costs_data, + const dnn::RnnStateTensorDescriptor &grads_desc, + DeviceMemory *grads_data, + const dnn::CtcLossDescriptor &ctc_loss_desc, + ScratchAllocator *workspace_allocator) { + return false; + } + // Transforms a tensor into another tensor with a different layout and/or data // type. // diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index c1dc49ff1be..ed119fbafa7 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -5230,6 +5230,33 @@ Stream &Stream::ThenRnnBackward( return *this; } +Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc, + const DeviceMemory &probs_data, + const absl::Span &labels_data, + const absl::Span &labels_lengths_data, + const absl::Span &input_lengths_data, + DeviceMemory *costs_data, + const dnn::RnnStateTensorDescriptor &grads_desc, + DeviceMemory *grads_data, + const dnn::CtcLossDescriptor &ctc_loss_desc, + ScratchAllocator *workspace_allocator) { + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + auto status = dnn->DoCtcLoss( + this, probs_desc, probs_data, labels_data, labels_lengths_data, + input_lengths_data, costs_data, grads_desc, grads_data, ctc_loss_desc, + workspace_allocator); + if (!status) { + SetError(); + } + } else { + SetErrorAndLogNoDnnSupport(); + } + } + return *this; +} + + Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc, dnn::DataType input_type, const DeviceMemoryBase &input_data, diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index dddd0fa6441..fe1290822e4 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -1912,6 +1912,20 @@ class Stream { ScratchAllocator *workspace_allocator, dnn::ProfileResult *output_profile_result); + // Enqueue a CTCLoss operation onto the stream. + // See DnnSupport::DoCtcLoss for more details. + Stream &ThenCtcLoss( + const dnn::RnnStateTensorDescriptor &probs_desc, + const DeviceMemory &probs_data, + const absl::Span &labels_data, + const absl::Span &labels_lengths_data, + const absl::Span &input_lengths_data, + DeviceMemory *costs_data, + const dnn::RnnStateTensorDescriptor &grads_desc, + DeviceMemory *grads_data, + const dnn::CtcLossDescriptor &ctc_loss_desc, + ScratchAllocator *workspace_allocator); + // Enqueue onto the stream a operation that transforms a tensor. // See DnnSupport::DoTransformTensor for more details. Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc, diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index f8b6655e586..e2ad3e15af1 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -353,6 +353,16 @@ StreamExecutor::createRnnDescriptor( state_allocator, use_padded_io); } +port::StatusOr> +StreamExecutor::createCtcLossDescriptor(dnn::DataType data_type) { + dnn::DnnSupport *dnn_support = AsDnn(); + if (!dnn_support) { + return port::Status(port::error::UNKNOWN, + "Fail to find the dnn implementation."); + } + return dnn_support->createCtcLossDescriptor(data_type); +} + port::StatusOr> StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, int data_size, diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index efa4034c88a..98a3397ea16 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -399,6 +399,11 @@ class StreamExecutor { float dropout, uint64 seed, ScratchAllocator *state_allocator, bool use_padded_io); + // Create an CTC loss descriptor. The caller retains the ownership of the + // descriptor. + port::StatusOr> + createCtcLossDescriptor(dnn::DataType data_type); + // Create a RNN sequence descriptor that specifies either the input or output // sequence. The caller retains the ownership of the returned descriptor. port::StatusOr>