Prevent CHECK-fail in LSTM/GRU with zero-length input.
PiperOrigin-RevId: 346239181 Change-Id: I5f233dbc076aab7bb4e31ba24f5abd4eaf99ea4f
This commit is contained in:
parent
042a6923d7
commit
14755416e3
@ -1468,7 +1468,9 @@ class CudnnRnnSequenceTensorDescriptor
|
||||
static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
|
||||
GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
|
||||
cudnnDataType_t data_type) {
|
||||
CHECK_GT(max_seq_length, 0);
|
||||
if (max_seq_length <= 0) {
|
||||
return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
|
||||
}
|
||||
int dims[] = {batch_size, data_size, 1};
|
||||
int strides[] = {dims[1] * dims[2], dims[2], 1};
|
||||
TensorDescriptor tensor_desc = CreateTensorDescriptor();
|
||||
@ -1486,7 +1488,9 @@ class CudnnRnnSequenceTensorDescriptor
|
||||
GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
|
||||
const absl::Span<const int>& seq_lengths, bool time_major,
|
||||
cudnnDataType_t data_type) {
|
||||
CHECK_GT(max_seq_length, 0);
|
||||
if (max_seq_length <= 0) {
|
||||
return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
|
||||
}
|
||||
int dims[] = {batch_size, data_size, 1};
|
||||
int strides[] = {dims[1] * dims[2], dims[2], 1};
|
||||
TensorDescriptor tensor_desc = CreateTensorDescriptor();
|
||||
|
Loading…
Reference in New Issue
Block a user