Add test for Keras LSTM model with multiple distribution strategies.
PiperOrigin-RevId: 290289057 Change-Id: Ia5f998b72a9815567ad226224c682c960d158ec0
This commit is contained in:
parent
ee24d4b059
commit
d5d92b241b
@ -19,6 +19,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
@ -558,5 +561,61 @@ class GradientTapeTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertTrue(all(g is not None for g in grads))
|
||||
|
||||
|
||||
class KerasModelsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_lstm(self, distribution):
|
||||
|
||||
batch_size = 32
|
||||
|
||||
def create_lstm_model():
|
||||
model = keras.models.Sequential()
|
||||
# We only have LSTM variables so we can detect no gradient issues more
|
||||
# easily.
|
||||
model.add(
|
||||
keras.layers.LSTM(1, return_sequences=False, input_shape=(10, 1)))
|
||||
return model
|
||||
|
||||
def create_lstm_data():
|
||||
seq_length = 10
|
||||
|
||||
x_train = np.random.rand(batch_size, seq_length, 1).astype("float32")
|
||||
y_train = np.random.rand(batch_size, 1).astype("float32")
|
||||
return x_train, y_train
|
||||
|
||||
x, y = create_lstm_data()
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
with distribution.scope():
|
||||
model = create_lstm_model()
|
||||
optimizer = keras.optimizer_v2.gradient_descent.SGD()
|
||||
|
||||
@def_function.function
|
||||
def train_step(input_iterator):
|
||||
|
||||
def step_fn(inputs):
|
||||
inps, targ = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
output = model(inps)
|
||||
loss = math_ops.reduce_mean(
|
||||
keras.losses.binary_crossentropy(
|
||||
y_true=targ, y_pred=output, from_logits=False))
|
||||
grads = tape.gradient(loss, model.variables)
|
||||
optimizer.apply_gradients(zip(grads, model.variables))
|
||||
return loss
|
||||
|
||||
outputs = distribution.experimental_run_v2(
|
||||
step_fn, args=(next(input_iterator),))
|
||||
return distribution.experimental_local_results(outputs)
|
||||
|
||||
train_step(input_iterator)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user