diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index bd62f39e8ae..8330aa81cb3 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -826,14 +826,34 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc, } Stream* stream = context->op_device_context()->stream(); + + Tensor seq_lengths_tensor; + DeviceMemory seq_lengths_ptr; + if (sequence_lengths != nullptr) { + TF_RETURN_IF_ERROR(context->allocate_temp( + DT_INT32, {static_cast(seq_lengths.size())}, + &seq_lengths_tensor)); + seq_lengths_ptr = AsDeviceMemory(&seq_lengths_tensor); + if (!stream + ->ThenMemcpy(&seq_lengths_ptr, seq_lengths.data(), + seq_lengths.size() * sizeof(int)) + .ok()) { + return errors::InvalidArgument( + "Failed to copy memory from host to " + "device for sequence_lengths in " + "CudnnRNNV3"); + } + } + bool launch_success = stream - ->ThenRnnForward(rnn_desc, *input_desc, input_data, *h_state_desc, - input_h_data, *c_state_desc, input_c_data, - params_data, *output_desc, &output_data, - *h_state_desc, &output_h_data, *c_state_desc, - &output_c_data, is_training, reserve_space_allocator, - workspace_allocator, output_profile_result) + ->ThenRnnForward(rnn_desc, *input_desc, input_data, seq_lengths_ptr, + *h_state_desc, input_h_data, *c_state_desc, + input_c_data, params_data, *output_desc, + &output_data, *h_state_desc, &output_h_data, + *c_state_desc, &output_c_data, is_training, + reserve_space_allocator, workspace_allocator, + output_profile_result) .ok(); return launch_success ? Status::OK() @@ -905,17 +925,36 @@ Status DoBackward( // Creates a memory callback for the workspace. The memory lives to the end // of this kernel calls. Stream* stream = context->op_device_context()->stream(); + + Tensor seq_lengths_tensor; + DeviceMemory seq_lengths_ptr; + if (sequence_lengths != nullptr) { + TF_RETURN_IF_ERROR(context->allocate_temp( + DT_INT32, {static_cast(seq_lengths.size())}, + &seq_lengths_tensor)); + seq_lengths_ptr = AsDeviceMemory(&seq_lengths_tensor); + if (!stream + ->ThenMemcpy(&seq_lengths_ptr, seq_lengths.data(), + seq_lengths.size() * sizeof(int)) + .ok()) { + return errors::InvalidArgument( + "Failed to copy memory from host to " + "device for sequence_lengths in " + "CudnnRNNBackwardOpV3"); + } + } + bool launch_success = stream ->ThenRnnBackward( - rnn_desc, *input_desc, input_data, *h_state_desc, input_h_data, - *c_state_desc, input_c_data, params_data, *output_desc, - output_data, *h_state_desc, output_h_data, *c_state_desc, - output_c_data, output_backprop_data, output_h_backprop_data, - output_c_backprop_data, &input_backprop_data, - &input_h_backprop_data, &input_c_backprop_data, - ¶ms_backprop_data, &reserve_space_uint8, workspace_allocator, - output_profile_result) + rnn_desc, *input_desc, input_data, seq_lengths_ptr, *h_state_desc, + input_h_data, *c_state_desc, input_c_data, params_data, + *output_desc, output_data, *h_state_desc, output_h_data, + *c_state_desc, output_c_data, output_backprop_data, + output_h_backprop_data, output_c_backprop_data, + &input_backprop_data, &input_h_backprop_data, + &input_c_backprop_data, ¶ms_backprop_data, + &reserve_space_uint8, workspace_allocator, output_profile_result) .ok(); return launch_success ? Status::OK() diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 01113f89f5e..593619ff084 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1186,6 +1186,9 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { } // Create the params handle. + // TODO(kaixih@nvidia.com): Should be removed when cudnnRNNForward*** and + // cudnnRNNForward***Ex are removed from the codebase, since the new API + // doesn't need param descriptors any more. SE_ASSIGN_OR_RETURN(auto params_desc, CudnnRnnParamsDescriptor::Create( cudnn, input_size, data_type, rnn_desc.get(), @@ -1659,10 +1662,16 @@ port::Status CheckRNNParameterSize( const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc) { size_t params_size_in_bytes = 0; +#if CUDNN_VERSION >= 8000 + RETURN_IF_CUDNN_ERROR(cudnnGetRNNWeightSpaceSize( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), + /*sizeInBytes=*/¶ms_size_in_bytes)); +#else RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/¶ms_size_in_bytes, /*dataType=*/rnn_desc.data_type())); +#endif if (static_cast(params_size_in_bytes) != rnn_desc.ParamsSizeInBytes()) { return port::Status(port::error::INVALID_ARGUMENT, @@ -1747,6 +1756,7 @@ port::Status CudnnSupport::DoRnnForwardImpl( Stream* stream, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const CudnnRnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const CudnnRnnStateTensorDescriptor& input_c_desc, @@ -1770,6 +1780,78 @@ port::Status CudnnSupport::DoRnnForwardImpl( auto cudnn = cudnn_->GetHandle(parent_, stream); SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); + + // In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been + // deprecated. Instead, we use the cudnnRNNForward which requires the + // sequence_lengths parameter. +#if CUDNN_VERSION >= 8000 + if (input_desc.is_var_seq_lengths()) { + DeviceMemory workspace; + DeviceMemory reserve_space; + cudnnForwardMode_t rnn_fwd_mode; + if (is_training) { + rnn_fwd_mode = CUDNN_FWD_MODE_TRAINING; + } else { + rnn_fwd_mode = CUDNN_FWD_MODE_INFERENCE; + } + size_t reserve_space_size_in_bytes = 0; + size_t workspace_size_in_bytes = 0; + RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), + /*fMode=*/rnn_fwd_mode, /*xDesc=*/input_desc.data_handle(), + /*workSpaceSize=*/&workspace_size_in_bytes, + /*reserveSpaceSize=*/&reserve_space_size_in_bytes)); + + if (workspace_size_in_bytes > 0) { + SE_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes( + workspace_size_in_bytes)); + } + if (reserve_space_size_in_bytes > 0) { + SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes( + reserve_space_size_in_bytes)); + } + + std::unique_ptr timer; + const bool is_profiling = output_profile_result != nullptr; + if (is_profiling) { + timer.reset(new GpuTimer(parent_)); + // The start and stop of the timer should be as close to the Cudnn call as + // possible. It is still possible for other threads to issue workload on + // to this stream. So it could take multiple profiling measurements. + if (!timer->Init() || !timer->Start(AsGpuStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to start timer"); + } + } + + RETURN_IF_CUDNN_ERROR(cudnnRNNForward( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), + /*fwdMode=*/rnn_fwd_mode, + /*devSeqLengths=*/ + reinterpret_cast(seq_lengths_data.opaque()), + /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(), + /*yDesc=*/output_desc.data_handle(), /*y=*/output_data->opaque(), + /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(), + /*hy=*/output_h_data->opaque(), + /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(), + /*cy=*/output_c_data->opaque(), + /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(), + /*weightSpace=*/params.opaque(), + /*workSpaceSize=*/workspace.size(), /*workspace=*/workspace.opaque(), + /*reserveSpaceSizeInBytes=*/reserve_space.size(), + /*reserveSpace=*/reserve_space.opaque())); + + if (is_profiling) { + if (!timer->Stop(AsGpuStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to stop timer"); + } + auto algo_desc = *rnn_desc.algorithm_config().algorithm(); + output_profile_result->set_algorithm(algo_desc); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); + } + return port::Status::OK(); + } +#endif SE_ASSIGN_OR_RETURN(DeviceMemory workspace, CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, workspace_allocator)) @@ -1834,7 +1916,6 @@ port::Status CudnnSupport::DoRnnForwardImpl( } } else { if (input_desc.is_var_seq_lengths()) { - // cudnnSetRNNPaddingMode(rnn_desc.handle(), CUDNN_RNN_PADDED_IO_ENABLED); RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTrainingEx( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(), @@ -1887,6 +1968,7 @@ port::Status CudnnSupport::DoRnnBackwardImpl( Stream* stream, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const CudnnRnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const CudnnRnnStateTensorDescriptor& input_c_desc, @@ -1917,6 +1999,91 @@ port::Status CudnnSupport::DoRnnBackwardImpl( auto cudnn = cudnn_->GetHandle(parent_, stream); SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); + + // In CUDNN v8.0, the cudnnRNNForward*** and cudnnRNNForward***Ex have been + // deprecated. Instead, we use the cudnnRNNForward which requires the + // sequence_lengths parameter. +#if CUDNN_VERSION >= 8000 + if (input_desc.is_var_seq_lengths()) { + DeviceMemory workspace; + size_t workspace_size_in_bytes = 0; + RETURN_IF_CUDNN_ERROR(cudnnGetRNNTempSpaceSizes( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), + /*fMode=*/CUDNN_FWD_MODE_TRAINING, /*xDesc=*/input_desc.data_handle(), + /*workSpaceSize=*/&workspace_size_in_bytes, + /*reserveSpaceSize=*/NULL)); + if (workspace_size_in_bytes > 0) { + SE_ASSIGN_OR_RETURN(workspace, workspace_allocator->AllocateBytes( + workspace_size_in_bytes)); + } + + std::unique_ptr timer; + const bool is_profiling = output_profile_result != nullptr; + if (is_profiling) { + timer.reset(new GpuTimer(parent_)); + // The start and stop of the timer should be as close to the Cudnn call as + // possible. It is still possible for other threads to issue workload on + // to this stream. So it could take multiple profiling measurements. + if (!timer->Init() || !timer->Start(AsGpuStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to start timer"); + } + } + + RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData_v8( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), + /*devSeqLengths=*/ + reinterpret_cast(seq_lengths_data.opaque()), + /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(), + /*dy=*/output_backprop_data.opaque(), + /*xDesc=*/input_desc.data_handle(), + /*dx=*/input_backprop_data->opaque(), + /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(), + /*dhy=*/output_h_backprop_data.opaque(), + /*dhx=*/input_h_backprop_data->opaque(), + /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(), + /*dcy=*/output_c_backprop_data.opaque(), + /*dcx=*/input_c_backprop_data->opaque(), + /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(), + /*weightSpace=*/params.opaque(), + /*workSpaceSize=*/workspace.size(), /*workSpace=*/workspace.opaque(), + /*reserveSpaceSize=*/reserve_space_data->size(), + /*reserveSpace=*/reserve_space_data->opaque())); + + if (params_backprop_data != nullptr) { + // Clear the dw to zeros. + stream->ThenMemZero(params_backprop_data, params_backprop_data->size()); + RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights_v8( + /*handle=*/cudnn.handle(), + /*rnnDesc=*/rnn_desc.handle(), + /*addGrad=*/CUDNN_WGRAD_MODE_ADD, + /*devSeqLengths=*/ + reinterpret_cast(seq_lengths_data.opaque()), + /*xDesc=*/input_desc.data_handle(), + /*x=*/input_data.opaque(), + /*hDesc=*/input_h_desc.handle(), + /*hx=*/input_h_data.opaque(), + /*yDesc=*/output_desc.data_handle(), + /*y=*/output_data.opaque(), + /*weightSpaceSize=*/rnn_desc.ParamsSizeInBytes(), + /*dweightSpace=*/params_backprop_data->opaque(), + /*workSpaceSize=*/workspace.size(), + /*workSpace=*/workspace.opaque(), + /*reserveSpaceSize=*/reserve_space_data->size(), + /*reserveSpace=*/reserve_space_data->opaque())); + } + + if (is_profiling) { + if (!timer->Stop(AsGpuStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to stop timer"); + } + auto algo_desc = *rnn_desc.algorithm_config().algorithm(); + output_profile_result->set_algorithm(algo_desc); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); + } + return port::Status::OK(); + } +#endif SE_ASSIGN_OR_RETURN(DeviceMemory workspace, CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, workspace_allocator)); @@ -2127,6 +2294,7 @@ bool CudnnSupport::DoRnnForward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2158,10 +2326,11 @@ bool CudnnSupport::DoRnnForward( return IsStatusOk( DoRnnForwardImpl( stream, cudnn_rnn_desc, cudnn_input_desc, input_data, - cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, - params, cudnn_output_desc, output_data, cudnn_output_h_desc, - output_h_data, cudnn_output_c_desc, output_c_data, is_training, - reserve_space_allocator, workspace_allocator, output_profile_result), + seq_lengths_data, cudnn_input_h_desc, input_h_data, + cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, + output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, + output_c_data, is_training, reserve_space_allocator, + workspace_allocator, output_profile_result), /*report_error=*/!output_profile_result); } @@ -2169,6 +2338,7 @@ bool CudnnSupport::DoRnnForward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2199,10 +2369,11 @@ bool CudnnSupport::DoRnnForward( return IsStatusOk( DoRnnForwardImpl( stream, cudnn_rnn_desc, cudnn_input_desc, input_data, - cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, - params, cudnn_output_desc, output_data, cudnn_output_h_desc, - output_h_data, cudnn_output_c_desc, output_c_data, is_training, - reserve_space_allocator, workspace_allocator, output_profile_result), + seq_lengths_data, cudnn_input_h_desc, input_h_data, + cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, + output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, + output_c_data, is_training, reserve_space_allocator, + workspace_allocator, output_profile_result), /*report_error=*/!output_profile_result); } @@ -2210,6 +2381,7 @@ bool CudnnSupport::DoRnnForward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2241,10 +2413,11 @@ bool CudnnSupport::DoRnnForward( return IsStatusOk( DoRnnForwardImpl( stream, cudnn_rnn_desc, cudnn_input_desc, input_data, - cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, - params, cudnn_output_desc, output_data, cudnn_output_h_desc, - output_h_data, cudnn_output_c_desc, output_c_data, is_training, - reserve_space_allocator, workspace_allocator, output_profile_result), + seq_lengths_data, cudnn_input_h_desc, input_h_data, + cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, + output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, + output_c_data, is_training, reserve_space_allocator, + workspace_allocator, output_profile_result), /*report_error=*/!output_profile_result); } @@ -2252,6 +2425,7 @@ bool CudnnSupport::DoRnnBackward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2290,13 +2464,13 @@ bool CudnnSupport::DoRnnBackward( return IsStatusOk( DoRnnBackwardImpl( stream, cudnn_rnn_desc, cudnn_input_desc, input_data, - cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, - params, cudnn_output_desc, output_data, cudnn_output_h_desc, - output_h_data, cudnn_output_c_desc, output_c_data, - output_backprop_data, output_h_backprop_data, output_c_backprop_data, - input_backprop_data, input_h_backprop_data, input_c_backprop_data, - params_backprop_data, reserve_space_data, workspace_allocator, - output_profile_result), + seq_lengths_data, cudnn_input_h_desc, input_h_data, + cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, + output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, + output_c_data, output_backprop_data, output_h_backprop_data, + output_c_backprop_data, input_backprop_data, input_h_backprop_data, + input_c_backprop_data, params_backprop_data, reserve_space_data, + workspace_allocator, output_profile_result), /*report_error=*/!output_profile_result); } @@ -2304,6 +2478,7 @@ bool CudnnSupport::DoRnnBackward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2341,13 +2516,13 @@ bool CudnnSupport::DoRnnBackward( return IsStatusOk( DoRnnBackwardImpl( stream, cudnn_rnn_desc, cudnn_input_desc, input_data, - cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, - params, cudnn_output_desc, output_data, cudnn_output_h_desc, - output_h_data, cudnn_output_c_desc, output_c_data, - output_backprop_data, output_h_backprop_data, output_c_backprop_data, - input_backprop_data, input_h_backprop_data, input_c_backprop_data, - params_backprop_data, reserve_space_data, workspace_allocator, - output_profile_result), + seq_lengths_data, cudnn_input_h_desc, input_h_data, + cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, + output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, + output_c_data, output_backprop_data, output_h_backprop_data, + output_c_backprop_data, input_backprop_data, input_h_backprop_data, + input_c_backprop_data, params_backprop_data, reserve_space_data, + workspace_allocator, output_profile_result), /*report_error=*/!output_profile_result); } @@ -2355,6 +2530,7 @@ bool CudnnSupport::DoRnnBackward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2393,13 +2569,13 @@ bool CudnnSupport::DoRnnBackward( return IsStatusOk( DoRnnBackwardImpl( stream, cudnn_rnn_desc, cudnn_input_desc, input_data, - cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, - params, cudnn_output_desc, output_data, cudnn_output_h_desc, - output_h_data, cudnn_output_c_desc, output_c_data, - output_backprop_data, output_h_backprop_data, output_c_backprop_data, - input_backprop_data, input_h_backprop_data, input_c_backprop_data, - params_backprop_data, reserve_space_data, workspace_allocator, - output_profile_result), + seq_lengths_data, cudnn_input_h_desc, input_h_data, + cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, + output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, + output_c_data, output_backprop_data, output_h_backprop_data, + output_c_backprop_data, input_backprop_data, input_h_backprop_data, + input_c_backprop_data, params_backprop_data, reserve_space_data, + workspace_allocator, output_profile_result), /*report_error=*/!output_profile_result); } diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 9cab982c9a1..941260e460c 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -74,6 +74,7 @@ class CudnnSupport : public dnn::DnnSupport { bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -92,6 +93,7 @@ class CudnnSupport : public dnn::DnnSupport { bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -110,6 +112,7 @@ class CudnnSupport : public dnn::DnnSupport { bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -128,6 +131,7 @@ class CudnnSupport : public dnn::DnnSupport { bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -153,6 +157,7 @@ class CudnnSupport : public dnn::DnnSupport { bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -178,6 +183,7 @@ class CudnnSupport : public dnn::DnnSupport { bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -641,6 +647,7 @@ class CudnnSupport : public dnn::DnnSupport { Stream* stream, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const CudnnRnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const CudnnRnnStateTensorDescriptor& input_c_desc, @@ -660,6 +667,7 @@ class CudnnSupport : public dnn::DnnSupport { Stream* stream, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const CudnnRnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const CudnnRnnStateTensorDescriptor& input_c_desc, diff --git a/tensorflow/stream_executor/cuda/cudnn_8_0.inc b/tensorflow/stream_executor/cuda/cudnn_8_0.inc index 9161dbc8cf9..52a9d7cd2bd 100644 --- a/tensorflow/stream_executor/cuda/cudnn_8_0.inc +++ b/tensorflow/stream_executor/cuda/cudnn_8_0.inc @@ -1786,6 +1786,16 @@ cudnnStatus_t CUDNNWINAPI cudnnSetPersistentRNNPlan( return func_ptr(rnnDesc, plan); } +cudnnStatus_t CUDNNWINAPI +cudnnGetRNNWeightSpaceSize(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + size_t *weightSpaceSize) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, + cudnnRNNDescriptor_t, size_t *); + static auto func_ptr = LoadSymbol("cudnnGetRNNWeightSpaceSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, rnnDesc, weightSpaceSize); +} + cudnnStatus_t CUDNNWINAPI cudnnGetRNNWorkspaceSize( cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, const int seqLength, const cudnnTensorDescriptor_t *xDesc, @@ -1798,6 +1808,19 @@ cudnnStatus_t CUDNNWINAPI cudnnGetRNNWorkspaceSize( return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes); } +cudnnStatus_t CUDNNWINAPI cudnnGetRNNTempSpaceSizes( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + cudnnForwardMode_t fMode, cudnnRNNDataDescriptor_t xDesc, + size_t *workSpaceSize, size_t *reserveSpaceSize) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, cudnnForwardMode_t, + cudnnRNNDataDescriptor_t, size_t *, size_t *); + static auto func_ptr = LoadSymbol("cudnnGetRNNTempSpaceSizes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, rnnDesc, fMode, xDesc, workSpaceSize, + reserveSpaceSize); +} + cudnnStatus_t CUDNNWINAPI cudnnGetRNNParamsSize(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, const cudnnTensorDescriptor_t xDesc, size_t *sizeInBytes, @@ -2748,6 +2771,28 @@ cudnnStatus_t CUDNNWINAPI cudnnRNNForwardTrainingEx( reserveSpace, reserveSpaceSizeInBytes); } +cudnnStatus_t CUDNNWINAPI cudnnRNNForward( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + cudnnForwardMode_t fwdMode, const int32_t devSeqLengths[], + cudnnRNNDataDescriptor_t xDesc, const void *x, + cudnnRNNDataDescriptor_t yDesc, void *y, cudnnTensorDescriptor_t hDesc, + const void *hx, void *hy, cudnnTensorDescriptor_t cDesc, const void *cx, + void *cy, size_t weightSpaceSize, const void *weightSpace, + size_t workSpaceSize, void *workSpace, size_t reserveSpaceSize, + void *reserveSpace) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, cudnnForwardMode_t, const int32_t[], + cudnnRNNDataDescriptor_t, const void *, cudnnRNNDataDescriptor_t, void *, + cudnnTensorDescriptor_t, const void *, void *, cudnnTensorDescriptor_t, + const void *, void *, size_t, const void *, size_t, void *, size_t, + void *); + static auto func_ptr = LoadSymbol("cudnnRNNForward"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, rnnDesc, fwdMode, devSeqLengths, xDesc, x, yDesc, y, + hDesc, hx, hy, cDesc, cx, cy, weightSpaceSize, weightSpace, + workSpaceSize, workSpace, reserveSpaceSize, reserveSpace); +} + cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardDataEx( cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, const cudnnRNNDataDescriptor_t yDesc, const void *y, @@ -2787,6 +2832,28 @@ cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardDataEx( reserveSpaceSizeInBytes); } +cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardData_v8( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + const int32_t devSeqLengths[], cudnnRNNDataDescriptor_t yDesc, + const void *y, const void *dy, cudnnRNNDataDescriptor_t xDesc, void *dx, + cudnnTensorDescriptor_t hDesc, const void *hx, const void *dhy, void *dhx, + cudnnTensorDescriptor_t cDesc, const void *cx, const void *dcy, void *dcx, + size_t weightSpaceSize, const void *weightSpace, size_t workSpaceSize, + void *workSpace, size_t reserveSpaceSize, void *reserveSpace) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int32_t[], + cudnnRNNDataDescriptor_t, const void *, const void *, + cudnnRNNDataDescriptor_t, void *, cudnnTensorDescriptor_t, const void *, + const void *, void *, cudnnTensorDescriptor_t, const void *, const void *, + void *, size_t, const void *, size_t, void *, size_t, void *); + static auto func_ptr = LoadSymbol("cudnnRNNBackwardData_v8"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, rnnDesc, devSeqLengths, yDesc, y, dy, xDesc, dx, + hDesc, hx, dhy, dhx, cDesc, cx, dcy, dcx, weightSpaceSize, + weightSpace, workSpaceSize, workSpace, reserveSpaceSize, + reserveSpace); +} + cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeightsEx( cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, const cudnnRNNDataDescriptor_t xDesc, const void *x, @@ -2806,6 +2873,26 @@ cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeightsEx( reserveSpaceSizeInBytes); } +cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeights_v8( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + cudnnWgradMode_t addGrad, const int32_t devSeqLengths[], + cudnnRNNDataDescriptor_t xDesc, const void *x, + cudnnTensorDescriptor_t hDesc, const void *hx, + cudnnRNNDataDescriptor_t yDesc, const void *y, size_t weightSpaceSize, + void *dweightSpace, size_t workSpaceSize, void *workSpace, + size_t reserveSpaceSize, void *reserveSpace) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, cudnnWgradMode_t, const int32_t[], + cudnnRNNDataDescriptor_t, const void *, cudnnTensorDescriptor_t, + const void *, cudnnRNNDataDescriptor_t, const void *, size_t, void *, + size_t, void *, size_t, void *); + static auto func_ptr = LoadSymbol("cudnnRNNBackwardWeights_v8"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, rnnDesc, addGrad, devSeqLengths, xDesc, x, hDesc, hx, + yDesc, y, weightSpaceSize, dweightSpace, workSpaceSize, + workSpace, reserveSpaceSize, reserveSpace); +} + cudnnStatus_t CUDNNWINAPI cudnnMultiHeadAttnBackwardData( cudnnHandle_t handle, const cudnnAttnDescriptor_t attnDesc, const int loWinIdx[], const int hiWinIdx[], const int devSeqLengthsDQDO[], diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 920f5fe246c..6ca42340d5b 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -2185,6 +2185,7 @@ class DnnSupport { virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2206,6 +2207,7 @@ class DnnSupport { virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2227,6 +2229,7 @@ class DnnSupport { virtual bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2289,6 +2292,7 @@ class DnnSupport { Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2317,6 +2321,7 @@ class DnnSupport { Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2345,6 +2350,7 @@ class DnnSupport { Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.cc b/tensorflow/stream_executor/rocm/rocm_dnn.cc index 8c1596331f3..2e0a865e41e 100644 --- a/tensorflow/stream_executor/rocm/rocm_dnn.cc +++ b/tensorflow/stream_executor/rocm/rocm_dnn.cc @@ -2578,6 +2578,7 @@ bool MIOpenSupport::DoRnnForward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2621,6 +2622,7 @@ bool MIOpenSupport::DoRnnForward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2663,6 +2665,7 @@ bool MIOpenSupport::DoRnnForward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2685,6 +2688,7 @@ bool MIOpenSupport::DoRnnBackward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2737,6 +2741,7 @@ bool MIOpenSupport::DoRnnBackward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -2788,6 +2793,7 @@ bool MIOpenSupport::DoRnnBackward( Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.h b/tensorflow/stream_executor/rocm/rocm_dnn.h index 654a1bf8f3a..11f1a1dd86d 100644 --- a/tensorflow/stream_executor/rocm/rocm_dnn.h +++ b/tensorflow/stream_executor/rocm/rocm_dnn.h @@ -101,6 +101,7 @@ class MIOpenSupport : public dnn::DnnSupport { bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -119,6 +120,7 @@ class MIOpenSupport : public dnn::DnnSupport { bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -137,6 +139,7 @@ class MIOpenSupport : public dnn::DnnSupport { bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -155,6 +158,7 @@ class MIOpenSupport : public dnn::DnnSupport { bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -180,6 +184,7 @@ class MIOpenSupport : public dnn::DnnSupport { bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, @@ -205,6 +210,7 @@ class MIOpenSupport : public dnn::DnnSupport { bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc, const dnn::RnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, + const DeviceMemory& seq_lengths_data, const dnn::RnnStateTensorDescriptor& input_h_desc, const DeviceMemory& input_h_data, const dnn::RnnStateTensorDescriptor& input_c_desc, diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 4ad9fc128cc..ccdb467a03d 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -4539,6 +4539,7 @@ Stream &Stream::ThenRnnForward( const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -4556,10 +4557,11 @@ Stream &Stream::ThenRnnForward( // TODO(zhengxq): add VLOG PARAM calls. if (dnn::DnnSupport *dnn = parent_->AsDnn()) { auto status = dnn->DoRnnForward( - this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, - input_c_desc, input_c_data, params, output_desc, output_data, - output_h_desc, output_h_data, output_c_desc, output_c_data, is_training, - reserve_space_allocator, workspace_allocator, output_profile_result); + this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, + input_h_data, input_c_desc, input_c_data, params, output_desc, + output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, + is_training, reserve_space_allocator, workspace_allocator, + output_profile_result); if (!status && !output_profile_result) { SetError(); } @@ -4573,6 +4575,7 @@ Stream &Stream::ThenRnnForward( const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -4589,10 +4592,11 @@ Stream &Stream::ThenRnnForward( // TODO(zhengxq): add VLOG PARAM calls. if (dnn::DnnSupport *dnn = parent_->AsDnn()) { auto status = dnn->DoRnnForward( - this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, - input_c_desc, input_c_data, params, output_desc, output_data, - output_h_desc, output_h_data, output_c_desc, output_c_data, is_training, - reserve_space_allocator, workspace_allocator, output_profile_result); + this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, + input_h_data, input_c_desc, input_c_data, params, output_desc, + output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, + is_training, reserve_space_allocator, workspace_allocator, + output_profile_result); if (!status && !output_profile_result) { SetError(); } @@ -4606,6 +4610,7 @@ Stream &Stream::ThenRnnForward( const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -4623,10 +4628,11 @@ Stream &Stream::ThenRnnForward( // TODO(zhengxq): add VLOG PARAM calls. if (dnn::DnnSupport *dnn = parent_->AsDnn()) { auto status = dnn->DoRnnForward( - this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, - input_c_desc, input_c_data, params, output_desc, output_data, - output_h_desc, output_h_data, output_c_desc, output_c_data, is_training, - reserve_space_allocator, workspace_allocator, output_profile_result); + this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, + input_h_data, input_c_desc, input_c_data, params, output_desc, + output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, + is_training, reserve_space_allocator, workspace_allocator, + output_profile_result); if (!status && !output_profile_result) { SetError(); } @@ -4640,6 +4646,7 @@ Stream &Stream::ThenRnnBackward( const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -4664,9 +4671,9 @@ Stream &Stream::ThenRnnBackward( // TODO(zhengxq): add VLOG PARAM calls. if (dnn::DnnSupport *dnn = parent_->AsDnn()) { auto status = dnn->DoRnnBackward( - this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, - input_c_desc, input_c_data, params, output_desc, output_data, - output_h_desc, output_h_data, output_c_desc, output_c_data, + this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, + input_h_data, input_c_desc, input_c_data, params, output_desc, + output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, output_backprop_data, output_h_backprop_data, output_c_backprop_data, input_backprop_data, input_h_backprop_data, input_c_backprop_data, params_backprop_data, reserve_space_data, workspace_allocator, @@ -4685,6 +4692,7 @@ Stream &Stream::ThenRnnBackward( const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -4708,9 +4716,9 @@ Stream &Stream::ThenRnnBackward( // TODO(zhengxq): add VLOG PARAM calls. if (dnn::DnnSupport *dnn = parent_->AsDnn()) { auto status = dnn->DoRnnBackward( - this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, - input_c_desc, input_c_data, params, output_desc, output_data, - output_h_desc, output_h_data, output_c_desc, output_c_data, + this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, + input_h_data, input_c_desc, input_c_data, params, output_desc, + output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, output_backprop_data, output_h_backprop_data, output_c_backprop_data, input_backprop_data, input_h_backprop_data, input_c_backprop_data, params_backprop_data, reserve_space_data, workspace_allocator, @@ -4729,6 +4737,7 @@ Stream &Stream::ThenRnnBackward( const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -4753,9 +4762,9 @@ Stream &Stream::ThenRnnBackward( // TODO(zhengxq): add VLOG PARAM calls. if (dnn::DnnSupport *dnn = parent_->AsDnn()) { auto status = dnn->DoRnnBackward( - this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, - input_c_desc, input_c_data, params, output_desc, output_data, - output_h_desc, output_h_data, output_c_desc, output_c_data, + this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc, + input_h_data, input_c_desc, input_c_data, params, output_desc, + output_data, output_h_desc, output_h_data, output_c_desc, output_c_data, output_backprop_data, output_h_backprop_data, output_c_backprop_data, input_backprop_data, input_h_backprop_data, input_c_backprop_data, params_backprop_data, reserve_space_data, workspace_allocator, diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index cb038c9ee67..e214ee47513 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -1779,6 +1779,7 @@ class Stream { Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -1798,6 +1799,7 @@ class Stream { Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -1816,6 +1818,7 @@ class Stream { Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -1837,6 +1840,7 @@ class Stream { const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -1862,6 +1866,7 @@ class Stream { Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc, @@ -1887,6 +1892,7 @@ class Stream { Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc, const dnn::RnnSequenceTensorDescriptor &input_desc, const DeviceMemory &input_data, + const DeviceMemory &seq_lengths_data, const dnn::RnnStateTensorDescriptor &input_h_desc, const DeviceMemory &input_h_data, const dnn::RnnStateTensorDescriptor &input_c_desc,