Add test for Keras LSTM model with multiple distribution strategies.

PiperOrigin-RevId: 290289057
Change-Id: Ia5f998b72a9815567ad226224c682c960d158ec0
This commit is contained in:
Ken Franko 2020-01-17 09:53:46 -08:00 committed by TensorFlower Gardener
parent ee24d4b059
commit d5d92b241b

View File

@ -19,6 +19,9 @@ from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
@ -558,5 +561,61 @@ class GradientTapeTest(test.TestCase, parameterized.TestCase):
self.assertTrue(all(g is not None for g in grads))
class KerasModelsTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies,
mode=["eager"]
))
def test_lstm(self, distribution):
batch_size = 32
def create_lstm_model():
model = keras.models.Sequential()
# We only have LSTM variables so we can detect no gradient issues more
# easily.
model.add(
keras.layers.LSTM(1, return_sequences=False, input_shape=(10, 1)))
return model
def create_lstm_data():
seq_length = 10
x_train = np.random.rand(batch_size, seq_length, 1).astype("float32")
y_train = np.random.rand(batch_size, 1).astype("float32")
return x_train, y_train
x, y = create_lstm_data()
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
dataset = dataset.batch(batch_size, drop_remainder=True)
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
with distribution.scope():
model = create_lstm_model()
optimizer = keras.optimizer_v2.gradient_descent.SGD()
@def_function.function
def train_step(input_iterator):
def step_fn(inputs):
inps, targ = inputs
with backprop.GradientTape() as tape:
output = model(inps)
loss = math_ops.reduce_mean(
keras.losses.binary_crossentropy(
y_true=targ, y_pred=output, from_logits=False))
grads = tape.gradient(loss, model.variables)
optimizer.apply_gradients(zip(grads, model.variables))
return loss
outputs = distribution.experimental_run_v2(
step_fn, args=(next(input_iterator),))
return distribution.experimental_local_results(outputs)
train_step(input_iterator)
if __name__ == "__main__":
test.main()