Test LSTM v2 layer in eager mode in dist strat keras tests

PiperOrigin-RevId: 258916770
This commit is contained in:
Priya Gupta 2019-07-19 00:06:05 -07:00 committed by TensorFlower Gardener
parent 3e1e849214
commit 2ecc2fffad

View File

@ -19,11 +19,16 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.distribute import combinations
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.keras.distribute import keras_correctness_test_base
from tensorflow.python.keras.layers import recurrent as rnn_v1
from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
class DistributionStrategyLstmModelCorrectnessTest(
keras_correctness_test_base
.TestDistributionStrategyEmbeddingModelCorrectnessBase):
@ -35,13 +40,20 @@ class DistributionStrategyLstmModelCorrectnessTest(
run_distributed=None,
input_shapes=None):
del input_shapes
if tf2.enabled():
if not context.executing_eagerly():
self.skipTest("LSTM v2 and legacy graph mode don't work together.")
lstm = rnn_v2.LSTM
else:
lstm = rnn_v1.LSTM
with keras_correctness_test_base.MaybeDistributionScope(distribution):
word_ids = keras.layers.Input(
shape=(max_words,), dtype=np.int32, name='words')
word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids)
lstm_embed = keras.layers.LSTM(
units=4, return_sequences=False)(
word_embed)
lstm_embed = lstm(units=4, return_sequences=False)(
word_embed)
preds = keras.layers.Dense(2, activation='softmax')(lstm_embed)
model = keras.Model(inputs=[word_ids], outputs=[preds])