Preserve trainable status of layers when loading from SavedModel.
Also refactor some unit tests to make it easier to create a "complete" test case which tests all model types and compares evaluation of orig vs loaded model. PiperOrigin-RevId: 359853088 Change-Id: I85597d67dbd3150ec3ed1571f3c66c7340fb465c
This commit is contained in:
parent
dca9172e65
commit
70b6d7a275
@ -696,7 +696,7 @@ class KerasObjectLoader(object):
|
|||||||
model.__init__(inputs, outputs, name=config['name'])
|
model.__init__(inputs, outputs, name=config['name'])
|
||||||
functional_lib.connect_ancillary_layers(model, created_layers)
|
functional_lib.connect_ancillary_layers(model, created_layers)
|
||||||
|
|
||||||
# Set model dtype and trainable status.
|
# Set model dtype.
|
||||||
_set_network_attributes_from_metadata(model)
|
_set_network_attributes_from_metadata(model)
|
||||||
|
|
||||||
# Unblock models that are dependent on this model.
|
# Unblock models that are dependent on this model.
|
||||||
@ -1161,7 +1161,7 @@ def _set_network_attributes_from_metadata(revived_obj):
|
|||||||
metadata = revived_obj._serialized_attributes['metadata']
|
metadata = revived_obj._serialized_attributes['metadata']
|
||||||
if metadata.get('dtype') is not None:
|
if metadata.get('dtype') is not None:
|
||||||
revived_obj._set_dtype_policy(metadata['dtype'])
|
revived_obj._set_dtype_policy(metadata['dtype'])
|
||||||
revived_obj.trainable = metadata['trainable']
|
revived_obj._trainable = metadata['trainable']
|
||||||
# pylint:enable=protected-access
|
# pylint:enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
@ -124,7 +124,7 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
|||||||
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
|
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
|
||||||
return os.path.join(temp_dir, dirname)
|
return os.path.join(temp_dir, dirname)
|
||||||
|
|
||||||
def _test_save_and_load(self, use_dataset=False):
|
def _get_model(self):
|
||||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
||||||
model.layers[-1].activity_regularizer = regularizers.get('l2')
|
model.layers[-1].activity_regularizer = regularizers.get('l2')
|
||||||
model.activity_regularizer = regularizers.get('l2')
|
model.activity_regularizer = regularizers.get('l2')
|
||||||
@ -134,7 +134,9 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
|||||||
def callable_loss():
|
def callable_loss():
|
||||||
return math_ops.reduce_sum(model.weights[0])
|
return math_ops.reduce_sum(model.weights[0])
|
||||||
model.add_loss(callable_loss)
|
model.add_loss(callable_loss)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _train_model(self, model, use_dataset=False):
|
||||||
x = np.random.random((1, 3))
|
x = np.random.random((1, 3))
|
||||||
y = np.random.random((1, 4))
|
y = np.random.random((1, 4))
|
||||||
|
|
||||||
@ -150,9 +152,14 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
|||||||
else:
|
else:
|
||||||
model.train_on_batch(x, y)
|
model.train_on_batch(x, y)
|
||||||
|
|
||||||
|
def _save_and_load(self, model):
|
||||||
saved_model_dir = self._save_model_dir()
|
saved_model_dir = self._save_model_dir()
|
||||||
tf_save.save(model, saved_model_dir)
|
tf_save.save(model, saved_model_dir)
|
||||||
loaded = keras_load.load(saved_model_dir)
|
loaded = keras_load.load(saved_model_dir)
|
||||||
|
return loaded
|
||||||
|
|
||||||
|
def _test_evaluation(self, model, loaded):
|
||||||
|
# Assert that original and loaded models have the same results when called.
|
||||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||||
self.assertAllClose(self.evaluate(model.weights),
|
self.assertAllClose(self.evaluate(model.weights),
|
||||||
self.evaluate(loaded.weights))
|
self.evaluate(loaded.weights))
|
||||||
@ -175,13 +182,20 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
def test_model_save_and_load(self):
|
def test_model_save_and_load(self):
|
||||||
self._test_save_and_load(use_dataset=True)
|
model = self._get_model()
|
||||||
|
self._train_model(model, use_dataset=False)
|
||||||
|
loaded = self._save_and_load(model)
|
||||||
|
self._test_evaluation(model, loaded)
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
def test_model_save_and_load_dataset(self):
|
def test_model_save_and_load_dataset(self):
|
||||||
self._test_save_and_load(use_dataset=True)
|
model = self._get_model()
|
||||||
|
self._train_model(model, use_dataset=True)
|
||||||
|
loaded = self._save_and_load(model)
|
||||||
|
self._test_evaluation(model, loaded)
|
||||||
|
|
||||||
def test_trainable_weights(self):
|
def test_trainable_weights(self):
|
||||||
|
"""Tests that trainable status of individual weights is preserved."""
|
||||||
layer = keras.layers.Dense(4, name='custom_layer')
|
layer = keras.layers.Dense(4, name='custom_layer')
|
||||||
layer.build([3,])
|
layer.build([3,])
|
||||||
layer.add_weight(
|
layer.add_weight(
|
||||||
@ -208,6 +222,31 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
|||||||
self.assertAllClose(self.evaluate(getattr(layer, attr)),
|
self.assertAllClose(self.evaluate(getattr(layer, attr)),
|
||||||
self.evaluate(getattr(loaded, attr)))
|
self.evaluate(getattr(loaded, attr)))
|
||||||
|
|
||||||
|
@keras_parameterized.run_with_all_model_types
|
||||||
|
def test_trainable_layers(self):
|
||||||
|
"""Tests that trainable status of individual layers is preserved."""
|
||||||
|
model = model = self._get_model()
|
||||||
|
# Set the last layer to *not* be trainable.
|
||||||
|
model.layers[-1].trainable = False
|
||||||
|
self._train_model(model, use_dataset=True)
|
||||||
|
loaded = self._save_and_load(model)
|
||||||
|
|
||||||
|
self._test_evaluation(model, loaded)
|
||||||
|
self.assertFalse(model.layers[-1].trainable)
|
||||||
|
self.assertFalse(loaded.layers[-1].trainable)
|
||||||
|
|
||||||
|
def test_trainable_custom_model_false(self):
|
||||||
|
"""Tests that overall False trainable status of Model is preserved."""
|
||||||
|
# Set all layers to *not* be trainable.
|
||||||
|
model = testing_utils.SmallSubclassMLP(1, 4, trainable=False)
|
||||||
|
model.compile(loss='mse', optimizer='rmsprop')
|
||||||
|
self._train_model(model, use_dataset=False)
|
||||||
|
loaded = self._save_and_load(model)
|
||||||
|
|
||||||
|
self._test_evaluation(model, loaded)
|
||||||
|
self.assertEmpty(model.trainable_variables)
|
||||||
|
self.assertEmpty(loaded.trainable_variables)
|
||||||
|
|
||||||
def test_maintains_losses(self):
|
def test_maintains_losses(self):
|
||||||
"""Tests that the layer losses do not change before and after export."""
|
"""Tests that the layer losses do not change before and after export."""
|
||||||
model = keras.models.Sequential([LayerWithLoss()])
|
model = keras.models.Sequential([LayerWithLoss()])
|
||||||
|
@ -469,8 +469,13 @@ def get_small_functional_mlp(num_hidden, num_classes, input_dim):
|
|||||||
class SmallSubclassMLP(models.Model):
|
class SmallSubclassMLP(models.Model):
|
||||||
"""A subclass model based small MLP."""
|
"""A subclass model based small MLP."""
|
||||||
|
|
||||||
def __init__(self, num_hidden, num_classes, use_bn=False, use_dp=False):
|
def __init__(self,
|
||||||
super(SmallSubclassMLP, self).__init__(name='test_model')
|
num_hidden,
|
||||||
|
num_classes,
|
||||||
|
use_bn=False,
|
||||||
|
use_dp=False,
|
||||||
|
**kwargs):
|
||||||
|
super(SmallSubclassMLP, self).__init__(name='test_model', **kwargs)
|
||||||
self.use_bn = use_bn
|
self.use_bn = use_bn
|
||||||
self.use_dp = use_dp
|
self.use_dp = use_dp
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user