diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 2768777426b..ea26fb064e7 100755 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1449,7 +1449,9 @@ class CudnnRnnSequenceTensorDescriptor static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create( GpuExecutor* parent, int max_seq_length, int batch_size, int data_size, cudnnDataType_t data_type) { - CHECK_GT(max_seq_length, 0); + if (max_seq_length <= 0) { + return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0"); + } int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; TensorDescriptor tensor_desc = CreateTensorDescriptor(); @@ -1470,7 +1472,9 @@ class CudnnRnnSequenceTensorDescriptor const absl::Span<const int>& seq_lengths, bool time_major, cudnnDataType_t data_type) { #if CUDNN_VERSION >= 7201 - CHECK_GT(max_seq_length, 0); + if (max_seq_length <= 0) { + return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0"); + } int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; TensorDescriptor tensor_desc = CreateTensorDescriptor();