[ROCm] Adding ROCm support for CTC Loss
This commit is contained in:
parent
e1a66cff30
commit
d7dbb77bd6
@ -2364,9 +2364,10 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/util/ctc:ctc_beam_search_lib",
|
"//tensorflow/core/util/ctc:ctc_beam_search_lib",
|
||||||
"//tensorflow/core/util/ctc:ctc_loss_calculator_lib",
|
"//tensorflow/core/util/ctc:ctc_loss_calculator_lib",
|
||||||
] + if_cuda([
|
] + if_cuda_or_rocm([
|
||||||
":gpu_utils",
|
":gpu_utils",
|
||||||
":conv_ops_gpu_hdrs",
|
":conv_ops_gpu_hdrs",
|
||||||
|
]) + if_cuda([
|
||||||
"@local_config_cuda//cuda:cudnn_header",
|
"@local_config_cuda//cuda:cudnn_header",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
@ -15,9 +15,9 @@ limitations under the License.
|
|||||||
|
|
||||||
// See docs in ../ops/ctc_ops.cc.
|
// See docs in ../ops/ctc_ops.cc.
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include "tensorflow/core/framework/bounds_check.h"
|
#include "tensorflow/core/framework/bounds_check.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
@ -31,15 +31,18 @@ limitations under the License.
|
|||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "third_party/gpus/cudnn/cudnn.h"
|
#include "third_party/gpus/cudnn/cudnn.h"
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||||
#include "tensorflow/core/util/stream_executor_util.h"
|
#include "tensorflow/core/util/stream_executor_util.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
using GPUDevice = Eigen::GpuDevice;
|
using GPUDevice = Eigen::GpuDevice;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -60,7 +63,7 @@ void DoHistogram(OpKernelContext* ctx, const Tensor* labels_indices,
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // end namespace
|
} // end namespace
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class CTCLossOp : public OpKernel {
|
class CTCLossOp : public OpKernel {
|
||||||
@ -221,7 +224,7 @@ REGISTER_CPU(double);
|
|||||||
|
|
||||||
#undef REGISTER_CPU
|
#undef REGISTER_CPU
|
||||||
|
|
||||||
#if GOOGLE_CUDA && CUDNN_VERSION >= 7603
|
#if ((GOOGLE_CUDA && CUDNN_VERSION >= 7603) || TENSORFLOW_USE_ROCM)
|
||||||
class CTCLossOpGPU : public OpKernel {
|
class CTCLossOpGPU : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit CTCLossOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
explicit CTCLossOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
@ -366,5 +369,5 @@ REGISTER_KERNEL_BUILDER(Name("CTCLossV2")
|
|||||||
.HostMemory("labels_values")
|
.HostMemory("labels_values")
|
||||||
.HostMemory("sequence_length"),
|
.HostMemory("sequence_length"),
|
||||||
CTCLossOpGPU);
|
CTCLossOpGPU);
|
||||||
#endif // GOOGLE_CUDA && CUDNN_VERSION >= 7603
|
#endif // ((GOOGLE_CUDA && CUDNN_VERSION >= 7603) || TENSORFLOW_USE_ROCM)
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -280,7 +280,12 @@ namespace wrap {
|
|||||||
__macro(miopenConvolutionBackwardWeightsGetSolution) \
|
__macro(miopenConvolutionBackwardWeightsGetSolution) \
|
||||||
__macro(miopenConvolutionBackwardWeightsGetSolutionWorkspaceSize) \
|
__macro(miopenConvolutionBackwardWeightsGetSolutionWorkspaceSize) \
|
||||||
__macro(miopenConvolutionBackwardWeightsCompileSolution) \
|
__macro(miopenConvolutionBackwardWeightsCompileSolution) \
|
||||||
__macro(miopenConvolutionBackwardWeightsImmediate)
|
__macro(miopenConvolutionBackwardWeightsImmediate) \
|
||||||
|
__macro(miopenCreateCTCLossDescriptor) \
|
||||||
|
__macro(miopenSetCTCLossDescriptor) \
|
||||||
|
__macro(miopenGetCTCLossWorkspaceSize) \
|
||||||
|
__macro(miopenCTCLoss) \
|
||||||
|
__macro(miopenDestroyCTCLossDescriptor)
|
||||||
|
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
@ -2319,6 +2324,164 @@ MIOpenRnnParamsDescriptor::MIOpenRnnParamsDescriptor(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class MIOpenCTCLossDescriptor {
|
||||||
|
public:
|
||||||
|
explicit MIOpenCTCLossDescriptor(miopenDataType_t data_type) {
|
||||||
|
auto status = wrap::miopenCreateCTCLossDescriptor(&handle_);
|
||||||
|
if (status != miopenStatusSuccess) {
|
||||||
|
LOG(FATAL) << "call to miopenCreateCTCLossDescriptor failed: "
|
||||||
|
<< ToString(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool apply_softmax_layer = true;
|
||||||
|
status = wrap::miopenSetCTCLossDescriptor(handle_, data_type, 0,
|
||||||
|
apply_softmax_layer);
|
||||||
|
if (status != miopenStatusSuccess) {
|
||||||
|
LOG(FATAL) << "call to miopenSetCTCLossDescriptor failed: "
|
||||||
|
<< ToString(status);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~MIOpenCTCLossDescriptor() {
|
||||||
|
auto status = wrap::miopenDestroyCTCLossDescriptor(handle_);
|
||||||
|
if (status != miopenStatusSuccess) {
|
||||||
|
LOG(FATAL) << "call to miopenDestroyCTCLossDescriptor failed: "
|
||||||
|
<< ToString(status);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
miopenCTCLossDescriptor_t handle() const { return handle_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
miopenCTCLossDescriptor_t handle_; // Owned
|
||||||
|
|
||||||
|
SE_DISALLOW_COPY_AND_ASSIGN(MIOpenCTCLossDescriptor);
|
||||||
|
};
|
||||||
|
|
||||||
|
port::Status MIOpenSupport::DoPrepareForCtcLoss(
|
||||||
|
Stream* stream, dnn::DataType element_type,
|
||||||
|
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||||
|
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||||
|
absl::Span<const int> labels_data,
|
||||||
|
absl::Span<const int> labels_lengths_data,
|
||||||
|
absl::Span<const int> input_lengths_data,
|
||||||
|
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory) {
|
||||||
|
auto miopen = miopen_->GetHandle(parent_, stream);
|
||||||
|
|
||||||
|
MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type));
|
||||||
|
|
||||||
|
// Query the workspace size.
|
||||||
|
size_t workspace_size_in_bytes = 0;
|
||||||
|
|
||||||
|
const MIOpenRnnStateTensorDescriptor& miopen_probs_desc =
|
||||||
|
static_cast<const MIOpenRnnStateTensorDescriptor&>(probs_desc);
|
||||||
|
|
||||||
|
const MIOpenRnnStateTensorDescriptor& miopen_grads_desc =
|
||||||
|
static_cast<const MIOpenRnnStateTensorDescriptor&>(grads_desc);
|
||||||
|
|
||||||
|
auto status = wrap::miopenGetCTCLossWorkspaceSize(
|
||||||
|
miopen.handle(), miopen_probs_desc.handle(), miopen_grads_desc.handle(),
|
||||||
|
labels_data.data(), labels_lengths_data.data(), input_lengths_data.data(),
|
||||||
|
MIOPEN_CTC_LOSS_ALGO_DETERMINISTIC, miopen_ctc_loss_desc.handle(),
|
||||||
|
&workspace_size_in_bytes);
|
||||||
|
|
||||||
|
if (status != miopenStatusSuccess) {
|
||||||
|
LOG(FATAL) << "call to miopenDestroyCTCLossDescriptor failed: "
|
||||||
|
<< ToString(status);
|
||||||
|
return port::InternalError(
|
||||||
|
"Failed to determine scratch memory size for MIOpen CTC Loss");
|
||||||
|
}
|
||||||
|
|
||||||
|
*scratch_memory = DeviceMemory<uint8>();
|
||||||
|
|
||||||
|
// Allocate the workspace.
|
||||||
|
if (workspace_size_in_bytes != 0) {
|
||||||
|
if (scratch_allocator == nullptr) {
|
||||||
|
return port::InternalError(
|
||||||
|
absl::StrCat("An allocator must be specified when scratch memory is "
|
||||||
|
"needed"));
|
||||||
|
}
|
||||||
|
auto scratch_or = scratch_allocator->AllocateBytes(workspace_size_in_bytes);
|
||||||
|
if (scratch_or.ok()) {
|
||||||
|
*scratch_memory = scratch_or.ValueOrDie();
|
||||||
|
} else {
|
||||||
|
LOG(ERROR)
|
||||||
|
<< "Failed to allocate scratch memory - "
|
||||||
|
<< scratch_or.status().error_message() << "\n"
|
||||||
|
<< "\tYou can set the env var TF_CUDNN_WORKSPACE_LIMIT_IN_MB to a "
|
||||||
|
"larger number (e.g. 8192) to increase the max memory limit.\n"
|
||||||
|
<< "\tIncreasing the max memory limit might help resolve this "
|
||||||
|
"error";
|
||||||
|
return port::InternalError(absl::StrCat(
|
||||||
|
"Failed to allocate scratch memory for MIOpen CTC Loss, of size: ",
|
||||||
|
workspace_size_in_bytes));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return port::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
port::Status MIOpenSupport::DoCtcLossImpl(
|
||||||
|
Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc,
|
||||||
|
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
|
||||||
|
absl::Span<const int> labels_lengths_data,
|
||||||
|
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||||
|
const MIOpenRnnStateTensorDescriptor& grads_desc,
|
||||||
|
DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
|
||||||
|
DeviceMemory<uint8> scratch_memory) {
|
||||||
|
auto miopen = miopen_->GetHandle(parent_, stream);
|
||||||
|
|
||||||
|
int kNumTimestamps = probs_desc.num_layers();
|
||||||
|
int kBatchSize = probs_desc.batch_size();
|
||||||
|
int kNumLabels = probs_desc.data_size();
|
||||||
|
int total_size = kNumLabels * kNumTimestamps * kBatchSize;
|
||||||
|
(void)total_size;
|
||||||
|
|
||||||
|
auto status = wrap::miopenCTCLoss(
|
||||||
|
miopen.handle(), probs_desc.handle(), probs_data.opaque(),
|
||||||
|
labels_data.data(), labels_lengths_data.data(), input_lengths_data.data(),
|
||||||
|
costs_data.opaque(), grads_desc.handle(), grads_data.opaque(),
|
||||||
|
MIOPEN_CTC_LOSS_ALGO_DETERMINISTIC, ctc_loss_desc.handle(),
|
||||||
|
scratch_memory.opaque(), scratch_memory.size());
|
||||||
|
if (status != miopenStatusSuccess) {
|
||||||
|
LOG(FATAL) << "call to miopenCTCLoss failed: " << ToString(status);
|
||||||
|
return port::InternalError("Failure during MIOpen CTC Loss");
|
||||||
|
}
|
||||||
|
|
||||||
|
return port::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
port::Status MIOpenSupport::DoCtcLoss(
|
||||||
|
Stream* stream, dnn::DataType element_type,
|
||||||
|
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||||
|
const DeviceMemoryBase probs_data,
|
||||||
|
|
||||||
|
absl::Span<const int> labels_data,
|
||||||
|
absl::Span<const int> labels_lengths_data,
|
||||||
|
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||||
|
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||||
|
DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory) {
|
||||||
|
// Current MIOPen CTC Loss only supports the float datatype
|
||||||
|
if (element_type != dnn::DataType::kFloat) {
|
||||||
|
return port::Status(port::error::INVALID_ARGUMENT,
|
||||||
|
"MIOpenCTCLossDescriptor is supported only when the "
|
||||||
|
"DataType is float");
|
||||||
|
}
|
||||||
|
|
||||||
|
MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type));
|
||||||
|
|
||||||
|
const MIOpenRnnStateTensorDescriptor& miopen_probs_desc =
|
||||||
|
static_cast<const MIOpenRnnStateTensorDescriptor&>(probs_desc);
|
||||||
|
|
||||||
|
const MIOpenRnnStateTensorDescriptor& miopen_grads_desc =
|
||||||
|
static_cast<const MIOpenRnnStateTensorDescriptor&>(grads_desc);
|
||||||
|
|
||||||
|
return DoCtcLossImpl(stream, miopen_probs_desc, probs_data, labels_data,
|
||||||
|
labels_lengths_data, input_lengths_data, costs_data,
|
||||||
|
miopen_grads_desc, grads_data, miopen_ctc_loss_desc,
|
||||||
|
scratch_memory);
|
||||||
|
}
|
||||||
|
|
||||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
||||||
MIOpenSupport::createRnnDescriptor(
|
MIOpenSupport::createRnnDescriptor(
|
||||||
int num_layers, int hidden_size, int input_size, int cell_size,
|
int num_layers, int hidden_size, int input_size, int cell_size,
|
||||||
@ -2633,8 +2796,8 @@ void* MIOpenAllocatorCallback(void* ctx, size_t size_in_bytes) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void MIOpenDeallocatorCallback(void* ctx, void* mem) {
|
void MIOpenDeallocatorCallback(void* ctx, void* mem) {
|
||||||
// Don't need deallocator since the TensorFlow heap will automatically reclaim
|
// Don't need deallocator since the TensorFlow heap will automatically
|
||||||
// the memory
|
// reclaim the memory
|
||||||
}
|
}
|
||||||
|
|
||||||
port::Status MIOpenSupport::DoPrepareForConvolution(
|
port::Status MIOpenSupport::DoPrepareForConvolution(
|
||||||
@ -2835,9 +2998,9 @@ port::Status MIOpenSupport::DoConvolve(
|
|||||||
if (!timer->Init()) {
|
if (!timer->Init()) {
|
||||||
return port::Status(port::error::INTERNAL, "Failed to init timer");
|
return port::Status(port::error::INTERNAL, "Failed to init timer");
|
||||||
}
|
}
|
||||||
// The start and stop of the timer should be as close to the MIOpen call as
|
// The start and stop of the timer should be as close to the MIOpen call
|
||||||
// possible. It is still possible for other threads to issue workload on
|
// as possible. It is still possible for other threads to issue workload
|
||||||
// to this stream. So it could take multiple profiling measurements.
|
// on to this stream. So it could take multiple profiling measurements.
|
||||||
if (!timer->Start(AsGpuStream(stream))) {
|
if (!timer->Start(AsGpuStream(stream))) {
|
||||||
timer->Destroy();
|
timer->Destroy();
|
||||||
return port::Status(port::error::INTERNAL, "Failed to start timer");
|
return port::Status(port::error::INTERNAL, "Failed to start timer");
|
||||||
@ -2971,9 +3134,9 @@ bool MIOpenSupport::GetMIOpenConvolveAlgorithms(
|
|||||||
miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
|
miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
|
||||||
input_nd.handle(), &maxSolutionCount);
|
input_nd.handle(), &maxSolutionCount);
|
||||||
if (status != miopenStatusSuccess) {
|
if (status != miopenStatusSuccess) {
|
||||||
LOG(FATAL)
|
LOG(FATAL) << "call to miopenConvolutionBackwardDataGetSolutionCount "
|
||||||
<< "call to miopenConvolutionBackwardDataGetSolutionCount failed: "
|
"failed: "
|
||||||
<< ToString(status);
|
<< ToString(status);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
@ -4143,8 +4306,8 @@ bool MIOpenSupport::DoNormalizeBackwardWithDimensions(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
LOG(ERROR)
|
LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
|
||||||
<< "Failed to calculate tensor size to chain forward and backward LRN";
|
"backward LRN";
|
||||||
}
|
}
|
||||||
|
|
||||||
status = wrap::miopenLRNForward(miopen.handle(), normalize.handle(), &alpha,
|
status = wrap::miopenLRNForward(miopen.handle(), normalize.handle(), &alpha,
|
||||||
|
@ -33,6 +33,8 @@ class GpuExecutor;
|
|||||||
class MIOpenRnnDescriptor;
|
class MIOpenRnnDescriptor;
|
||||||
class MIOpenRnnSequenceTensorDescriptor;
|
class MIOpenRnnSequenceTensorDescriptor;
|
||||||
class MIOpenRnnStateTensorDescriptor;
|
class MIOpenRnnStateTensorDescriptor;
|
||||||
|
class MIOpenCTCLossDescriptor;
|
||||||
|
|
||||||
// Opaque and unique identifier for the MIOpen plugin.
|
// Opaque and unique identifier for the MIOpen plugin.
|
||||||
extern const PluginId kMIOpenPlugin;
|
extern const PluginId kMIOpenPlugin;
|
||||||
|
|
||||||
@ -636,6 +638,17 @@ class MIOpenSupport : public dnn::DnnSupport {
|
|||||||
|
|
||||||
GpuExecutor* GetParentExecutor() { return parent_; }
|
GpuExecutor* GetParentExecutor() { return parent_; }
|
||||||
|
|
||||||
|
port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
|
||||||
|
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||||
|
const DeviceMemoryBase probs_data,
|
||||||
|
absl::Span<const int> labels_data,
|
||||||
|
absl::Span<const int> labels_lengths_data,
|
||||||
|
absl::Span<const int> input_lengths_data,
|
||||||
|
DeviceMemoryBase costs_data,
|
||||||
|
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||||
|
DeviceMemoryBase grads_data,
|
||||||
|
DeviceMemory<uint8> scratch_memory) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GpuExecutor* parent_; // Parent executor object. Not owned.
|
GpuExecutor* parent_; // Parent executor object. Not owned.
|
||||||
|
|
||||||
@ -784,6 +797,25 @@ class MIOpenSupport : public dnn::DnnSupport {
|
|||||||
ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
|
ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
|
||||||
DeviceMemory<uint8>* scratch_memory) override;
|
DeviceMemory<uint8>* scratch_memory) override;
|
||||||
|
|
||||||
|
port::Status DoCtcLossImpl(
|
||||||
|
Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc,
|
||||||
|
const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
|
||||||
|
absl::Span<const int> labels_lengths_data,
|
||||||
|
absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
|
||||||
|
const MIOpenRnnStateTensorDescriptor& grads_desc,
|
||||||
|
DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
|
||||||
|
DeviceMemory<uint8> scratch_memory);
|
||||||
|
|
||||||
|
port::Status DoPrepareForCtcLoss(
|
||||||
|
Stream* stream, dnn::DataType element_type,
|
||||||
|
const dnn::RnnStateTensorDescriptor& probs_desc,
|
||||||
|
const dnn::RnnStateTensorDescriptor& grads_desc,
|
||||||
|
absl::Span<const int> labels_data,
|
||||||
|
absl::Span<const int> labels_lengths_data,
|
||||||
|
absl::Span<const int> input_lengths_data,
|
||||||
|
ScratchAllocator* scratch_allocator,
|
||||||
|
DeviceMemory<uint8>* scratch_memory) override;
|
||||||
|
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(MIOpenSupport);
|
SE_DISALLOW_COPY_AND_ASSIGN(MIOpenSupport);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user