Add integration test for Sequential pop
workflow.
PiperOrigin-RevId: 251323411
This commit is contained in:
parent
d69b8192fb
commit
0030e1dbdd
@ -19,10 +19,12 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
@ -31,6 +33,24 @@ from tensorflow.python.ops import rnn_cell
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class KerasIntegrationTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
def _save_and_reload_model(self, model):
|
||||||
|
fpath = os.path.join(self.create_tempdir().full_path,
|
||||||
|
'test_model_%s' % (random.randint(0, 1e7),))
|
||||||
|
if context.executing_eagerly():
|
||||||
|
keras.saving.save._KERAS_SAVED_MODEL_STILL_EXPERIMENTAL = False
|
||||||
|
save_format = 'tf'
|
||||||
|
else:
|
||||||
|
if (not isinstance(model, keras.Sequential) and
|
||||||
|
not model._is_graph_network):
|
||||||
|
return model # Not supported
|
||||||
|
save_format = 'h5'
|
||||||
|
model.save(fpath, save_format=save_format)
|
||||||
|
model = keras.models.load_model(fpath)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
class VectorClassificationIntegrationTest(keras_parameterized.TestCase):
|
class VectorClassificationIntegrationTest(keras_parameterized.TestCase):
|
||||||
@ -103,6 +123,55 @@ class VectorClassificationIntegrationTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(predictions.shape, (x_train.shape[0], 2))
|
self.assertEqual(predictions.shape, (x_train.shape[0], 2))
|
||||||
|
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
class SequentialIntegrationTest(KerasIntegrationTest):
|
||||||
|
|
||||||
|
def test_sequential_save_and_pop(self):
|
||||||
|
# Test the following sequence of actions:
|
||||||
|
# - construct a Sequential model and train it
|
||||||
|
# - save it
|
||||||
|
# - load it
|
||||||
|
# - pop its last layer and add a new layer instead
|
||||||
|
# - continue training
|
||||||
|
np.random.seed(1337)
|
||||||
|
(x_train, y_train), _ = testing_utils.get_test_data(
|
||||||
|
train_samples=100,
|
||||||
|
test_samples=0,
|
||||||
|
input_shape=(10,),
|
||||||
|
num_classes=2)
|
||||||
|
y_train = keras.utils.to_categorical(y_train)
|
||||||
|
model = keras.Sequential([
|
||||||
|
keras.layers.Dense(16, activation='relu'),
|
||||||
|
keras.layers.Dropout(0.1),
|
||||||
|
keras.layers.Dense(y_train.shape[-1], activation='softmax')
|
||||||
|
])
|
||||||
|
model.compile(
|
||||||
|
loss='categorical_crossentropy',
|
||||||
|
optimizer=keras.optimizer_v2.adam.Adam(0.005),
|
||||||
|
metrics=['acc'],
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
model.fit(x_train, y_train, epochs=1, batch_size=10,
|
||||||
|
validation_data=(x_train, y_train),
|
||||||
|
verbose=2)
|
||||||
|
model = self._save_and_reload_model(model)
|
||||||
|
model.pop()
|
||||||
|
model.add(keras.layers.Dense(y_train.shape[-1], activation='softmax'))
|
||||||
|
model.compile(
|
||||||
|
loss='categorical_crossentropy',
|
||||||
|
optimizer=keras.optimizer_v2.adam.Adam(0.005),
|
||||||
|
metrics=['acc'],
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
history = model.fit(x_train, y_train, epochs=10, batch_size=10,
|
||||||
|
validation_data=(x_train, y_train),
|
||||||
|
verbose=2)
|
||||||
|
self.assertGreater(history.history['val_acc'][-1], 0.7)
|
||||||
|
model = self._save_and_reload_model(model)
|
||||||
|
_, val_acc = model.evaluate(x_train, y_train)
|
||||||
|
self.assertAlmostEqual(history.history['val_acc'][-1], val_acc)
|
||||||
|
predictions = model.predict(x_train)
|
||||||
|
self.assertEqual(predictions.shape, (x_train.shape[0], 2))
|
||||||
|
|
||||||
|
|
||||||
# See b/122473407
|
# See b/122473407
|
||||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
class TimeseriesClassificationIntegrationTest(keras_parameterized.TestCase):
|
class TimeseriesClassificationIntegrationTest(keras_parameterized.TestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user