From e3e7bd4bf36c0d352695b4dd2d901225d5e9358b Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Fri, 29 May 2020 10:45:00 -0700 Subject: [PATCH] Don't use the cuDNN GPU kernel for LSTMs when inputs are RaggedTensors. PiperOrigin-RevId: 313809073 Change-Id: I368307d458f8b23c320b9c4df31c81d18e7ab43d --- tensorflow/python/keras/layers/recurrent_v2.py | 8 ++++++-- tensorflow/python/keras/layers/recurrent_v2_test.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/layers/recurrent_v2.py b/tensorflow/python/keras/layers/recurrent_v2.py index a9d5ef8587c..adefb689a1f 100644 --- a/tensorflow/python/keras/layers/recurrent_v2.py +++ b/tensorflow/python/keras/layers/recurrent_v2.py @@ -413,7 +413,9 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU): input_shape = K.int_shape(inputs) timesteps = input_shape[0] if self.time_major else input_shape[1] - if not self._could_use_gpu_kernel: + # TODO(b/156447398) Investigate why the cuDNN kernel kernel fails with + # ragged inputs. + if is_ragged_input or not self._could_use_gpu_kernel: kwargs = {'training': training} self._maybe_reset_cell_dropout_mask(self.cell) @@ -1109,7 +1111,9 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM): input_shape = K.int_shape(inputs) timesteps = input_shape[0] if self.time_major else input_shape[1] - if not self._could_use_gpu_kernel: + # TODO(b/156447398) Investigate why the cuDNN kernel kernel fails with + # ragged inputs. + if is_ragged_input or not self._could_use_gpu_kernel: # Fall back to use the normal LSTM. kwargs = {'training': training} self._maybe_reset_cell_dropout_mask(self.cell) diff --git a/tensorflow/python/keras/layers/recurrent_v2_test.py b/tensorflow/python/keras/layers/recurrent_v2_test.py index 4cb964b4bc4..ec70761c8a8 100644 --- a/tensorflow/python/keras/layers/recurrent_v2_test.py +++ b/tensorflow/python/keras/layers/recurrent_v2_test.py @@ -30,7 +30,9 @@ from tensorflow.python.eager import context from tensorflow.python.framework import test_util from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.layers import embeddings from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2 +from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test @@ -113,6 +115,16 @@ class RNNV2Test(keras_parameterized.TestCase): model = keras.Model(inputs=inputs, outputs=outputs, name='3_layer') model.save(os.path.join(self.get_temp_dir(), 'model'), save_format='tf') + @parameterized.parameters([rnn_v2.LSTM, rnn_v2.GRU]) + def test_ragged(self, layer): + vocab_size = 100 + inputs = ragged_factory_ops.constant( + np.random.RandomState(0).randint(0, vocab_size, [128, 25])) + embedder = embeddings.Embedding(input_dim=vocab_size, output_dim=16) + embedded_inputs = embedder(inputs) + lstm = layer(32) + lstm(embedded_inputs) + if __name__ == '__main__': test.main()