Prevent CHECK-fail in LSTM/GRU with zero-length input.

PiperOrigin-RevId: 346239181
Change-Id: I5f233dbc076aab7bb4e31ba24f5abd4eaf99ea4f
This commit is contained in:
Mihai Maruseac 2020-12-07 20:31:31 -08:00 committed by TensorFlower Gardener
parent 042a6923d7
commit 14755416e3

View File

@ -1468,7 +1468,9 @@ class CudnnRnnSequenceTensorDescriptor
static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
cudnnDataType_t data_type) {
CHECK_GT(max_seq_length, 0);
if (max_seq_length <= 0) {
return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
}
int dims[] = {batch_size, data_size, 1};
int strides[] = {dims[1] * dims[2], dims[2], 1};
TensorDescriptor tensor_desc = CreateTensorDescriptor();
@ -1486,7 +1488,9 @@ class CudnnRnnSequenceTensorDescriptor
GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
const absl::Span<const int>& seq_lengths, bool time_major,
cudnnDataType_t data_type) {
CHECK_GT(max_seq_length, 0);
if (max_seq_length <= 0) {
return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
}
int dims[] = {batch_size, data_size, 1};
int strides[] = {dims[1] * dims[2], dims[2], 1};
TensorDescriptor tensor_desc = CreateTensorDescriptor();