Test LSTM v2 layer in eager mode in dist strat keras tests
PiperOrigin-RevId: 258916770
This commit is contained in:
parent
3e1e849214
commit
2ecc2fffad
@ -19,11 +19,16 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.keras.distribute import keras_correctness_test_base
|
||||
from tensorflow.python.keras.layers import recurrent as rnn_v1
|
||||
from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
|
||||
|
||||
|
||||
class DistributionStrategyLstmModelCorrectnessTest(
|
||||
keras_correctness_test_base
|
||||
.TestDistributionStrategyEmbeddingModelCorrectnessBase):
|
||||
@ -35,13 +40,20 @@ class DistributionStrategyLstmModelCorrectnessTest(
|
||||
run_distributed=None,
|
||||
input_shapes=None):
|
||||
del input_shapes
|
||||
|
||||
if tf2.enabled():
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest("LSTM v2 and legacy graph mode don't work together.")
|
||||
lstm = rnn_v2.LSTM
|
||||
else:
|
||||
lstm = rnn_v1.LSTM
|
||||
|
||||
with keras_correctness_test_base.MaybeDistributionScope(distribution):
|
||||
word_ids = keras.layers.Input(
|
||||
shape=(max_words,), dtype=np.int32, name='words')
|
||||
word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids)
|
||||
lstm_embed = keras.layers.LSTM(
|
||||
units=4, return_sequences=False)(
|
||||
word_embed)
|
||||
lstm_embed = lstm(units=4, return_sequences=False)(
|
||||
word_embed)
|
||||
|
||||
preds = keras.layers.Dense(2, activation='softmax')(lstm_embed)
|
||||
model = keras.Model(inputs=[word_ids], outputs=[preds])
|
||||
|
Loading…
Reference in New Issue
Block a user