diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index a7ed5cedb4f..174442887f8 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1465,7 +1465,9 @@ class CudnnRnnSequenceTensorDescriptor static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create( GpuExecutor* parent, int max_seq_length, int batch_size, int data_size, cudnnDataType_t data_type) { - CHECK_GT(max_seq_length, 0); + if (max_seq_length <= 0) { + return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0"); + } int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; TensorDescriptor tensor_desc = CreateTensorDescriptor(); @@ -1483,7 +1485,9 @@ class CudnnRnnSequenceTensorDescriptor GpuExecutor* parent, int max_seq_length, int batch_size, int data_size, const absl::Span<const int>& seq_lengths, bool time_major, cudnnDataType_t data_type) { - CHECK_GT(max_seq_length, 0); + if (max_seq_length <= 0) { + return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0"); + } int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; TensorDescriptor tensor_desc = CreateTensorDescriptor();