Don't use the cuDNN GPU kernel for LSTMs when inputs are RaggedTensors.
PiperOrigin-RevId: 313809073 Change-Id: I368307d458f8b23c320b9c4df31c81d18e7ab43d
This commit is contained in:
parent
55987dbb42
commit
e3e7bd4bf3
|
@ -413,7 +413,9 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU):
|
||||||
input_shape = K.int_shape(inputs)
|
input_shape = K.int_shape(inputs)
|
||||||
timesteps = input_shape[0] if self.time_major else input_shape[1]
|
timesteps = input_shape[0] if self.time_major else input_shape[1]
|
||||||
|
|
||||||
if not self._could_use_gpu_kernel:
|
# TODO(b/156447398) Investigate why the cuDNN kernel kernel fails with
|
||||||
|
# ragged inputs.
|
||||||
|
if is_ragged_input or not self._could_use_gpu_kernel:
|
||||||
kwargs = {'training': training}
|
kwargs = {'training': training}
|
||||||
self._maybe_reset_cell_dropout_mask(self.cell)
|
self._maybe_reset_cell_dropout_mask(self.cell)
|
||||||
|
|
||||||
|
@ -1109,7 +1111,9 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM):
|
||||||
input_shape = K.int_shape(inputs)
|
input_shape = K.int_shape(inputs)
|
||||||
timesteps = input_shape[0] if self.time_major else input_shape[1]
|
timesteps = input_shape[0] if self.time_major else input_shape[1]
|
||||||
|
|
||||||
if not self._could_use_gpu_kernel:
|
# TODO(b/156447398) Investigate why the cuDNN kernel kernel fails with
|
||||||
|
# ragged inputs.
|
||||||
|
if is_ragged_input or not self._could_use_gpu_kernel:
|
||||||
# Fall back to use the normal LSTM.
|
# Fall back to use the normal LSTM.
|
||||||
kwargs = {'training': training}
|
kwargs = {'training': training}
|
||||||
self._maybe_reset_cell_dropout_mask(self.cell)
|
self._maybe_reset_cell_dropout_mask(self.cell)
|
||||||
|
|
|
@ -30,7 +30,9 @@ from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
|
from tensorflow.python.keras.layers import embeddings
|
||||||
from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2
|
from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2
|
||||||
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,6 +115,16 @@ class RNNV2Test(keras_parameterized.TestCase):
|
||||||
model = keras.Model(inputs=inputs, outputs=outputs, name='3_layer')
|
model = keras.Model(inputs=inputs, outputs=outputs, name='3_layer')
|
||||||
model.save(os.path.join(self.get_temp_dir(), 'model'), save_format='tf')
|
model.save(os.path.join(self.get_temp_dir(), 'model'), save_format='tf')
|
||||||
|
|
||||||
|
@parameterized.parameters([rnn_v2.LSTM, rnn_v2.GRU])
|
||||||
|
def test_ragged(self, layer):
|
||||||
|
vocab_size = 100
|
||||||
|
inputs = ragged_factory_ops.constant(
|
||||||
|
np.random.RandomState(0).randint(0, vocab_size, [128, 25]))
|
||||||
|
embedder = embeddings.Embedding(input_dim=vocab_size, output_dim=16)
|
||||||
|
embedded_inputs = embedder(inputs)
|
||||||
|
lstm = layer(32)
|
||||||
|
lstm(embedded_inputs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
|
Loading…
Reference in New Issue