diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py index 4f9f9a812e1..59818f9fb23 100644 --- a/tensorflow/python/keras/saving/saved_model/load.py +++ b/tensorflow/python/keras/saving/saved_model/load.py @@ -696,7 +696,7 @@ class KerasObjectLoader(object): model.__init__(inputs, outputs, name=config['name']) functional_lib.connect_ancillary_layers(model, created_layers) - # Set model dtype and trainable status. + # Set model dtype. _set_network_attributes_from_metadata(model) # Unblock models that are dependent on this model. @@ -1161,7 +1161,7 @@ def _set_network_attributes_from_metadata(revived_obj): metadata = revived_obj._serialized_attributes['metadata'] if metadata.get('dtype') is not None: revived_obj._set_dtype_policy(metadata['dtype']) - revived_obj.trainable = metadata['trainable'] + revived_obj._trainable = metadata['trainable'] # pylint:enable=protected-access diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index 915f2ec34e7..e217e926f5f 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -124,7 +124,7 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase): self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) return os.path.join(temp_dir, dirname) - def _test_save_and_load(self, use_dataset=False): + def _get_model(self): model = testing_utils.get_small_mlp(1, 4, input_dim=3) model.layers[-1].activity_regularizer = regularizers.get('l2') model.activity_regularizer = regularizers.get('l2') @@ -134,7 +134,9 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase): def callable_loss(): return math_ops.reduce_sum(model.weights[0]) model.add_loss(callable_loss) + return model + def _train_model(self, model, use_dataset=False): x = np.random.random((1, 3)) y = np.random.random((1, 4)) @@ -150,9 +152,14 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase): else: model.train_on_batch(x, y) + def _save_and_load(self, model): saved_model_dir = self._save_model_dir() tf_save.save(model, saved_model_dir) loaded = keras_load.load(saved_model_dir) + return loaded + + def _test_evaluation(self, model, loaded): + # Assert that original and loaded models have the same results when called. self.evaluate(variables.variables_initializer(loaded.variables)) self.assertAllClose(self.evaluate(model.weights), self.evaluate(loaded.weights)) @@ -175,13 +182,20 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase): @keras_parameterized.run_with_all_model_types def test_model_save_and_load(self): - self._test_save_and_load(use_dataset=True) + model = self._get_model() + self._train_model(model, use_dataset=False) + loaded = self._save_and_load(model) + self._test_evaluation(model, loaded) @keras_parameterized.run_with_all_model_types def test_model_save_and_load_dataset(self): - self._test_save_and_load(use_dataset=True) + model = self._get_model() + self._train_model(model, use_dataset=True) + loaded = self._save_and_load(model) + self._test_evaluation(model, loaded) def test_trainable_weights(self): + """Tests that trainable status of individual weights is preserved.""" layer = keras.layers.Dense(4, name='custom_layer') layer.build([3,]) layer.add_weight( @@ -208,6 +222,31 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase): self.assertAllClose(self.evaluate(getattr(layer, attr)), self.evaluate(getattr(loaded, attr))) + @keras_parameterized.run_with_all_model_types + def test_trainable_layers(self): + """Tests that trainable status of individual layers is preserved.""" + model = model = self._get_model() + # Set the last layer to *not* be trainable. + model.layers[-1].trainable = False + self._train_model(model, use_dataset=True) + loaded = self._save_and_load(model) + + self._test_evaluation(model, loaded) + self.assertFalse(model.layers[-1].trainable) + self.assertFalse(loaded.layers[-1].trainable) + + def test_trainable_custom_model_false(self): + """Tests that overall False trainable status of Model is preserved.""" + # Set all layers to *not* be trainable. + model = testing_utils.SmallSubclassMLP(1, 4, trainable=False) + model.compile(loss='mse', optimizer='rmsprop') + self._train_model(model, use_dataset=False) + loaded = self._save_and_load(model) + + self._test_evaluation(model, loaded) + self.assertEmpty(model.trainable_variables) + self.assertEmpty(loaded.trainable_variables) + def test_maintains_losses(self): """Tests that the layer losses do not change before and after export.""" model = keras.models.Sequential([LayerWithLoss()]) diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py index bf3de9a7f37..785beeabb49 100644 --- a/tensorflow/python/keras/testing_utils.py +++ b/tensorflow/python/keras/testing_utils.py @@ -469,8 +469,13 @@ def get_small_functional_mlp(num_hidden, num_classes, input_dim): class SmallSubclassMLP(models.Model): """A subclass model based small MLP.""" - def __init__(self, num_hidden, num_classes, use_bn=False, use_dp=False): - super(SmallSubclassMLP, self).__init__(name='test_model') + def __init__(self, + num_hidden, + num_classes, + use_bn=False, + use_dp=False, + **kwargs): + super(SmallSubclassMLP, self).__init__(name='test_model', **kwargs) self.use_bn = use_bn self.use_dp = use_dp