Upgrade to CUDNN RNN v8 APIs
This commit is contained in:
parent
d22d31bebd
commit
610f79e3d3
tensorflow
core/kernels
stream_executor
@ -826,14 +826,33 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
|
||||
}
|
||||
|
||||
Stream* stream = context->op_device_context()->stream();
|
||||
|
||||
Tensor seq_lengths_tensor;
|
||||
DeviceMemory<int> seq_lengths_ptr;
|
||||
if (sequence_lengths != nullptr) {
|
||||
auto seq_lengths_vec = sequence_lengths->template flat<int>();
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(DT_INT32,
|
||||
{seq_lengths_vec.size()},
|
||||
&seq_lengths_tensor));
|
||||
seq_lengths_ptr = AsDeviceMemory<int>(&seq_lengths_tensor);
|
||||
if (!stream->ThenMemcpy(&seq_lengths_ptr,
|
||||
seq_lengths_vec.data(),
|
||||
seq_lengths_vec.size() * sizeof(int)).ok()) {
|
||||
return errors::InvalidArgument("Failed to copy memory from host to "
|
||||
"device for sequence_lengths in "
|
||||
"CudnnRNNV3");
|
||||
}
|
||||
}
|
||||
|
||||
bool launch_success =
|
||||
stream
|
||||
->ThenRnnForward(rnn_desc, *input_desc, input_data, *h_state_desc,
|
||||
input_h_data, *c_state_desc, input_c_data,
|
||||
params_data, *output_desc, &output_data,
|
||||
*h_state_desc, &output_h_data, *c_state_desc,
|
||||
&output_c_data, is_training, reserve_space_allocator,
|
||||
workspace_allocator, output_profile_result)
|
||||
->ThenRnnForward(rnn_desc, *input_desc, input_data, seq_lengths_ptr,
|
||||
*h_state_desc, input_h_data, *c_state_desc,
|
||||
input_c_data, params_data, *output_desc,
|
||||
&output_data, *h_state_desc, &output_h_data,
|
||||
*c_state_desc, &output_c_data, is_training,
|
||||
reserve_space_allocator, workspace_allocator,
|
||||
output_profile_result)
|
||||
.ok();
|
||||
return launch_success
|
||||
? Status::OK()
|
||||
@ -905,17 +924,35 @@ Status DoBackward(
|
||||
// Creates a memory callback for the workspace. The memory lives to the end
|
||||
// of this kernel calls.
|
||||
Stream* stream = context->op_device_context()->stream();
|
||||
|
||||
Tensor seq_lengths_tensor;
|
||||
DeviceMemory<int> seq_lengths_ptr;
|
||||
if (sequence_lengths != nullptr) {
|
||||
auto seq_lengths_vec = sequence_lengths->template flat<int>();
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(DT_INT32,
|
||||
{seq_lengths_vec.size()},
|
||||
&seq_lengths_tensor));
|
||||
seq_lengths_ptr = AsDeviceMemory<int>(&seq_lengths_tensor);
|
||||
if (!stream->ThenMemcpy(&seq_lengths_ptr,
|
||||
seq_lengths_vec.data(),
|
||||
seq_lengths_vec.size() * sizeof(int)).ok()) {
|
||||
return errors::InvalidArgument("Failed to copy memory from host to "
|
||||
"device for sequence_lengths in "
|
||||
"CudnnRNNBackwardOpV3");
|
||||
}
|
||||
}
|
||||
|
||||
bool launch_success =
|
||||
stream
|
||||
->ThenRnnBackward(
|
||||
rnn_desc, *input_desc, input_data, *h_state_desc, input_h_data,
|
||||
*c_state_desc, input_c_data, params_data, *output_desc,
|
||||
output_data, *h_state_desc, output_h_data, *c_state_desc,
|
||||
output_c_data, output_backprop_data, output_h_backprop_data,
|
||||
output_c_backprop_data, &input_backprop_data,
|
||||
&input_h_backprop_data, &input_c_backprop_data,
|
||||
¶ms_backprop_data, &reserve_space_uint8, workspace_allocator,
|
||||
output_profile_result)
|
||||
rnn_desc, *input_desc, input_data, seq_lengths_ptr, *h_state_desc,
|
||||
input_h_data, *c_state_desc, input_c_data, params_data,
|
||||
*output_desc, output_data, *h_state_desc, output_h_data,
|
||||
*c_state_desc, output_c_data, output_backprop_data,
|
||||
output_h_backprop_data, output_c_backprop_data,
|
||||
&input_backprop_data, &input_h_backprop_data,
|
||||
&input_c_backprop_data, ¶ms_backprop_data,
|
||||
&reserve_space_uint8, workspace_allocator, output_profile_result)
|
||||
.ok();
|
||||
return launch_success
|
||||
? Status::OK()
|
||||
|
@ -1181,6 +1181,9 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
||||
}
|
||||
|
||||
// Create the params handle.
|
||||
// TODO(kaixih@nvidia.com): Should be removed when cudnnRNNForward*** and
|
||||
// cudnnRNNForward***Ex are removed from the codebase, since the new API
|
||||
// doesn't need param descriptors any more.
|
||||
SE_ASSIGN_OR_RETURN(auto params_desc,
|
||||
CudnnRnnParamsDescriptor::Create(
|
||||
cudnn, input_size, data_type, rnn_desc.get(),
|
||||
@ -1653,10 +1656,16 @@ port::Status CheckRNNParameterSize(
|
||||
const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc,
|
||||
const CudnnRnnSequenceTensorDescriptor& input_desc) {
|
||||
size_t params_size_in_bytes = 0;
|
||||
#if CUDNN_VERSION >= 8000
|
||||
RETURN_IF_CUDNN_ERROR(cudnnGetRNNWeightSpaceSize(
|
||||
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
|
||||
/*sizeInBytes=*/¶ms_size_in_bytes));
|
||||
#else
|
||||
RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
|
||||
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
|
||||
/*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/¶ms_size_in_bytes,
|
||||
/*dataType=*/rnn_desc.data_type()));
|
||||
#endif
|
||||
if (static_cast<int64>(params_size_in_bytes) !=
|
||||
rnn_desc.ParamsSizeInBytes()) {
|
||||
return port::Status(port::error::INVALID_ARGUMENT,
|
||||
@ -1741,6 +1750,7 @@ port::Status CudnnSupport::DoRnnForwardImpl(
|
||||
Stream* stream, const CudnnRnnDescriptor& rnn_desc,
|
||||
const CudnnRnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<T>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const CudnnRnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<T>& input_h_data,
|
||||
const CudnnRnnStateTensorDescriptor& input_c_desc,
|
||||
@ -1764,6 +1774,78 @@ port::Status CudnnSupport::DoRnnForwardImpl(
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
|
||||
SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
|
||||
|
||||
// In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been
|
||||
// deprecated. Instead, we use the cudnnRNNForward which requires the
|
||||
// sequence_lengths parameter.
|
||||
#if CUDNN_VERSION >= 8000
|
||||
if (input_desc.is_var_seq_lengths()) {
|
||||
DeviceMemory<uint8> workspace;
|
||||
DeviceMemory<uint8> reserve_space;
|
||||
cudnnForwardMode_t rnn_fwd_mode;
|
||||
if (is_training) {
|
||||
rnn_fwd_mode = CUDNN_FWD_MODE_TRAINING;
|
||||
} else {
|
||||
rnn_fwd_mode = CUDNN_FWD_MODE_INFERENCE;
|
||||
}
|
||||
size_t reserve_space_size_in_bytes = 0;
|
||||
size_t workspace_size_in_bytes = 0;
|
||||
RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes(
|
||||
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
|
||||
/*fMode=*/rnn_fwd_mode, /*xDesc=*/input_desc.data_handle(),
|
||||
/*workSpaceSize=*/&workspace_size_in_bytes,
|
||||
/*reserveSpaceSize=*/&reserve_space_size_in_bytes));
|
||||
|
||||
if (workspace_size_in_bytes > 0) {
|
||||
SE_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes(
|
||||
workspace_size_in_bytes));
|
||||
}
|
||||
if (reserve_space_size_in_bytes > 0) {
|
||||
SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
|
||||
reserve_space_size_in_bytes));
|
||||
}
|
||||
|
||||
std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
|
||||
const bool is_profiling = output_profile_result != nullptr;
|
||||
if (is_profiling) {
|
||||
timer.reset(new GpuTimer(parent_));
|
||||
// The start and stop of the timer should be as close to the Cudnn call as
|
||||
// possible. It is still possible for other threads to issue workload on
|
||||
// to this stream. So it could take multiple profiling measurements.
|
||||
if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
|
||||
return port::Status(port::error::INTERNAL, "Failed to start timer");
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_IF_CUDNN_ERROR(cudnnRNNForward(
|
||||
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
|
||||
/*fwdMode=*/rnn_fwd_mode,
|
||||
/*devSeqLengths=*/reinterpret_cast<const int*>(
|
||||
seq_lengths_data.opaque()),
|
||||
/*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
|
||||
/*yDesc=*/output_desc.data_handle(), /*y=*/output_data->opaque(),
|
||||
/*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
|
||||
/*hy=*/output_h_data->opaque(),
|
||||
/*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
|
||||
/*cy=*/output_c_data->opaque(),
|
||||
/*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(),
|
||||
/*weightSpace=*/params.opaque(),
|
||||
/*workSpaceSize=*/workspace.size(), /*workspace=*/workspace.opaque(),
|
||||
/*reserveSpaceSizeInBytes=*/reserve_space.size(),
|
||||
/*reserveSpace=*/reserve_space.opaque()));
|
||||
|
||||
if (is_profiling) {
|
||||
if (!timer->Stop(AsGpuStream(stream))) {
|
||||
return port::Status(port::error::INTERNAL, "Failed to stop timer");
|
||||
}
|
||||
auto algo_desc = *rnn_desc.algorithm_config().algorithm();
|
||||
output_profile_result->set_algorithm(algo_desc);
|
||||
output_profile_result->set_elapsed_time_in_ms(
|
||||
timer->GetElapsedMilliseconds());
|
||||
}
|
||||
return port::Status::OK();
|
||||
}
|
||||
#endif
|
||||
SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
|
||||
CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
|
||||
workspace_allocator))
|
||||
@ -1828,7 +1910,6 @@ port::Status CudnnSupport::DoRnnForwardImpl(
|
||||
}
|
||||
} else {
|
||||
if (input_desc.is_var_seq_lengths()) {
|
||||
// cudnnSetRNNPaddingMode(rnn_desc.handle(), CUDNN_RNN_PADDED_IO_ENABLED);
|
||||
RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTrainingEx(
|
||||
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
|
||||
/*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
|
||||
@ -1881,6 +1962,7 @@ port::Status CudnnSupport::DoRnnBackwardImpl(
|
||||
Stream* stream, const CudnnRnnDescriptor& rnn_desc,
|
||||
const CudnnRnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<T>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const CudnnRnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<T>& input_h_data,
|
||||
const CudnnRnnStateTensorDescriptor& input_c_desc,
|
||||
@ -1911,6 +1993,91 @@ port::Status CudnnSupport::DoRnnBackwardImpl(
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
|
||||
SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
|
||||
|
||||
// In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been
|
||||
// deprecated. Instead, we use the cudnnRNNForward which requires the
|
||||
// sequence_lengths parameter.
|
||||
#if CUDNN_VERSION >= 8000
|
||||
if (input_desc.is_var_seq_lengths()) {
|
||||
DeviceMemory<uint8> workspace;
|
||||
size_t workspace_size_in_bytes = 0;
|
||||
RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes(
|
||||
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
|
||||
/*fMode=*/CUDNN_FWD_MODE_TRAINING, /*xDesc=*/input_desc.data_handle(),
|
||||
/*workSpaceSize=*/&workspace_size_in_bytes,
|
||||
/*reserveSpaceSize=*/NULL));
|
||||
if (workspace_size_in_bytes > 0) {
|
||||
SE_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes(
|
||||
workspace_size_in_bytes));
|
||||
}
|
||||
|
||||
std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
|
||||
const bool is_profiling = output_profile_result != nullptr;
|
||||
if (is_profiling) {
|
||||
timer.reset(new GpuTimer(parent_));
|
||||
// The start and stop of the timer should be as close to the Cudnn call as
|
||||
// possible. It is still possible for other threads to issue workload on
|
||||
// to this stream. So it could take multiple profiling measurements.
|
||||
if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
|
||||
return port::Status(port::error::INTERNAL, "Failed to start timer");
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData_v8(
|
||||
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
|
||||
/*devSeqLengths=*/reinterpret_cast<const int*>(
|
||||
seq_lengths_data.opaque()),
|
||||
/*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(),
|
||||
/*dy=*/output_backprop_data.opaque(),
|
||||
/*xDesc=*/input_desc.data_handle(),
|
||||
/*dx=*/input_backprop_data->opaque(),
|
||||
/*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
|
||||
/*dhy=*/output_h_backprop_data.opaque(),
|
||||
/*dhx=*/input_h_backprop_data->opaque(),
|
||||
/*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
|
||||
/*dcy=*/output_c_backprop_data.opaque(),
|
||||
/*dcx=*/input_c_backprop_data->opaque(),
|
||||
/*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(),
|
||||
/*weightSpace=*/params.opaque(),
|
||||
/*workSpaceSize=*/workspace.size(), /*workSpace=*/workspace.opaque(),
|
||||
/*reserveSpaceSize=*/reserve_space_data->size(),
|
||||
/*reserveSpace=*/reserve_space_data->opaque()));
|
||||
|
||||
if (params_backprop_data != nullptr) {
|
||||
// Clear the dw to zeros.
|
||||
stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
|
||||
RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights_v8(
|
||||
/*handle=*/cudnn.handle(),
|
||||
/*rnnDesc=*/rnn_desc.handle(),
|
||||
/*addGrad=*/CUDNN_WGRAD_MODE_ADD,
|
||||
/*devSeqLengths=*/reinterpret_cast<const int*>(
|
||||
seq_lengths_data.opaque()),
|
||||
/*xDesc=*/input_desc.data_handle(),
|
||||
/*x=*/input_data.opaque(),
|
||||
/*hDesc=*/input_h_desc.handle(),
|
||||
/*hx=*/input_h_data.opaque(),
|
||||
/*yDesc=*/output_desc.data_handle(),
|
||||
/*y=*/output_data.opaque(),
|
||||
/*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(),
|
||||
/*dweightSpace=*/params_backprop_data->opaque(),
|
||||
/*workSpaceSize=*/workspace.size(),
|
||||
/*workSpace=*/workspace.opaque(),
|
||||
/*reserveSpaceSize=*/reserve_space_data->size(),
|
||||
/*reserveSpace=*/reserve_space_data->opaque()));
|
||||
}
|
||||
|
||||
if (is_profiling) {
|
||||
if (!timer->Stop(AsGpuStream(stream))) {
|
||||
return port::Status(port::error::INTERNAL, "Failed to stop timer");
|
||||
}
|
||||
auto algo_desc = *rnn_desc.algorithm_config().algorithm();
|
||||
output_profile_result->set_algorithm(algo_desc);
|
||||
output_profile_result->set_elapsed_time_in_ms(
|
||||
timer->GetElapsedMilliseconds());
|
||||
}
|
||||
return port::Status::OK();
|
||||
}
|
||||
#endif
|
||||
SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
|
||||
CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
|
||||
workspace_allocator));
|
||||
@ -2121,6 +2288,7 @@ bool CudnnSupport::DoRnnForward(
|
||||
Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<Eigen::half>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<Eigen::half>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -2152,10 +2320,11 @@ bool CudnnSupport::DoRnnForward(
|
||||
return IsStatusOk(
|
||||
DoRnnForwardImpl<Eigen::half>(
|
||||
stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
|
||||
cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
|
||||
params, cudnn_output_desc, output_data, cudnn_output_h_desc,
|
||||
output_h_data, cudnn_output_c_desc, output_c_data, is_training,
|
||||
reserve_space_allocator, workspace_allocator, output_profile_result),
|
||||
seq_lengths_data, cudnn_input_h_desc, input_h_data,
|
||||
cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
|
||||
output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
|
||||
output_c_data, is_training, reserve_space_allocator,
|
||||
workspace_allocator, output_profile_result),
|
||||
/*report_error=*/!output_profile_result);
|
||||
}
|
||||
|
||||
@ -2163,6 +2332,7 @@ bool CudnnSupport::DoRnnForward(
|
||||
Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<float>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<float>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -2193,10 +2363,11 @@ bool CudnnSupport::DoRnnForward(
|
||||
return IsStatusOk(
|
||||
DoRnnForwardImpl<float>(
|
||||
stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
|
||||
cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
|
||||
params, cudnn_output_desc, output_data, cudnn_output_h_desc,
|
||||
output_h_data, cudnn_output_c_desc, output_c_data, is_training,
|
||||
reserve_space_allocator, workspace_allocator, output_profile_result),
|
||||
seq_lengths_data, cudnn_input_h_desc, input_h_data,
|
||||
cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
|
||||
output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
|
||||
output_c_data, is_training, reserve_space_allocator,
|
||||
workspace_allocator, output_profile_result),
|
||||
/*report_error=*/!output_profile_result);
|
||||
}
|
||||
|
||||
@ -2204,6 +2375,7 @@ bool CudnnSupport::DoRnnForward(
|
||||
Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<double>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<double>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -2235,10 +2407,11 @@ bool CudnnSupport::DoRnnForward(
|
||||
return IsStatusOk(
|
||||
DoRnnForwardImpl<double>(
|
||||
stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
|
||||
cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
|
||||
params, cudnn_output_desc, output_data, cudnn_output_h_desc,
|
||||
output_h_data, cudnn_output_c_desc, output_c_data, is_training,
|
||||
reserve_space_allocator, workspace_allocator, output_profile_result),
|
||||
seq_lengths_data, cudnn_input_h_desc, input_h_data,
|
||||
cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
|
||||
output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
|
||||
output_c_data, is_training, reserve_space_allocator,
|
||||
workspace_allocator, output_profile_result),
|
||||
/*report_error=*/!output_profile_result);
|
||||
}
|
||||
|
||||
@ -2246,6 +2419,7 @@ bool CudnnSupport::DoRnnBackward(
|
||||
Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<Eigen::half>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<Eigen::half>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -2284,13 +2458,13 @@ bool CudnnSupport::DoRnnBackward(
|
||||
return IsStatusOk(
|
||||
DoRnnBackwardImpl<Eigen::half>(
|
||||
stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
|
||||
cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
|
||||
params, cudnn_output_desc, output_data, cudnn_output_h_desc,
|
||||
output_h_data, cudnn_output_c_desc, output_c_data,
|
||||
output_backprop_data, output_h_backprop_data, output_c_backprop_data,
|
||||
input_backprop_data, input_h_backprop_data, input_c_backprop_data,
|
||||
params_backprop_data, reserve_space_data, workspace_allocator,
|
||||
output_profile_result),
|
||||
seq_lengths_data, cudnn_input_h_desc, input_h_data,
|
||||
cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
|
||||
output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
|
||||
output_c_data, output_backprop_data, output_h_backprop_data,
|
||||
output_c_backprop_data, input_backprop_data, input_h_backprop_data,
|
||||
input_c_backprop_data, params_backprop_data, reserve_space_data,
|
||||
workspace_allocator, output_profile_result),
|
||||
/*report_error=*/!output_profile_result);
|
||||
}
|
||||
|
||||
@ -2298,6 +2472,7 @@ bool CudnnSupport::DoRnnBackward(
|
||||
Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<float>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<float>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -2335,13 +2510,13 @@ bool CudnnSupport::DoRnnBackward(
|
||||
return IsStatusOk(
|
||||
DoRnnBackwardImpl<float>(
|
||||
stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
|
||||
cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
|
||||
params, cudnn_output_desc, output_data, cudnn_output_h_desc,
|
||||
output_h_data, cudnn_output_c_desc, output_c_data,
|
||||
output_backprop_data, output_h_backprop_data, output_c_backprop_data,
|
||||
input_backprop_data, input_h_backprop_data, input_c_backprop_data,
|
||||
params_backprop_data, reserve_space_data, workspace_allocator,
|
||||
output_profile_result),
|
||||
seq_lengths_data, cudnn_input_h_desc, input_h_data,
|
||||
cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
|
||||
output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
|
||||
output_c_data, output_backprop_data, output_h_backprop_data,
|
||||
output_c_backprop_data, input_backprop_data, input_h_backprop_data,
|
||||
input_c_backprop_data, params_backprop_data, reserve_space_data,
|
||||
workspace_allocator, output_profile_result),
|
||||
/*report_error=*/!output_profile_result);
|
||||
}
|
||||
|
||||
@ -2349,6 +2524,7 @@ bool CudnnSupport::DoRnnBackward(
|
||||
Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<double>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<double>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -2387,13 +2563,13 @@ bool CudnnSupport::DoRnnBackward(
|
||||
return IsStatusOk(
|
||||
DoRnnBackwardImpl<double>(
|
||||
stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
|
||||
cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
|
||||
params, cudnn_output_desc, output_data, cudnn_output_h_desc,
|
||||
output_h_data, cudnn_output_c_desc, output_c_data,
|
||||
output_backprop_data, output_h_backprop_data, output_c_backprop_data,
|
||||
input_backprop_data, input_h_backprop_data, input_c_backprop_data,
|
||||
params_backprop_data, reserve_space_data, workspace_allocator,
|
||||
output_profile_result),
|
||||
seq_lengths_data, cudnn_input_h_desc, input_h_data,
|
||||
cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
|
||||
output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
|
||||
output_c_data, output_backprop_data, output_h_backprop_data,
|
||||
output_c_backprop_data, input_backprop_data, input_h_backprop_data,
|
||||
input_c_backprop_data, params_backprop_data, reserve_space_data,
|
||||
workspace_allocator, output_profile_result),
|
||||
/*report_error=*/!output_profile_result);
|
||||
}
|
||||
|
||||
|
@ -74,6 +74,7 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||
bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<Eigen::half>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<Eigen::half>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -92,6 +93,7 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||
bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<float>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<float>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -110,6 +112,7 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||
bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<double>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<double>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -128,6 +131,7 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||
bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<Eigen::half>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<Eigen::half>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -153,6 +157,7 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||
bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<float>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<float>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -178,6 +183,7 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||
bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<double>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<double>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -641,6 +647,7 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||
Stream* stream, const CudnnRnnDescriptor& rnn_desc,
|
||||
const CudnnRnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<T>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const CudnnRnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<T>& input_h_data,
|
||||
const CudnnRnnStateTensorDescriptor& input_c_desc,
|
||||
@ -660,6 +667,7 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||
Stream* stream, const CudnnRnnDescriptor& rnn_desc,
|
||||
const CudnnRnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<T>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const CudnnRnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<T>& input_h_data,
|
||||
const CudnnRnnStateTensorDescriptor& input_c_desc,
|
||||
|
@ -1786,6 +1786,16 @@ cudnnStatus_t CUDNNWINAPI cudnnSetPersistentRNNPlan(
|
||||
return func_ptr(rnnDesc, plan);
|
||||
}
|
||||
|
||||
cudnnStatus_t CUDNNWINAPI cudnnGetRNNWeightSpaceSize(
|
||||
cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc,
|
||||
size_t *weightSpaceSize) {
|
||||
using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(
|
||||
cudnnHandle_t, cudnnRNNDescriptor_t, size_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNWeightSpaceSize");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(handle, rnnDesc, weightSpaceSize);
|
||||
}
|
||||
|
||||
cudnnStatus_t CUDNNWINAPI cudnnGetRNNWorkspaceSize(
|
||||
cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc,
|
||||
const int seqLength, const cudnnTensorDescriptor_t *xDesc,
|
||||
@ -1798,6 +1808,19 @@ cudnnStatus_t CUDNNWINAPI cudnnGetRNNWorkspaceSize(
|
||||
return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes);
|
||||
}
|
||||
|
||||
cudnnStatus_t CUDNNWINAPI cudnnGetRNNTempSpaceSizes(
|
||||
cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc,
|
||||
cudnnForwardMode_t fMode, cudnnRNNDataDescriptor_t xDesc,
|
||||
size_t *workSpaceSize, size_t *reserveSpaceSize) {
|
||||
using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(
|
||||
cudnnHandle_t, cudnnRNNDescriptor_t, cudnnForwardMode_t,
|
||||
cudnnRNNDataDescriptor_t, size_t *, size_t *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNTempSpaceSizes");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(handle, rnnDesc, fMode, xDesc, workSpaceSize,
|
||||
reserveSpaceSize);
|
||||
}
|
||||
|
||||
cudnnStatus_t CUDNNWINAPI
|
||||
cudnnGetRNNParamsSize(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc,
|
||||
const cudnnTensorDescriptor_t xDesc, size_t *sizeInBytes,
|
||||
@ -2748,6 +2771,28 @@ cudnnStatus_t CUDNNWINAPI cudnnRNNForwardTrainingEx(
|
||||
reserveSpace, reserveSpaceSizeInBytes);
|
||||
}
|
||||
|
||||
cudnnStatus_t CUDNNWINAPI cudnnRNNForward(
|
||||
cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc,
|
||||
cudnnForwardMode_t fwdMode, const int32_t devSeqLengths[],
|
||||
cudnnRNNDataDescriptor_t xDesc, const void *x,
|
||||
cudnnRNNDataDescriptor_t yDesc, void *y, cudnnTensorDescriptor_t hDesc,
|
||||
const void *hx, void *hy, cudnnTensorDescriptor_t cDesc, const void *cx,
|
||||
void *cy, size_t weightSpaceSize, const void *weightSpace,
|
||||
size_t workSpaceSize, void *workSpace, size_t reserveSpaceSize,
|
||||
void *reserveSpace) {
|
||||
using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(
|
||||
cudnnHandle_t, cudnnRNNDescriptor_t, cudnnForwardMode_t, const int32_t [],
|
||||
cudnnRNNDataDescriptor_t, const void *, cudnnRNNDataDescriptor_t, void *,
|
||||
cudnnTensorDescriptor_t, const void *, void *, cudnnTensorDescriptor_t,
|
||||
const void *, void *, size_t, const void *, size_t, void *, size_t,
|
||||
void *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForward");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(handle, rnnDesc, fwdMode, devSeqLengths, xDesc, x, yDesc, y,
|
||||
hDesc, hx, hy, cDesc, cx, cy, weightSpaceSize, weightSpace,
|
||||
workSpaceSize, workSpace, reserveSpaceSize, reserveSpace);
|
||||
}
|
||||
|
||||
cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardDataEx(
|
||||
cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc,
|
||||
const cudnnRNNDataDescriptor_t yDesc, const void *y,
|
||||
@ -2787,6 +2832,28 @@ cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardDataEx(
|
||||
reserveSpaceSizeInBytes);
|
||||
}
|
||||
|
||||
cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardData_v8(
|
||||
cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc,
|
||||
const int32_t devSeqLengths[], cudnnRNNDataDescriptor_t yDesc,
|
||||
const void *y, const void *dy, cudnnRNNDataDescriptor_t xDesc, void *dx,
|
||||
cudnnTensorDescriptor_t hDesc, const void *hx, const void *dhy, void *dhx,
|
||||
cudnnTensorDescriptor_t cDesc, const void *cx, const void *dcy, void *dcx,
|
||||
size_t weightSpaceSize, const void *weightSpace, size_t workSpaceSize,
|
||||
void *workSpace, size_t reserveSpaceSize, void *reserveSpace) {
|
||||
using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(
|
||||
cudnnHandle_t, cudnnRNNDescriptor_t, const int32_t [],
|
||||
cudnnRNNDataDescriptor_t, const void *, const void *,
|
||||
cudnnRNNDataDescriptor_t, void *, cudnnTensorDescriptor_t, const void *,
|
||||
const void *, void *, cudnnTensorDescriptor_t, const void *, const void *,
|
||||
void *, size_t, const void *, size_t, void *, size_t, void *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardData_v8");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(handle, rnnDesc, devSeqLengths, yDesc, y, dy, xDesc, dx,
|
||||
hDesc, hx, dhy, dhx, cDesc, cx, dcy, dcx, weightSpaceSize,
|
||||
weightSpace, workSpaceSize, workSpace, reserveSpaceSize,
|
||||
reserveSpace);
|
||||
}
|
||||
|
||||
cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeightsEx(
|
||||
cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc,
|
||||
const cudnnRNNDataDescriptor_t xDesc, const void *x,
|
||||
@ -2806,6 +2873,26 @@ cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeightsEx(
|
||||
reserveSpaceSizeInBytes);
|
||||
}
|
||||
|
||||
cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeights_v8(
|
||||
cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc,
|
||||
cudnnWgradMode_t addGrad, const int32_t devSeqLengths[],
|
||||
cudnnRNNDataDescriptor_t xDesc, const void *x,
|
||||
cudnnTensorDescriptor_t hDesc, const void *hx,
|
||||
cudnnRNNDataDescriptor_t yDesc, const void *y, size_t weightSpaceSize,
|
||||
void *dweightSpace, size_t workSpaceSize, void *workSpace,
|
||||
size_t reserveSpaceSize, void *reserveSpace) {
|
||||
using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(
|
||||
cudnnHandle_t, cudnnRNNDescriptor_t, cudnnWgradMode_t, const int32_t [],
|
||||
cudnnRNNDataDescriptor_t, const void *, cudnnTensorDescriptor_t,
|
||||
const void *, cudnnRNNDataDescriptor_t, const void *, size_t, void *,
|
||||
size_t, void *, size_t, void *);
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardWeights_v8");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr(handle, rnnDesc, addGrad, devSeqLengths, xDesc, x, hDesc, hx,
|
||||
yDesc, y, weightSpaceSize, dweightSpace, workSpaceSize,
|
||||
workSpace, reserveSpaceSize, reserveSpace);
|
||||
}
|
||||
|
||||
cudnnStatus_t CUDNNWINAPI cudnnMultiHeadAttnBackwardData(
|
||||
cudnnHandle_t handle, const cudnnAttnDescriptor_t attnDesc,
|
||||
const int loWinIdx[], const int hiWinIdx[], const int devSeqLengthsDQDO[],
|
||||
|
@ -2185,6 +2185,7 @@ class DnnSupport {
|
||||
virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<Eigen::half>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<Eigen::half>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -2206,6 +2207,7 @@ class DnnSupport {
|
||||
virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<float>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<float>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -2227,6 +2229,7 @@ class DnnSupport {
|
||||
virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<double>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<double>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -2289,6 +2292,7 @@ class DnnSupport {
|
||||
Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<Eigen::half>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<Eigen::half>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -2317,6 +2321,7 @@ class DnnSupport {
|
||||
Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<float>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<float>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
@ -2345,6 +2350,7 @@ class DnnSupport {
|
||||
Stream* stream, const dnn::RnnDescriptor& rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor& input_desc,
|
||||
const DeviceMemory<double>& input_data,
|
||||
const DeviceMemory<int>& seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_h_desc,
|
||||
const DeviceMemory<double>& input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor& input_c_desc,
|
||||
|
@ -4539,6 +4539,7 @@ Stream &Stream::ThenRnnForward(
|
||||
const dnn::RnnDescriptor &rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor &input_desc,
|
||||
const DeviceMemory<Eigen::half> &input_data,
|
||||
const DeviceMemory<int> &seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_h_desc,
|
||||
const DeviceMemory<Eigen::half> &input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_c_desc,
|
||||
@ -4556,10 +4557,11 @@ Stream &Stream::ThenRnnForward(
|
||||
// TODO(zhengxq): add VLOG PARAM calls.
|
||||
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||
auto status = dnn->DoRnnForward(
|
||||
this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
|
||||
input_c_desc, input_c_data, params, output_desc, output_data,
|
||||
output_h_desc, output_h_data, output_c_desc, output_c_data, is_training,
|
||||
reserve_space_allocator, workspace_allocator, output_profile_result);
|
||||
this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
|
||||
input_h_data, input_c_desc, input_c_data, params, output_desc,
|
||||
output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
|
||||
is_training, reserve_space_allocator, workspace_allocator,
|
||||
output_profile_result);
|
||||
if (!status && !output_profile_result) {
|
||||
SetError();
|
||||
}
|
||||
@ -4573,6 +4575,7 @@ Stream &Stream::ThenRnnForward(
|
||||
const dnn::RnnDescriptor &rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor &input_desc,
|
||||
const DeviceMemory<float> &input_data,
|
||||
const DeviceMemory<int> &seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_h_desc,
|
||||
const DeviceMemory<float> &input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_c_desc,
|
||||
@ -4589,10 +4592,11 @@ Stream &Stream::ThenRnnForward(
|
||||
// TODO(zhengxq): add VLOG PARAM calls.
|
||||
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||
auto status = dnn->DoRnnForward(
|
||||
this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
|
||||
input_c_desc, input_c_data, params, output_desc, output_data,
|
||||
output_h_desc, output_h_data, output_c_desc, output_c_data, is_training,
|
||||
reserve_space_allocator, workspace_allocator, output_profile_result);
|
||||
this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
|
||||
input_h_data, input_c_desc, input_c_data, params, output_desc,
|
||||
output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
|
||||
is_training, reserve_space_allocator, workspace_allocator,
|
||||
output_profile_result);
|
||||
if (!status && !output_profile_result) {
|
||||
SetError();
|
||||
}
|
||||
@ -4606,6 +4610,7 @@ Stream &Stream::ThenRnnForward(
|
||||
const dnn::RnnDescriptor &rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor &input_desc,
|
||||
const DeviceMemory<double> &input_data,
|
||||
const DeviceMemory<int> &seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_h_desc,
|
||||
const DeviceMemory<double> &input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_c_desc,
|
||||
@ -4623,10 +4628,11 @@ Stream &Stream::ThenRnnForward(
|
||||
// TODO(zhengxq): add VLOG PARAM calls.
|
||||
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||
auto status = dnn->DoRnnForward(
|
||||
this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
|
||||
input_c_desc, input_c_data, params, output_desc, output_data,
|
||||
output_h_desc, output_h_data, output_c_desc, output_c_data, is_training,
|
||||
reserve_space_allocator, workspace_allocator, output_profile_result);
|
||||
this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
|
||||
input_h_data, input_c_desc, input_c_data, params, output_desc,
|
||||
output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
|
||||
is_training, reserve_space_allocator, workspace_allocator,
|
||||
output_profile_result);
|
||||
if (!status && !output_profile_result) {
|
||||
SetError();
|
||||
}
|
||||
@ -4640,6 +4646,7 @@ Stream &Stream::ThenRnnBackward(
|
||||
const dnn::RnnDescriptor &rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor &input_desc,
|
||||
const DeviceMemory<Eigen::half> &input_data,
|
||||
const DeviceMemory<int> &seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_h_desc,
|
||||
const DeviceMemory<Eigen::half> &input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_c_desc,
|
||||
@ -4664,9 +4671,9 @@ Stream &Stream::ThenRnnBackward(
|
||||
// TODO(zhengxq): add VLOG PARAM calls.
|
||||
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||
auto status = dnn->DoRnnBackward(
|
||||
this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
|
||||
input_c_desc, input_c_data, params, output_desc, output_data,
|
||||
output_h_desc, output_h_data, output_c_desc, output_c_data,
|
||||
this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
|
||||
input_h_data, input_c_desc, input_c_data, params, output_desc,
|
||||
output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
|
||||
output_backprop_data, output_h_backprop_data, output_c_backprop_data,
|
||||
input_backprop_data, input_h_backprop_data, input_c_backprop_data,
|
||||
params_backprop_data, reserve_space_data, workspace_allocator,
|
||||
@ -4685,6 +4692,7 @@ Stream &Stream::ThenRnnBackward(
|
||||
const dnn::RnnDescriptor &rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor &input_desc,
|
||||
const DeviceMemory<float> &input_data,
|
||||
const DeviceMemory<int> &seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_h_desc,
|
||||
const DeviceMemory<float> &input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_c_desc,
|
||||
@ -4708,9 +4716,9 @@ Stream &Stream::ThenRnnBackward(
|
||||
// TODO(zhengxq): add VLOG PARAM calls.
|
||||
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||
auto status = dnn->DoRnnBackward(
|
||||
this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
|
||||
input_c_desc, input_c_data, params, output_desc, output_data,
|
||||
output_h_desc, output_h_data, output_c_desc, output_c_data,
|
||||
this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
|
||||
input_h_data, input_c_desc, input_c_data, params, output_desc,
|
||||
output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
|
||||
output_backprop_data, output_h_backprop_data, output_c_backprop_data,
|
||||
input_backprop_data, input_h_backprop_data, input_c_backprop_data,
|
||||
params_backprop_data, reserve_space_data, workspace_allocator,
|
||||
@ -4729,6 +4737,7 @@ Stream &Stream::ThenRnnBackward(
|
||||
const dnn::RnnDescriptor &rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor &input_desc,
|
||||
const DeviceMemory<double> &input_data,
|
||||
const DeviceMemory<int> &seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_h_desc,
|
||||
const DeviceMemory<double> &input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_c_desc,
|
||||
@ -4753,9 +4762,9 @@ Stream &Stream::ThenRnnBackward(
|
||||
// TODO(zhengxq): add VLOG PARAM calls.
|
||||
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
|
||||
auto status = dnn->DoRnnBackward(
|
||||
this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
|
||||
input_c_desc, input_c_data, params, output_desc, output_data,
|
||||
output_h_desc, output_h_data, output_c_desc, output_c_data,
|
||||
this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
|
||||
input_h_data, input_c_desc, input_c_data, params, output_desc,
|
||||
output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
|
||||
output_backprop_data, output_h_backprop_data, output_c_backprop_data,
|
||||
input_backprop_data, input_h_backprop_data, input_c_backprop_data,
|
||||
params_backprop_data, reserve_space_data, workspace_allocator,
|
||||
|
@ -1779,6 +1779,7 @@ class Stream {
|
||||
Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor &input_desc,
|
||||
const DeviceMemory<Eigen::half> &input_data,
|
||||
const DeviceMemory<int> &seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_h_desc,
|
||||
const DeviceMemory<Eigen::half> &input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_c_desc,
|
||||
@ -1798,6 +1799,7 @@ class Stream {
|
||||
Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor &input_desc,
|
||||
const DeviceMemory<float> &input_data,
|
||||
const DeviceMemory<int> &seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_h_desc,
|
||||
const DeviceMemory<float> &input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_c_desc,
|
||||
@ -1816,6 +1818,7 @@ class Stream {
|
||||
Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor &input_desc,
|
||||
const DeviceMemory<double> &input_data,
|
||||
const DeviceMemory<int> &seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_h_desc,
|
||||
const DeviceMemory<double> &input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_c_desc,
|
||||
@ -1837,6 +1840,7 @@ class Stream {
|
||||
const dnn::RnnDescriptor &rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor &input_desc,
|
||||
const DeviceMemory<Eigen::half> &input_data,
|
||||
const DeviceMemory<int> &seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_h_desc,
|
||||
const DeviceMemory<Eigen::half> &input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_c_desc,
|
||||
@ -1862,6 +1866,7 @@ class Stream {
|
||||
Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor &input_desc,
|
||||
const DeviceMemory<float> &input_data,
|
||||
const DeviceMemory<int> &seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_h_desc,
|
||||
const DeviceMemory<float> &input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_c_desc,
|
||||
@ -1887,6 +1892,7 @@ class Stream {
|
||||
Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
|
||||
const dnn::RnnSequenceTensorDescriptor &input_desc,
|
||||
const DeviceMemory<double> &input_data,
|
||||
const DeviceMemory<int> &seq_lengths_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_h_desc,
|
||||
const DeviceMemory<double> &input_h_data,
|
||||
const dnn::RnnStateTensorDescriptor &input_c_desc,
|
||||
|
Loading…
Reference in New Issue
Block a user