[ROCm] Adding ROCm support for CTC Loss
This commit is contained in:
parent
e1a66cff30
commit
d7dbb77bd6
@ -2364,9 +2364,10 @@ tf_kernel_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/util/ctc:ctc_beam_search_lib",
|
||||
"//tensorflow/core/util/ctc:ctc_loss_calculator_lib",
|
||||
] + if_cuda([
|
||||
] + if_cuda_or_rocm([
|
||||
":gpu_utils",
|
||||
":conv_ops_gpu_hdrs",
|
||||
]) + if_cuda([
|
||||
"@local_config_cuda//cuda:cudnn_header",
|
||||
]),
|
||||
)
|
||||
|
@ -15,9 +15,9 @@ limitations under the License.
|
||||
|
||||
// See docs in ../ops/ctc_ops.cc.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
@ -31,15 +31,18 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cudnn/cudnn.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
namespace {
|
||||
@ -60,7 +63,7 @@ void DoHistogram(OpKernelContext* ctx, const Tensor* labels_indices,
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
template <typename T>
|
||||
class CTCLossOp : public OpKernel {
|
||||
@ -221,7 +224,7 @@ REGISTER_CPU(double);
|
||||
|
||||
#undef REGISTER_CPU
|
||||
|
||||
#if GOOGLE_CUDA && CUDNN_VERSION >= 7603
|
||||
#if ((GOOGLE_CUDA && CUDNN_VERSION >= 7603) || TENSORFLOW_USE_ROCM)
|
||||
class CTCLossOpGPU : public OpKernel {
|
||||
public:
|
||||
explicit CTCLossOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@ -366,5 +369,5 @@ REGISTER_KERNEL_BUILDER(Name("CTCLossV2")
|
||||
.HostMemory("labels_values")
|
||||
.HostMemory("sequence_length"),
|
||||
CTCLossOpGPU);
|
||||
#endif // GOOGLE_CUDA && CUDNN_VERSION >= 7603
|
||||
#endif // ((GOOGLE_CUDA && CUDNN_VERSION >= 7603) || TENSORFLOW_USE_ROCM)
|
||||
} // end namespace tensorflow
|
||||
|
@ -280,7 +280,12 @@ namespace wrap {
|
||||
__macro(miopenConvolutionBackwardWeightsGetSolution) \
|
||||
__macro(miopenConvolutionBackwardWeightsGetSolutionWorkspaceSize) \
|
||||
__macro(miopenConvolutionBackwardWeightsCompileSolution) \
|
||||
__macro(miopenConvolutionBackwardWeightsImmediate)
|
||||
__macro(miopenConvolutionBackwardWeightsImmediate) \
|
||||
__macro(miopenCreateCTCLossDescriptor) \
|
||||
__macro(miopenSetCTCLossDescriptor) \
|
||||
__macro(miopenGetCTCLossWorkspaceSize) \
|
||||
__macro(miopenCTCLoss) \
|
||||
__macro(miopenDestroyCTCLossDescriptor)
|
||||
|
||||
// clang-format on
|
||||
|
||||
@ -2319,6 +2324,164 @@ MIOpenRnnParamsDescriptor::MIOpenRnnParamsDescriptor(
|
||||
}
|
||||
}
|
||||
|
||||
class MIOpenCTCLossDescriptor {
|
||||
public:
|
||||
explicit MIOpenCTCLossDescriptor(miopenDataType_t data_type) {
|
||||
auto status = wrap::miopenCreateCTCLossDescriptor(&handle_);
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL) << "call to miopenCreateCTCLossDescriptor failed: "
|
||||
<< ToString(status);
|
||||
}
|
||||
|
||||
bool apply_softmax_layer = true;
|
||||
status = wrap::miopenSetCTCLossDescriptor(handle_, data_type, 0,
|
||||
apply_softmax_layer);
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL) << "call to miopenSetCTCLossDescriptor failed: "
|
||||
<< ToString(status);
|
||||
}
|
||||
}
|
||||
|
||||
~MIOpenCTCLossDescriptor() {
|
||||
auto status = wrap::miopenDestroyCTCLossDescriptor(handle_);
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL) << "call to miopenDestroyCTCLossDescriptor failed: "
|
||||
<< ToString(status);
|
||||
}
|
||||
}
|
||||
|
||||
miopenCTCLossDescriptor_t handle() const { return handle_; }
|
||||
|
||||
private:
|
||||
miopenCTCLossDescriptor_t handle_; // Owned
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(MIOpenCTCLossDescriptor);
|
||||
};
|
||||
|
||||
port::Status MIOpenSupport::DoPrepareForCtcLoss(
|
||||
Stream* stream, dnn::DataType element_type,
|
||||
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory) {
|
||||
auto miopen = miopen_->GetHandle(parent_, stream);
|
||||
|
||||
MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type));
|
||||
|
||||
// Query the workspace size.
|
||||
size_t workspace_size_in_bytes = 0;
|
||||
|
||||
const MIOpenRnnStateTensorDescriptor& miopen_probs_desc =
|
||||
static_cast<const MIOpenRnnStateTensorDescriptor&>(probs_desc);
|
||||
|
||||
const MIOpenRnnStateTensorDescriptor& miopen_grads_desc =
|
||||
static_cast<const MIOpenRnnStateTensorDescriptor&>(grads_desc);
|
||||
|
||||
auto status = wrap::miopenGetCTCLossWorkspaceSize(
|
||||
miopen.handle(), miopen_probs_desc.handle(), miopen_grads_desc.handle(),
|
||||
labels_data.data(), labels_lengths_data.data(), input_lengths_data.data(),
|
||||
MIOPEN_CTC_LOSS_ALGO_DETERMINISTIC, miopen_ctc_loss_desc.handle(),
|
||||
&workspace_size_in_bytes);
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL) << "call to miopenDestroyCTCLossDescriptor failed: "
|
||||
<< ToString(status);
|
||||
return port::InternalError(
|
||||
"Failed to determine scratch memory size for MIOpen CTC Loss");
|
||||
}
|
||||
|
||||
*scratch_memory = DeviceMemory<uint8>();
|
||||
|
||||
// Allocate the workspace.
|
||||
if (workspace_size_in_bytes != 0) {
|
||||
if (scratch_allocator == nullptr) {
|
||||
return port::InternalError(
|
||||
absl::StrCat("An allocator must be specified when scratch memory is "
|
||||
"needed"));
|
||||
}
|
||||
auto scratch_or = scratch_allocator->AllocateBytes(workspace_size_in_bytes);
|
||||
if (scratch_or.ok()) {
|
||||
*scratch_memory = scratch_or.ValueOrDie();
|
||||
} else {
|
||||
LOG(ERROR)
|
||||
<< "Failed to allocate scratch memory - "
|
||||
<< scratch_or.status().error_message() << "\n"
|
||||
<< "\tYou can set the env var TF_CUDNN_WORKSPACE_LIMIT_IN_MB to a "
|
||||
"larger number (e.g. 8192) to increase the max memory limit.\n"
|
||||
<< "\tIncreasing the max memory limit might help resolve this "
|
||||
"error";
|
||||
return port::InternalError(absl::StrCat(
|
||||
"Failed to allocate scratch memory for MIOpen CTC Loss, of size: ",
|
||||
workspace_size_in_bytes));
|
||||
}
|
||||
}
|
||||
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::Status MIOpenSupport::DoCtcLossImpl(
|
||||
Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc,
|
||||
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||
const MIOpenRnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
|
||||
DeviceMemory<uint8> scratch_memory) {
|
||||
auto miopen = miopen_->GetHandle(parent_, stream);
|
||||
|
||||
int kNumTimestamps = probs_desc.num_layers();
|
||||
int kBatchSize = probs_desc.batch_size();
|
||||
int kNumLabels = probs_desc.data_size();
|
||||
int total_size = kNumLabels * kNumTimestamps * kBatchSize;
|
||||
(void)total_size;
|
||||
|
||||
auto status = wrap::miopenCTCLoss(
|
||||
miopen.handle(), probs_desc.handle(), probs_data.opaque(),
|
||||
labels_data.data(), labels_lengths_data.data(), input_lengths_data.data(),
|
||||
costs_data.opaque(), grads_desc.handle(), grads_data.opaque(),
|
||||
MIOPEN_CTC_LOSS_ALGO_DETERMINISTIC, ctc_loss_desc.handle(),
|
||||
scratch_memory.opaque(), scratch_memory.size());
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL) << "call to miopenCTCLoss failed: " << ToString(status);
|
||||
return port::InternalError("Failure during MIOpen CTC Loss");
|
||||
}
|
||||
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::Status MIOpenSupport::DoCtcLoss(
|
||||
Stream* stream, dnn::DataType element_type,
|
||||
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||
const DeviceMemoryBase probs_data,
|
||||
|
||||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory) {
|
||||
// Current MIOPen CTC Loss only supports the float datatype
|
||||
if (element_type != dnn::DataType::kFloat) {
|
||||
return port::Status(port::error::INVALID_ARGUMENT,
|
||||
"MIOpenCTCLossDescriptor is supported only when the "
|
||||
"DataType is float");
|
||||
}
|
||||
|
||||
MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type));
|
||||
|
||||
const MIOpenRnnStateTensorDescriptor& miopen_probs_desc =
|
||||
static_cast<const MIOpenRnnStateTensorDescriptor&>(probs_desc);
|
||||
|
||||
const MIOpenRnnStateTensorDescriptor& miopen_grads_desc =
|
||||
static_cast<const MIOpenRnnStateTensorDescriptor&>(grads_desc);
|
||||
|
||||
return DoCtcLossImpl(stream, miopen_probs_desc, probs_data, labels_data,
|
||||
labels_lengths_data, input_lengths_data, costs_data,
|
||||
miopen_grads_desc, grads_data, miopen_ctc_loss_desc,
|
||||
scratch_memory);
|
||||
}
|
||||
|
||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
||||
MIOpenSupport::createRnnDescriptor(
|
||||
int num_layers, int hidden_size, int input_size, int cell_size,
|
||||
@ -2633,8 +2796,8 @@ void* MIOpenAllocatorCallback(void* ctx, size_t size_in_bytes) {
|
||||
}
|
||||
|
||||
void MIOpenDeallocatorCallback(void* ctx, void* mem) {
|
||||
// Don't need deallocator since the TensorFlow heap will automatically reclaim
|
||||
// the memory
|
||||
// Don't need deallocator since the TensorFlow heap will automatically
|
||||
// reclaim the memory
|
||||
}
|
||||
|
||||
port::Status MIOpenSupport::DoPrepareForConvolution(
|
||||
@ -2835,9 +2998,9 @@ port::Status MIOpenSupport::DoConvolve(
|
||||
if (!timer->Init()) {
|
||||
return port::Status(port::error::INTERNAL, "Failed to init timer");
|
||||
}
|
||||
// The start and stop of the timer should be as close to the MIOpen call as
|
||||
// possible. It is still possible for other threads to issue workload on
|
||||
// to this stream. So it could take multiple profiling measurements.
|
||||
// The start and stop of the timer should be as close to the MIOpen call
|
||||
// as possible. It is still possible for other threads to issue workload
|
||||
// on to this stream. So it could take multiple profiling measurements.
|
||||
if (!timer->Start(AsGpuStream(stream))) {
|
||||
timer->Destroy();
|
||||
return port::Status(port::error::INTERNAL, "Failed to start timer");
|
||||
@ -2971,9 +3134,9 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithms(
|
||||
miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
|
||||
input_nd.handle(), &maxSolutionCount);
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL)
|
||||
<< "call to miopenConvolutionBackwardDataGetSolutionCount failed: "
|
||||
<< ToString(status);
|
||||
LOG(FATAL) << "call to miopenConvolutionBackwardDataGetSolutionCount "
|
||||
"failed: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
@ -4143,8 +4306,8 @@ bool MIOpenSupport::DoNormalizeBackwardWithDimensions(
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
LOG(ERROR)
|
||||
<< "Failed to calculate tensor size to chain forward and backward LRN";
|
||||
LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
|
||||
"backward LRN";
|
||||
}
|
||||
|
||||
status = wrap::miopenLRNForward(miopen.handle(), normalize.handle(), &alpha,
|
||||
|
@ -33,6 +33,8 @@ class GpuExecutor;
|
||||
class MIOpenRnnDescriptor;
|
||||
class MIOpenRnnSequenceTensorDescriptor;
|
||||
class MIOpenRnnStateTensorDescriptor;
|
||||
class MIOpenCTCLossDescriptor;
|
||||
|
||||
// Opaque and unique identifier for the MIOpen plugin.
|
||||
extern const PluginId kMIOpenPlugin;
|
||||
|
||||
@ -636,6 +638,17 @@ class MIOpenSupport : public dnn::DnnSupport {
|
||||
|
||||
GpuExecutor* GetParentExecutor() { return parent_; }
|
||||
|
||||
port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
|
||||
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||
const DeviceMemoryBase probs_data,
|
||||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
DeviceMemoryBase costs_data,
|
||||
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data,
|
||||
DeviceMemory<uint8> scratch_memory) override;
|
||||
|
||||
private:
|
||||
GpuExecutor* parent_; // Parent executor object. Not owned.
|
||||
|
||||
@ -784,6 +797,25 @@ class MIOpenSupport : public dnn::DnnSupport {
|
||||
ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
|
||||
DeviceMemory<uint8>* scratch_memory) override;
|
||||
|
||||
port::Status DoCtcLossImpl(
|
||||
Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc,
|
||||
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||
const MIOpenRnnStateTensorDescriptor& grads_desc,
|
||||
DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
|
||||
DeviceMemory<uint8> scratch_memory);
|
||||
|
||||
port::Status DoPrepareForCtcLoss(
|
||||
Stream* stream, dnn::DataType element_type,
|
||||
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||
absl::Span<const int> labels_data,
|
||||
absl::Span<const int> labels_lengths_data,
|
||||
absl::Span<const int> input_lengths_data,
|
||||
ScratchAllocator* scratch_allocator,
|
||||
DeviceMemory<uint8>* scratch_memory) override;
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(MIOpenSupport);
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user