From b550171e78e0a085b208d6a3b8b29ed29faa97ae Mon Sep 17 00:00:00 2001 From: Mihai Maruseac <mihaimaruseac@google.com> Date: Mon, 7 Dec 2020 20:31:31 -0800 Subject: [PATCH] Prevent CHECK-fail in LSTM/GRU with zero-length input. PiperOrigin-RevId: 346239181 Change-Id: I5f233dbc076aab7bb4e31ba24f5abd4eaf99ea4f --- tensorflow/stream_executor/cuda/cuda_dnn.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index a7ed5cedb4f..174442887f8 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1465,7 +1465,9 @@ class CudnnRnnSequenceTensorDescriptor static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create( GpuExecutor* parent, int max_seq_length, int batch_size, int data_size, cudnnDataType_t data_type) { - CHECK_GT(max_seq_length, 0); + if (max_seq_length <= 0) { + return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0"); + } int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; TensorDescriptor tensor_desc = CreateTensorDescriptor(); @@ -1483,7 +1485,9 @@ class CudnnRnnSequenceTensorDescriptor GpuExecutor* parent, int max_seq_length, int batch_size, int data_size, const absl::Span<const int>& seq_lengths, bool time_major, cudnnDataType_t data_type) { - CHECK_GT(max_seq_length, 0); + if (max_seq_length <= 0) { + return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0"); + } int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; TensorDescriptor tensor_desc = CreateTensorDescriptor();