Test LSTM v2 layer in eager mode in dist strat keras tests
PiperOrigin-RevId: 258916770
This commit is contained in:
parent
3e1e849214
commit
2ecc2fffad
@ -19,11 +19,16 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.distribute import combinations
|
from tensorflow.python.distribute import combinations
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.keras.distribute import keras_correctness_test_base
|
from tensorflow.python.keras.distribute import keras_correctness_test_base
|
||||||
|
from tensorflow.python.keras.layers import recurrent as rnn_v1
|
||||||
|
from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2
|
||||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
|
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
|
||||||
|
|
||||||
|
|
||||||
class DistributionStrategyLstmModelCorrectnessTest(
|
class DistributionStrategyLstmModelCorrectnessTest(
|
||||||
keras_correctness_test_base
|
keras_correctness_test_base
|
||||||
.TestDistributionStrategyEmbeddingModelCorrectnessBase):
|
.TestDistributionStrategyEmbeddingModelCorrectnessBase):
|
||||||
@ -35,12 +40,19 @@ class DistributionStrategyLstmModelCorrectnessTest(
|
|||||||
run_distributed=None,
|
run_distributed=None,
|
||||||
input_shapes=None):
|
input_shapes=None):
|
||||||
del input_shapes
|
del input_shapes
|
||||||
|
|
||||||
|
if tf2.enabled():
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
self.skipTest("LSTM v2 and legacy graph mode don't work together.")
|
||||||
|
lstm = rnn_v2.LSTM
|
||||||
|
else:
|
||||||
|
lstm = rnn_v1.LSTM
|
||||||
|
|
||||||
with keras_correctness_test_base.MaybeDistributionScope(distribution):
|
with keras_correctness_test_base.MaybeDistributionScope(distribution):
|
||||||
word_ids = keras.layers.Input(
|
word_ids = keras.layers.Input(
|
||||||
shape=(max_words,), dtype=np.int32, name='words')
|
shape=(max_words,), dtype=np.int32, name='words')
|
||||||
word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids)
|
word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids)
|
||||||
lstm_embed = keras.layers.LSTM(
|
lstm_embed = lstm(units=4, return_sequences=False)(
|
||||||
units=4, return_sequences=False)(
|
|
||||||
word_embed)
|
word_embed)
|
||||||
|
|
||||||
preds = keras.layers.Dense(2, activation='softmax')(lstm_embed)
|
preds = keras.layers.Dense(2, activation='softmax')(lstm_embed)
|
||||||
|
Loading…
Reference in New Issue
Block a user