Preserve trainable status of layers when loading from SavedModel.

Also refactor some unit tests to make it easier to create a "complete" test case which tests all model types and compares evaluation of orig vs loaded model.

PiperOrigin-RevId: 359853088
Change-Id: I85597d67dbd3150ec3ed1571f3c66c7340fb465c
This commit is contained in:
Monica Song 2021-02-26 15:37:47 -08:00 committed by TensorFlower Gardener
parent dca9172e65
commit 70b6d7a275
3 changed files with 51 additions and 7 deletions

View File

@ -696,7 +696,7 @@ class KerasObjectLoader(object):
model.__init__(inputs, outputs, name=config['name'])
functional_lib.connect_ancillary_layers(model, created_layers)
# Set model dtype and trainable status.
# Set model dtype.
_set_network_attributes_from_metadata(model)
# Unblock models that are dependent on this model.
@ -1161,7 +1161,7 @@ def _set_network_attributes_from_metadata(revived_obj):
metadata = revived_obj._serialized_attributes['metadata']
if metadata.get('dtype') is not None:
revived_obj._set_dtype_policy(metadata['dtype'])
revived_obj.trainable = metadata['trainable']
revived_obj._trainable = metadata['trainable']
# pylint:enable=protected-access

View File

@ -124,7 +124,7 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
return os.path.join(temp_dir, dirname)
def _test_save_and_load(self, use_dataset=False):
def _get_model(self):
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
model.layers[-1].activity_regularizer = regularizers.get('l2')
model.activity_regularizer = regularizers.get('l2')
@ -134,7 +134,9 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
def callable_loss():
return math_ops.reduce_sum(model.weights[0])
model.add_loss(callable_loss)
return model
def _train_model(self, model, use_dataset=False):
x = np.random.random((1, 3))
y = np.random.random((1, 4))
@ -150,9 +152,14 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
else:
model.train_on_batch(x, y)
def _save_and_load(self, model):
saved_model_dir = self._save_model_dir()
tf_save.save(model, saved_model_dir)
loaded = keras_load.load(saved_model_dir)
return loaded
def _test_evaluation(self, model, loaded):
# Assert that original and loaded models have the same results when called.
self.evaluate(variables.variables_initializer(loaded.variables))
self.assertAllClose(self.evaluate(model.weights),
self.evaluate(loaded.weights))
@ -175,13 +182,20 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
@keras_parameterized.run_with_all_model_types
def test_model_save_and_load(self):
self._test_save_and_load(use_dataset=True)
model = self._get_model()
self._train_model(model, use_dataset=False)
loaded = self._save_and_load(model)
self._test_evaluation(model, loaded)
@keras_parameterized.run_with_all_model_types
def test_model_save_and_load_dataset(self):
self._test_save_and_load(use_dataset=True)
model = self._get_model()
self._train_model(model, use_dataset=True)
loaded = self._save_and_load(model)
self._test_evaluation(model, loaded)
def test_trainable_weights(self):
"""Tests that trainable status of individual weights is preserved."""
layer = keras.layers.Dense(4, name='custom_layer')
layer.build([3,])
layer.add_weight(
@ -208,6 +222,31 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
self.assertAllClose(self.evaluate(getattr(layer, attr)),
self.evaluate(getattr(loaded, attr)))
@keras_parameterized.run_with_all_model_types
def test_trainable_layers(self):
"""Tests that trainable status of individual layers is preserved."""
model = model = self._get_model()
# Set the last layer to *not* be trainable.
model.layers[-1].trainable = False
self._train_model(model, use_dataset=True)
loaded = self._save_and_load(model)
self._test_evaluation(model, loaded)
self.assertFalse(model.layers[-1].trainable)
self.assertFalse(loaded.layers[-1].trainable)
def test_trainable_custom_model_false(self):
"""Tests that overall False trainable status of Model is preserved."""
# Set all layers to *not* be trainable.
model = testing_utils.SmallSubclassMLP(1, 4, trainable=False)
model.compile(loss='mse', optimizer='rmsprop')
self._train_model(model, use_dataset=False)
loaded = self._save_and_load(model)
self._test_evaluation(model, loaded)
self.assertEmpty(model.trainable_variables)
self.assertEmpty(loaded.trainable_variables)
def test_maintains_losses(self):
"""Tests that the layer losses do not change before and after export."""
model = keras.models.Sequential([LayerWithLoss()])

View File

@ -469,8 +469,13 @@ def get_small_functional_mlp(num_hidden, num_classes, input_dim):
class SmallSubclassMLP(models.Model):
"""A subclass model based small MLP."""
def __init__(self, num_hidden, num_classes, use_bn=False, use_dp=False):
super(SmallSubclassMLP, self).__init__(name='test_model')
def __init__(self,
num_hidden,
num_classes,
use_bn=False,
use_dp=False,
**kwargs):
super(SmallSubclassMLP, self).__init__(name='test_model', **kwargs)
self.use_bn = use_bn
self.use_dp = use_dp