From 9442cbb955147567792c37ace0090f7ae9df8360 Mon Sep 17 00:00:00 2001
From: Alexandre Lissy <alissy@mozilla.com>
Date: Tue, 28 Jul 2020 20:41:54 +0200
Subject: [PATCH] Fix #41630: include max_seq_length in cudnn descriptor cache
 key

---
 tensorflow/core/kernels/cudnn_rnn_ops.cc | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc
index b9b96d3fc70..1a3f05fdcd9 100644
--- a/tensorflow/core/kernels/cudnn_rnn_ops.cc
+++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc
@@ -500,6 +500,9 @@ struct CudnnRnnModelShapes {
   int max_seq_length;
   int batch_size;
   int cell_num_units = 0;
+  // If you add new field to this structure, please take care of
+  // updating IsCompatibleWith() below as well as the hash function in
+  // CudnnRnnConfigHasher.
   TensorShape input_shape;
   TensorShape output_shape;
   TensorShape hidden_state_shape;
@@ -508,7 +511,7 @@ struct CudnnRnnModelShapes {
   bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
     return num_layers == rhs.num_layers && input_size == rhs.input_size &&
            num_units == rhs.num_units && dir_count == rhs.dir_count &&
-           cell_num_units == rhs.cell_num_units;
+           cell_num_units == rhs.cell_num_units && max_seq_length == rhs.max_seq_length;
   }
   string DebugString() const {
     return strings::Printf(
@@ -530,7 +533,7 @@ struct CudnnRnnConfigHasher {
 
     uint64 hash =
         HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
-                  shapes.dir_count, shapes.batch_size});
+                  shapes.dir_count, shapes.max_seq_length, shapes.batch_size});
     if (algo_desc.has_value()) {
       hash = Hash64Combine(hash, algo_desc->hash());
     }