Add integration test for Sequential pop
workflow.
PiperOrigin-RevId: 251323411
This commit is contained in:
parent
d69b8192fb
commit
0030e1dbdd
@ -19,10 +19,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
@ -31,6 +33,24 @@ from tensorflow.python.ops import rnn_cell
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class KerasIntegrationTest(keras_parameterized.TestCase):
|
||||
|
||||
def _save_and_reload_model(self, model):
|
||||
fpath = os.path.join(self.create_tempdir().full_path,
|
||||
'test_model_%s' % (random.randint(0, 1e7),))
|
||||
if context.executing_eagerly():
|
||||
keras.saving.save._KERAS_SAVED_MODEL_STILL_EXPERIMENTAL = False
|
||||
save_format = 'tf'
|
||||
else:
|
||||
if (not isinstance(model, keras.Sequential) and
|
||||
not model._is_graph_network):
|
||||
return model # Not supported
|
||||
save_format = 'h5'
|
||||
model.save(fpath, save_format=save_format)
|
||||
model = keras.models.load_model(fpath)
|
||||
return model
|
||||
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class VectorClassificationIntegrationTest(keras_parameterized.TestCase):
|
||||
@ -103,6 +123,55 @@ class VectorClassificationIntegrationTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(predictions.shape, (x_train.shape[0], 2))
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class SequentialIntegrationTest(KerasIntegrationTest):
|
||||
|
||||
def test_sequential_save_and_pop(self):
|
||||
# Test the following sequence of actions:
|
||||
# - construct a Sequential model and train it
|
||||
# - save it
|
||||
# - load it
|
||||
# - pop its last layer and add a new layer instead
|
||||
# - continue training
|
||||
np.random.seed(1337)
|
||||
(x_train, y_train), _ = testing_utils.get_test_data(
|
||||
train_samples=100,
|
||||
test_samples=0,
|
||||
input_shape=(10,),
|
||||
num_classes=2)
|
||||
y_train = keras.utils.to_categorical(y_train)
|
||||
model = keras.Sequential([
|
||||
keras.layers.Dense(16, activation='relu'),
|
||||
keras.layers.Dropout(0.1),
|
||||
keras.layers.Dense(y_train.shape[-1], activation='softmax')
|
||||
])
|
||||
model.compile(
|
||||
loss='categorical_crossentropy',
|
||||
optimizer=keras.optimizer_v2.adam.Adam(0.005),
|
||||
metrics=['acc'],
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
model.fit(x_train, y_train, epochs=1, batch_size=10,
|
||||
validation_data=(x_train, y_train),
|
||||
verbose=2)
|
||||
model = self._save_and_reload_model(model)
|
||||
model.pop()
|
||||
model.add(keras.layers.Dense(y_train.shape[-1], activation='softmax'))
|
||||
model.compile(
|
||||
loss='categorical_crossentropy',
|
||||
optimizer=keras.optimizer_v2.adam.Adam(0.005),
|
||||
metrics=['acc'],
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(x_train, y_train, epochs=10, batch_size=10,
|
||||
validation_data=(x_train, y_train),
|
||||
verbose=2)
|
||||
self.assertGreater(history.history['val_acc'][-1], 0.7)
|
||||
model = self._save_and_reload_model(model)
|
||||
_, val_acc = model.evaluate(x_train, y_train)
|
||||
self.assertAlmostEqual(history.history['val_acc'][-1], val_acc)
|
||||
predictions = model.predict(x_train)
|
||||
self.assertEqual(predictions.shape, (x_train.shape[0], 2))
|
||||
|
||||
|
||||
# See b/122473407
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
class TimeseriesClassificationIntegrationTest(keras_parameterized.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user