[ROCm] Adding ROCm support for CTC Loss

This commit is contained in:
Deven Desai 2020-01-27 18:50:24 +00:00
parent e1a66cff30
commit d7dbb77bd6
4 changed files with 218 additions and 19 deletions

View File

@ -2364,9 +2364,10 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core/util/ctc:ctc_beam_search_lib",
"//tensorflow/core/util/ctc:ctc_loss_calculator_lib",
] + if_cuda([
] + if_cuda_or_rocm([
":gpu_utils",
":conv_ops_gpu_hdrs",
]) + if_cuda([
"@local_config_cuda//cuda:cudnn_header",
]),
)

View File

@ -15,9 +15,9 @@ limitations under the License.
// See docs in ../ops/ctc_ops.cc.
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op.h"
@ -31,15 +31,18 @@ limitations under the License.
#if GOOGLE_CUDA
#include "third_party/gpus/cudnn/cudnn.h"
#endif // GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/util/stream_executor_util.h"
#include "tensorflow/core/util/tensor_format.h"
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
using GPUDevice = Eigen::GpuDevice;
namespace {
@ -60,7 +63,7 @@ void DoHistogram(OpKernelContext* ctx, const Tensor* labels_indices,
}
} // end namespace
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename T>
class CTCLossOp : public OpKernel {
@ -221,7 +224,7 @@ REGISTER_CPU(double);
#undef REGISTER_CPU
#if GOOGLE_CUDA && CUDNN_VERSION >= 7603
#if ((GOOGLE_CUDA && CUDNN_VERSION >= 7603) || TENSORFLOW_USE_ROCM)
class CTCLossOpGPU : public OpKernel {
public:
explicit CTCLossOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) {
@ -366,5 +369,5 @@ REGISTER_KERNEL_BUILDER(Name("CTCLossV2")
.HostMemory("labels_values")
.HostMemory("sequence_length"),
CTCLossOpGPU);
#endif // GOOGLE_CUDA && CUDNN_VERSION >= 7603
#endif // ((GOOGLE_CUDA && CUDNN_VERSION >= 7603) || TENSORFLOW_USE_ROCM)
} // end namespace tensorflow

View File

@ -280,7 +280,12 @@ namespace wrap {
__macro(miopenConvolutionBackwardWeightsGetSolution) \
__macro(miopenConvolutionBackwardWeightsGetSolutionWorkspaceSize) \
__macro(miopenConvolutionBackwardWeightsCompileSolution) \
__macro(miopenConvolutionBackwardWeightsImmediate)
__macro(miopenConvolutionBackwardWeightsImmediate) \
__macro(miopenCreateCTCLossDescriptor) \
__macro(miopenSetCTCLossDescriptor) \
__macro(miopenGetCTCLossWorkspaceSize) \
__macro(miopenCTCLoss) \
__macro(miopenDestroyCTCLossDescriptor)
// clang-format on
@ -2319,6 +2324,164 @@ MIOpenRnnParamsDescriptor::MIOpenRnnParamsDescriptor(
}
}
class MIOpenCTCLossDescriptor {
public:
explicit MIOpenCTCLossDescriptor(miopenDataType_t data_type) {
auto status = wrap::miopenCreateCTCLossDescriptor(&handle_);
if (status != miopenStatusSuccess) {
LOG(FATAL) << "call to miopenCreateCTCLossDescriptor failed: "
<< ToString(status);
}
bool apply_softmax_layer = true;
status = wrap::miopenSetCTCLossDescriptor(handle_, data_type, 0,
apply_softmax_layer);
if (status != miopenStatusSuccess) {
LOG(FATAL) << "call to miopenSetCTCLossDescriptor failed: "
<< ToString(status);
}
}
~MIOpenCTCLossDescriptor() {
auto status = wrap::miopenDestroyCTCLossDescriptor(handle_);
if (status != miopenStatusSuccess) {
LOG(FATAL) << "call to miopenDestroyCTCLossDescriptor failed: "
<< ToString(status);
}
}
miopenCTCLossDescriptor_t handle() const { return handle_; }
private:
miopenCTCLossDescriptor_t handle_; // Owned
SE_DISALLOW_COPY_AND_ASSIGN(MIOpenCTCLossDescriptor);
};
port::Status MIOpenSupport::DoPrepareForCtcLoss(
Stream* stream, dnn::DataType element_type,
const dnn::RnnStateTensorDescriptor& probs_desc,
const dnn::RnnStateTensorDescriptor& grads_desc,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory) {
auto miopen = miopen_->GetHandle(parent_, stream);
MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type));
// Query the workspace size.
size_t workspace_size_in_bytes = 0;
const MIOpenRnnStateTensorDescriptor& miopen_probs_desc =
static_cast<const MIOpenRnnStateTensorDescriptor&>(probs_desc);
const MIOpenRnnStateTensorDescriptor& miopen_grads_desc =
static_cast<const MIOpenRnnStateTensorDescriptor&>(grads_desc);
auto status = wrap::miopenGetCTCLossWorkspaceSize(
miopen.handle(), miopen_probs_desc.handle(), miopen_grads_desc.handle(),
labels_data.data(), labels_lengths_data.data(), input_lengths_data.data(),
MIOPEN_CTC_LOSS_ALGO_DETERMINISTIC, miopen_ctc_loss_desc.handle(),
&workspace_size_in_bytes);
if (status != miopenStatusSuccess) {
LOG(FATAL) << "call to miopenDestroyCTCLossDescriptor failed: "
<< ToString(status);
return port::InternalError(
"Failed to determine scratch memory size for MIOpen CTC Loss");
}
*scratch_memory = DeviceMemory<uint8>();
// Allocate the workspace.
if (workspace_size_in_bytes != 0) {
if (scratch_allocator == nullptr) {
return port::InternalError(
absl::StrCat("An allocator must be specified when scratch memory is "
"needed"));
}
auto scratch_or = scratch_allocator->AllocateBytes(workspace_size_in_bytes);
if (scratch_or.ok()) {
*scratch_memory = scratch_or.ValueOrDie();
} else {
LOG(ERROR)
<< "Failed to allocate scratch memory - "
<< scratch_or.status().error_message() << "\n"
<< "\tYou can set the env var TF_CUDNN_WORKSPACE_LIMIT_IN_MB to a "
"larger number (e.g. 8192) to increase the max memory limit.\n"
<< "\tIncreasing the max memory limit might help resolve this "
"error";
return port::InternalError(absl::StrCat(
"Failed to allocate scratch memory for MIOpen CTC Loss, of size: ",
workspace_size_in_bytes));
}
}
return port::Status::OK();
}
port::Status MIOpenSupport::DoCtcLossImpl(
Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
const MIOpenRnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
DeviceMemory<uint8> scratch_memory) {
auto miopen = miopen_->GetHandle(parent_, stream);
int kNumTimestamps = probs_desc.num_layers();
int kBatchSize = probs_desc.batch_size();
int kNumLabels = probs_desc.data_size();
int total_size = kNumLabels * kNumTimestamps * kBatchSize;
(void)total_size;
auto status = wrap::miopenCTCLoss(
miopen.handle(), probs_desc.handle(), probs_data.opaque(),
labels_data.data(), labels_lengths_data.data(), input_lengths_data.data(),
costs_data.opaque(), grads_desc.handle(), grads_data.opaque(),
MIOPEN_CTC_LOSS_ALGO_DETERMINISTIC, ctc_loss_desc.handle(),
scratch_memory.opaque(), scratch_memory.size());
if (status != miopenStatusSuccess) {
LOG(FATAL) << "call to miopenCTCLoss failed: " << ToString(status);
return port::InternalError("Failure during MIOpen CTC Loss");
}
return port::Status::OK();
}
port::Status MIOpenSupport::DoCtcLoss(
Stream* stream, dnn::DataType element_type,
const dnn::RnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
const dnn::RnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory) {
// Current MIOPen CTC Loss only supports the float datatype
if (element_type != dnn::DataType::kFloat) {
return port::Status(port::error::INVALID_ARGUMENT,
"MIOpenCTCLossDescriptor is supported only when the "
"DataType is float");
}
MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type));
const MIOpenRnnStateTensorDescriptor& miopen_probs_desc =
static_cast<const MIOpenRnnStateTensorDescriptor&>(probs_desc);
const MIOpenRnnStateTensorDescriptor& miopen_grads_desc =
static_cast<const MIOpenRnnStateTensorDescriptor&>(grads_desc);
return DoCtcLossImpl(stream, miopen_probs_desc, probs_data, labels_data,
labels_lengths_data, input_lengths_data, costs_data,
miopen_grads_desc, grads_data, miopen_ctc_loss_desc,
scratch_memory);
}
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
MIOpenSupport::createRnnDescriptor(
int num_layers, int hidden_size, int input_size, int cell_size,
@ -2633,8 +2796,8 @@ void* MIOpenAllocatorCallback(void* ctx, size_t size_in_bytes) {
}
void MIOpenDeallocatorCallback(void* ctx, void* mem) {
// Don't need deallocator since the TensorFlow heap will automatically reclaim
// the memory
// Don't need deallocator since the TensorFlow heap will automatically
// reclaim the memory
}
port::Status MIOpenSupport::DoPrepareForConvolution(
@ -2835,9 +2998,9 @@ port::Status MIOpenSupport::DoConvolve(
if (!timer->Init()) {
return port::Status(port::error::INTERNAL, "Failed to init timer");
}
// The start and stop of the timer should be as close to the MIOpen call as
// possible. It is still possible for other threads to issue workload on
// to this stream. So it could take multiple profiling measurements.
// The start and stop of the timer should be as close to the MIOpen call
// as possible. It is still possible for other threads to issue workload
// on to this stream. So it could take multiple profiling measurements.
if (!timer->Start(AsGpuStream(stream))) {
timer->Destroy();
return port::Status(port::error::INTERNAL, "Failed to start timer");
@ -2971,9 +3134,9 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithms(
miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
input_nd.handle(), &maxSolutionCount);
if (status != miopenStatusSuccess) {
LOG(FATAL)
<< "call to miopenConvolutionBackwardDataGetSolutionCount failed: "
<< ToString(status);
LOG(FATAL) << "call to miopenConvolutionBackwardDataGetSolutionCount "
"failed: "
<< ToString(status);
return false;
}
break;
@ -4143,8 +4306,8 @@ bool MIOpenSupport::DoNormalizeBackwardWithDimensions(
return false;
}
} else {
LOG(ERROR)
<< "Failed to calculate tensor size to chain forward and backward LRN";
LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
"backward LRN";
}
status = wrap::miopenLRNForward(miopen.handle(), normalize.handle(), &alpha,

View File

@ -33,6 +33,8 @@ class GpuExecutor;
class MIOpenRnnDescriptor;
class MIOpenRnnSequenceTensorDescriptor;
class MIOpenRnnStateTensorDescriptor;
class MIOpenCTCLossDescriptor;
// Opaque and unique identifier for the MIOpen plugin.
extern const PluginId kMIOpenPlugin;
@ -636,6 +638,17 @@ class MIOpenSupport : public dnn::DnnSupport {
GpuExecutor* GetParentExecutor() { return parent_; }
port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
const dnn::RnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
DeviceMemoryBase costs_data,
const dnn::RnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data,
DeviceMemory<uint8> scratch_memory) override;
private:
GpuExecutor* parent_; // Parent executor object. Not owned.
@ -784,6 +797,25 @@ class MIOpenSupport : public dnn::DnnSupport {
ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
DeviceMemory<uint8>* scratch_memory) override;
port::Status DoCtcLossImpl(
Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc,
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
const MIOpenRnnStateTensorDescriptor& grads_desc,
DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
DeviceMemory<uint8> scratch_memory);
port::Status DoPrepareForCtcLoss(
Stream* stream, dnn::DataType element_type,
const dnn::RnnStateTensorDescriptor& probs_desc,
const dnn::RnnStateTensorDescriptor& grads_desc,
absl::Span<const int> labels_data,
absl::Span<const int> labels_lengths_data,
absl::Span<const int> input_lengths_data,
ScratchAllocator* scratch_allocator,
DeviceMemory<uint8>* scratch_memory) override;
SE_DISALLOW_COPY_AND_ASSIGN(MIOpenSupport);
};