Don't use the cuDNN GPU kernel for LSTMs when inputs are RaggedTensors.
PiperOrigin-RevId: 313809073 Change-Id: I368307d458f8b23c320b9c4df31c81d18e7ab43d
This commit is contained in:
parent
55987dbb42
commit
e3e7bd4bf3
|
@ -413,7 +413,9 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU):
|
|||
input_shape = K.int_shape(inputs)
|
||||
timesteps = input_shape[0] if self.time_major else input_shape[1]
|
||||
|
||||
if not self._could_use_gpu_kernel:
|
||||
# TODO(b/156447398) Investigate why the cuDNN kernel kernel fails with
|
||||
# ragged inputs.
|
||||
if is_ragged_input or not self._could_use_gpu_kernel:
|
||||
kwargs = {'training': training}
|
||||
self._maybe_reset_cell_dropout_mask(self.cell)
|
||||
|
||||
|
@ -1109,7 +1111,9 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM):
|
|||
input_shape = K.int_shape(inputs)
|
||||
timesteps = input_shape[0] if self.time_major else input_shape[1]
|
||||
|
||||
if not self._could_use_gpu_kernel:
|
||||
# TODO(b/156447398) Investigate why the cuDNN kernel kernel fails with
|
||||
# ragged inputs.
|
||||
if is_ragged_input or not self._could_use_gpu_kernel:
|
||||
# Fall back to use the normal LSTM.
|
||||
kwargs = {'training': training}
|
||||
self._maybe_reset_cell_dropout_mask(self.cell)
|
||||
|
|
|
@ -30,7 +30,9 @@ from tensorflow.python.eager import context
|
|||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.layers import embeddings
|
||||
from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
|
@ -113,6 +115,16 @@ class RNNV2Test(keras_parameterized.TestCase):
|
|||
model = keras.Model(inputs=inputs, outputs=outputs, name='3_layer')
|
||||
model.save(os.path.join(self.get_temp_dir(), 'model'), save_format='tf')
|
||||
|
||||
@parameterized.parameters([rnn_v2.LSTM, rnn_v2.GRU])
|
||||
def test_ragged(self, layer):
|
||||
vocab_size = 100
|
||||
inputs = ragged_factory_ops.constant(
|
||||
np.random.RandomState(0).randint(0, vocab_size, [128, 25]))
|
||||
embedder = embeddings.Embedding(input_dim=vocab_size, output_dim=16)
|
||||
embedded_inputs = embedder(inputs)
|
||||
lstm = layer(32)
|
||||
lstm(embedded_inputs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
|
Loading…
Reference in New Issue