diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 0109597e9cc..87017da1e6e 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2364,9 +2364,10 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core/util/ctc:ctc_beam_search_lib", "//tensorflow/core/util/ctc:ctc_loss_calculator_lib", - ] + if_cuda([ + ] + if_cuda_or_rocm([ ":gpu_utils", ":conv_ops_gpu_hdrs", + ]) + if_cuda([ "@local_config_cuda//cuda:cudnn_header", ]), ) diff --git a/tensorflow/core/kernels/ctc_loss_op.cc b/tensorflow/core/kernels/ctc_loss_op.cc index 224a15b1936..6358e82fdda 100644 --- a/tensorflow/core/kernels/ctc_loss_op.cc +++ b/tensorflow/core/kernels/ctc_loss_op.cc @@ -15,9 +15,9 @@ limitations under the License. // See docs in ../ops/ctc_ops.cc. -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op.h" @@ -31,15 +31,18 @@ limitations under the License. #if GOOGLE_CUDA #include "third_party/gpus/cudnn/cudnn.h" +#endif // GOOGLE_CUDA + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/util/stream_executor_util.h" #include "tensorflow/core/util/tensor_format.h" -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM using GPUDevice = Eigen::GpuDevice; namespace { @@ -60,7 +63,7 @@ void DoHistogram(OpKernelContext* ctx, const Tensor* labels_indices, } } // end namespace -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM template <typename T> class CTCLossOp : public OpKernel { @@ -221,7 +224,7 @@ REGISTER_CPU(double); #undef REGISTER_CPU -#if GOOGLE_CUDA && CUDNN_VERSION >= 7603 +#if ((GOOGLE_CUDA && CUDNN_VERSION >= 7603) || TENSORFLOW_USE_ROCM) class CTCLossOpGPU : public OpKernel { public: explicit CTCLossOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -366,5 +369,5 @@ REGISTER_KERNEL_BUILDER(Name("CTCLossV2") .HostMemory("labels_values") .HostMemory("sequence_length"), CTCLossOpGPU); -#endif // GOOGLE_CUDA && CUDNN_VERSION >= 7603 +#endif // ((GOOGLE_CUDA && CUDNN_VERSION >= 7603) || TENSORFLOW_USE_ROCM) } // end namespace tensorflow diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.cc b/tensorflow/stream_executor/rocm/rocm_dnn.cc index 8df92357e9b..f70a16fed87 100644 --- a/tensorflow/stream_executor/rocm/rocm_dnn.cc +++ b/tensorflow/stream_executor/rocm/rocm_dnn.cc @@ -280,7 +280,12 @@ namespace wrap { __macro(miopenConvolutionBackwardWeightsGetSolution) \ __macro(miopenConvolutionBackwardWeightsGetSolutionWorkspaceSize) \ __macro(miopenConvolutionBackwardWeightsCompileSolution) \ - __macro(miopenConvolutionBackwardWeightsImmediate) + __macro(miopenConvolutionBackwardWeightsImmediate) \ + __macro(miopenCreateCTCLossDescriptor) \ + __macro(miopenSetCTCLossDescriptor) \ + __macro(miopenGetCTCLossWorkspaceSize) \ + __macro(miopenCTCLoss) \ + __macro(miopenDestroyCTCLossDescriptor) // clang-format on @@ -2319,6 +2324,164 @@ MIOpenRnnParamsDescriptor::MIOpenRnnParamsDescriptor( } } +class MIOpenCTCLossDescriptor { + public: + explicit MIOpenCTCLossDescriptor(miopenDataType_t data_type) { + auto status = wrap::miopenCreateCTCLossDescriptor(&handle_); + if (status != miopenStatusSuccess) { + LOG(FATAL) << "call to miopenCreateCTCLossDescriptor failed: " + << ToString(status); + } + + bool apply_softmax_layer = true; + status = wrap::miopenSetCTCLossDescriptor(handle_, data_type, 0, + apply_softmax_layer); + if (status != miopenStatusSuccess) { + LOG(FATAL) << "call to miopenSetCTCLossDescriptor failed: " + << ToString(status); + } + } + + ~MIOpenCTCLossDescriptor() { + auto status = wrap::miopenDestroyCTCLossDescriptor(handle_); + if (status != miopenStatusSuccess) { + LOG(FATAL) << "call to miopenDestroyCTCLossDescriptor failed: " + << ToString(status); + } + } + + miopenCTCLossDescriptor_t handle() const { return handle_; } + + private: + miopenCTCLossDescriptor_t handle_; // Owned + + SE_DISALLOW_COPY_AND_ASSIGN(MIOpenCTCLossDescriptor); +}; + +port::Status MIOpenSupport::DoPrepareForCtcLoss( + Stream* stream, dnn::DataType element_type, + const dnn::RnnStateTensorDescriptor& probs_desc, + const dnn::RnnStateTensorDescriptor& grads_desc, + absl::Span<const int> labels_data, + absl::Span<const int> labels_lengths_data, + absl::Span<const int> input_lengths_data, + ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory) { + auto miopen = miopen_->GetHandle(parent_, stream); + + MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type)); + + // Query the workspace size. + size_t workspace_size_in_bytes = 0; + + const MIOpenRnnStateTensorDescriptor& miopen_probs_desc = + static_cast<const MIOpenRnnStateTensorDescriptor&>(probs_desc); + + const MIOpenRnnStateTensorDescriptor& miopen_grads_desc = + static_cast<const MIOpenRnnStateTensorDescriptor&>(grads_desc); + + auto status = wrap::miopenGetCTCLossWorkspaceSize( + miopen.handle(), miopen_probs_desc.handle(), miopen_grads_desc.handle(), + labels_data.data(), labels_lengths_data.data(), input_lengths_data.data(), + MIOPEN_CTC_LOSS_ALGO_DETERMINISTIC, miopen_ctc_loss_desc.handle(), + &workspace_size_in_bytes); + + if (status != miopenStatusSuccess) { + LOG(FATAL) << "call to miopenDestroyCTCLossDescriptor failed: " + << ToString(status); + return port::InternalError( + "Failed to determine scratch memory size for MIOpen CTC Loss"); + } + + *scratch_memory = DeviceMemory<uint8>(); + + // Allocate the workspace. + if (workspace_size_in_bytes != 0) { + if (scratch_allocator == nullptr) { + return port::InternalError( + absl::StrCat("An allocator must be specified when scratch memory is " + "needed")); + } + auto scratch_or = scratch_allocator->AllocateBytes(workspace_size_in_bytes); + if (scratch_or.ok()) { + *scratch_memory = scratch_or.ValueOrDie(); + } else { + LOG(ERROR) + << "Failed to allocate scratch memory - " + << scratch_or.status().error_message() << "\n" + << "\tYou can set the env var TF_CUDNN_WORKSPACE_LIMIT_IN_MB to a " + "larger number (e.g. 8192) to increase the max memory limit.\n" + << "\tIncreasing the max memory limit might help resolve this " + "error"; + return port::InternalError(absl::StrCat( + "Failed to allocate scratch memory for MIOpen CTC Loss, of size: ", + workspace_size_in_bytes)); + } + } + + return port::Status::OK(); +} + +port::Status MIOpenSupport::DoCtcLossImpl( + Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc, + const DeviceMemoryBase probs_data, absl::Span<const int> labels_data, + absl::Span<const int> labels_lengths_data, + absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data, + const MIOpenRnnStateTensorDescriptor& grads_desc, + DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc, + DeviceMemory<uint8> scratch_memory) { + auto miopen = miopen_->GetHandle(parent_, stream); + + int kNumTimestamps = probs_desc.num_layers(); + int kBatchSize = probs_desc.batch_size(); + int kNumLabels = probs_desc.data_size(); + int total_size = kNumLabels * kNumTimestamps * kBatchSize; + (void)total_size; + + auto status = wrap::miopenCTCLoss( + miopen.handle(), probs_desc.handle(), probs_data.opaque(), + labels_data.data(), labels_lengths_data.data(), input_lengths_data.data(), + costs_data.opaque(), grads_desc.handle(), grads_data.opaque(), + MIOPEN_CTC_LOSS_ALGO_DETERMINISTIC, ctc_loss_desc.handle(), + scratch_memory.opaque(), scratch_memory.size()); + if (status != miopenStatusSuccess) { + LOG(FATAL) << "call to miopenCTCLoss failed: " << ToString(status); + return port::InternalError("Failure during MIOpen CTC Loss"); + } + + return port::Status::OK(); +} + +port::Status MIOpenSupport::DoCtcLoss( + Stream* stream, dnn::DataType element_type, + const dnn::RnnStateTensorDescriptor& probs_desc, + const DeviceMemoryBase probs_data, + + absl::Span<const int> labels_data, + absl::Span<const int> labels_lengths_data, + absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data, + const dnn::RnnStateTensorDescriptor& grads_desc, + DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory) { + // Current MIOPen CTC Loss only supports the float datatype + if (element_type != dnn::DataType::kFloat) { + return port::Status(port::error::INVALID_ARGUMENT, + "MIOpenCTCLossDescriptor is supported only when the " + "DataType is float"); + } + + MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type)); + + const MIOpenRnnStateTensorDescriptor& miopen_probs_desc = + static_cast<const MIOpenRnnStateTensorDescriptor&>(probs_desc); + + const MIOpenRnnStateTensorDescriptor& miopen_grads_desc = + static_cast<const MIOpenRnnStateTensorDescriptor&>(grads_desc); + + return DoCtcLossImpl(stream, miopen_probs_desc, probs_data, labels_data, + labels_lengths_data, input_lengths_data, costs_data, + miopen_grads_desc, grads_data, miopen_ctc_loss_desc, + scratch_memory); +} + port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> MIOpenSupport::createRnnDescriptor( int num_layers, int hidden_size, int input_size, int cell_size, @@ -2633,8 +2796,8 @@ void* MIOpenAllocatorCallback(void* ctx, size_t size_in_bytes) { } void MIOpenDeallocatorCallback(void* ctx, void* mem) { - // Don't need deallocator since the TensorFlow heap will automatically reclaim - // the memory + // Don't need deallocator since the TensorFlow heap will automatically + // reclaim the memory } port::Status MIOpenSupport::DoPrepareForConvolution( @@ -2835,9 +2998,9 @@ port::Status MIOpenSupport::DoConvolve( if (!timer->Init()) { return port::Status(port::error::INTERNAL, "Failed to init timer"); } - // The start and stop of the timer should be as close to the MIOpen call as - // possible. It is still possible for other threads to issue workload on - // to this stream. So it could take multiple profiling measurements. + // The start and stop of the timer should be as close to the MIOpen call + // as possible. It is still possible for other threads to issue workload + // on to this stream. So it could take multiple profiling measurements. if (!timer->Start(AsGpuStream(stream))) { timer->Destroy(); return port::Status(port::error::INTERNAL, "Failed to start timer"); @@ -2971,9 +3134,9 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithms( miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(), input_nd.handle(), &maxSolutionCount); if (status != miopenStatusSuccess) { - LOG(FATAL) - << "call to miopenConvolutionBackwardDataGetSolutionCount failed: " - << ToString(status); + LOG(FATAL) << "call to miopenConvolutionBackwardDataGetSolutionCount " + "failed: " + << ToString(status); return false; } break; @@ -4143,8 +4306,8 @@ bool MIOpenSupport::DoNormalizeBackwardWithDimensions( return false; } } else { - LOG(ERROR) - << "Failed to calculate tensor size to chain forward and backward LRN"; + LOG(ERROR) << "Failed to calculate tensor size to chain forward and " + "backward LRN"; } status = wrap::miopenLRNForward(miopen.handle(), normalize.handle(), &alpha, diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.h b/tensorflow/stream_executor/rocm/rocm_dnn.h index 9c2f1bcf1c6..5dc192c5e86 100644 --- a/tensorflow/stream_executor/rocm/rocm_dnn.h +++ b/tensorflow/stream_executor/rocm/rocm_dnn.h @@ -33,6 +33,8 @@ class GpuExecutor; class MIOpenRnnDescriptor; class MIOpenRnnSequenceTensorDescriptor; class MIOpenRnnStateTensorDescriptor; +class MIOpenCTCLossDescriptor; + // Opaque and unique identifier for the MIOpen plugin. extern const PluginId kMIOpenPlugin; @@ -636,6 +638,17 @@ class MIOpenSupport : public dnn::DnnSupport { GpuExecutor* GetParentExecutor() { return parent_; } + port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type, + const dnn::RnnStateTensorDescriptor& probs_desc, + const DeviceMemoryBase probs_data, + absl::Span<const int> labels_data, + absl::Span<const int> labels_lengths_data, + absl::Span<const int> input_lengths_data, + DeviceMemoryBase costs_data, + const dnn::RnnStateTensorDescriptor& grads_desc, + DeviceMemoryBase grads_data, + DeviceMemory<uint8> scratch_memory) override; + private: GpuExecutor* parent_; // Parent executor object. Not owned. @@ -784,6 +797,25 @@ class MIOpenSupport : public dnn::DnnSupport { ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc, DeviceMemory<uint8>* scratch_memory) override; + port::Status DoCtcLossImpl( + Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc, + const DeviceMemoryBase probs_data, absl::Span<const int> labels_data, + absl::Span<const int> labels_lengths_data, + absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data, + const MIOpenRnnStateTensorDescriptor& grads_desc, + DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc, + DeviceMemory<uint8> scratch_memory); + + port::Status DoPrepareForCtcLoss( + Stream* stream, dnn::DataType element_type, + const dnn::RnnStateTensorDescriptor& probs_desc, + const dnn::RnnStateTensorDescriptor& grads_desc, + absl::Span<const int> labels_data, + absl::Span<const int> labels_lengths_data, + absl::Span<const int> input_lengths_data, + ScratchAllocator* scratch_allocator, + DeviceMemory<uint8>* scratch_memory) override; + SE_DISALLOW_COPY_AND_ASSIGN(MIOpenSupport); };