Don't use the cuDNN GPU kernel for LSTMs when inputs are RaggedTensors.

PiperOrigin-RevId: 313809073
Change-Id: I368307d458f8b23c320b9c4df31c81d18e7ab43d
This commit is contained in:
Edward Loper 2020-05-29 10:45:00 -07:00 committed by TensorFlower Gardener
parent 55987dbb42
commit e3e7bd4bf3
2 changed files with 18 additions and 2 deletions

View File

@ -413,7 +413,9 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU):
input_shape = K.int_shape(inputs) input_shape = K.int_shape(inputs)
timesteps = input_shape[0] if self.time_major else input_shape[1] timesteps = input_shape[0] if self.time_major else input_shape[1]
if not self._could_use_gpu_kernel: # TODO(b/156447398) Investigate why the cuDNN kernel kernel fails with
# ragged inputs.
if is_ragged_input or not self._could_use_gpu_kernel:
kwargs = {'training': training} kwargs = {'training': training}
self._maybe_reset_cell_dropout_mask(self.cell) self._maybe_reset_cell_dropout_mask(self.cell)
@ -1109,7 +1111,9 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM):
input_shape = K.int_shape(inputs) input_shape = K.int_shape(inputs)
timesteps = input_shape[0] if self.time_major else input_shape[1] timesteps = input_shape[0] if self.time_major else input_shape[1]
if not self._could_use_gpu_kernel: # TODO(b/156447398) Investigate why the cuDNN kernel kernel fails with
# ragged inputs.
if is_ragged_input or not self._could_use_gpu_kernel:
# Fall back to use the normal LSTM. # Fall back to use the normal LSTM.
kwargs = {'training': training} kwargs = {'training': training}
self._maybe_reset_cell_dropout_mask(self.cell) self._maybe_reset_cell_dropout_mask(self.cell)

View File

@ -30,7 +30,9 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.layers import embeddings
from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2 from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -113,6 +115,16 @@ class RNNV2Test(keras_parameterized.TestCase):
model = keras.Model(inputs=inputs, outputs=outputs, name='3_layer') model = keras.Model(inputs=inputs, outputs=outputs, name='3_layer')
model.save(os.path.join(self.get_temp_dir(), 'model'), save_format='tf') model.save(os.path.join(self.get_temp_dir(), 'model'), save_format='tf')
@parameterized.parameters([rnn_v2.LSTM, rnn_v2.GRU])
def test_ragged(self, layer):
vocab_size = 100
inputs = ragged_factory_ops.constant(
np.random.RandomState(0).randint(0, vocab_size, [128, 25]))
embedder = embeddings.Embedding(input_dim=vocab_size, output_dim=16)
embedded_inputs = embedder(inputs)
lstm = layer(32)
lstm(embedded_inputs)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()