Test LSTM v2 layer in eager mode in dist strat keras tests

PiperOrigin-RevId: 258916770
This commit is contained in:
Priya Gupta 2019-07-19 00:06:05 -07:00 committed by TensorFlower Gardener
parent 3e1e849214
commit 2ecc2fffad

View File

@ -19,11 +19,16 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.eager import context
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.keras.distribute import keras_correctness_test_base from tensorflow.python.keras.distribute import keras_correctness_test_base
from tensorflow.python.keras.layers import recurrent as rnn_v1
from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
class DistributionStrategyLstmModelCorrectnessTest( class DistributionStrategyLstmModelCorrectnessTest(
keras_correctness_test_base keras_correctness_test_base
.TestDistributionStrategyEmbeddingModelCorrectnessBase): .TestDistributionStrategyEmbeddingModelCorrectnessBase):
@ -35,13 +40,20 @@ class DistributionStrategyLstmModelCorrectnessTest(
run_distributed=None, run_distributed=None,
input_shapes=None): input_shapes=None):
del input_shapes del input_shapes
if tf2.enabled():
if not context.executing_eagerly():
self.skipTest("LSTM v2 and legacy graph mode don't work together.")
lstm = rnn_v2.LSTM
else:
lstm = rnn_v1.LSTM
with keras_correctness_test_base.MaybeDistributionScope(distribution): with keras_correctness_test_base.MaybeDistributionScope(distribution):
word_ids = keras.layers.Input( word_ids = keras.layers.Input(
shape=(max_words,), dtype=np.int32, name='words') shape=(max_words,), dtype=np.int32, name='words')
word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids) word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids)
lstm_embed = keras.layers.LSTM( lstm_embed = lstm(units=4, return_sequences=False)(
units=4, return_sequences=False)( word_embed)
word_embed)
preds = keras.layers.Dense(2, activation='softmax')(lstm_embed) preds = keras.layers.Dense(2, activation='softmax')(lstm_embed)
model = keras.Model(inputs=[word_ids], outputs=[preds]) model = keras.Model(inputs=[word_ids], outputs=[preds])