From ec544f8099981be897463a6b39b8a7a1d6f0f62d Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Mon, 7 Dec 2020 20:31:31 -0800 Subject: [PATCH] Prevent CHECK-fail in LSTM/GRU with zero-length input. PiperOrigin-RevId: 346239181 Change-Id: I5f233dbc076aab7bb4e31ba24f5abd4eaf99ea4f --- tensorflow/stream_executor/cuda/cuda_dnn.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 70cc11a3e03..b53ad905991 100755 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1383,7 +1383,9 @@ class CudnnRnnSequenceTensorDescriptor static port::StatusOr Create( GpuExecutor* parent, int max_seq_length, int batch_size, int data_size, cudnnDataType_t data_type) { - CHECK_GT(max_seq_length, 0); + if (max_seq_length <= 0) { + return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0"); + } int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; TensorDescriptor tensor_desc = CreateTensorDescriptor(); @@ -1404,7 +1406,9 @@ class CudnnRnnSequenceTensorDescriptor const absl::Span& seq_lengths, bool time_major, cudnnDataType_t data_type) { #if CUDNN_VERSION >= 7201 - CHECK_GT(max_seq_length, 0); + if (max_seq_length <= 0) { + return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0"); + } int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; TensorDescriptor tensor_desc = CreateTensorDescriptor();