From fe206167759b89712bebd97363a7bad41171c0f4 Mon Sep 17 00:00:00 2001 From: Katherine Wu Date: Thu, 16 May 2019 14:56:01 -0700 Subject: [PATCH] (CL 1/2) Allow `model_to_estimator` to save object-based checkpoints. This change adds an `optimizer_config` field to models.clone_and_build, which allows the optimizer config to be passed in (instead of generating the config in the function -- this caused issues when clone_and_build is called in a different context from the original model). PiperOrigin-RevId: 248605885 --- tensorflow/python/keras/models.py | 9 ++++++-- tensorflow/python/keras/models_test.py | 29 +++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index 9f8d6abbbc2..d699daf6b48 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -480,7 +480,8 @@ def in_place_subclassed_model_state_restoration(model): def clone_and_build_model( model, input_tensors=None, target_tensors=None, custom_objects=None, - compile_clone=True, in_place_reset=False, optimizer_iterations=None): + compile_clone=True, in_place_reset=False, optimizer_iterations=None, + optimizer_config=None): """Clone a `Model` and build/compile it with the same settings used before. This function can be be run in the same graph or in a separate graph from the @@ -508,6 +509,10 @@ def clone_and_build_model( optimizer if the clone is compiled. This argument is used when a Keras model is cloned into an Estimator model function, because Estimators create their own global step variable. + optimizer_config: Optimizer config dictionary returned from `get_config()`. + This argument should be defined if `clone_and_build_model` is called in + a different graph or session from the original model, and the optimizer is + an instance of `OptimizerV2`. Returns: Clone of the model. @@ -562,7 +567,7 @@ def clone_and_build_model( orig_optimizer.optimizer, optimizer_iterations) K.track_tf_optimizer(optimizer) else: - optimizer_config = orig_optimizer.get_config() + optimizer_config = optimizer_config or orig_optimizer.get_config() optimizer = orig_optimizer.__class__.from_config(optimizer_config) if optimizer_iterations is not None: optimizer.iterations = optimizer_iterations diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py index 0ef7323fe5e..4f8758ed5d0 100644 --- a/tensorflow/python/keras/models_test.py +++ b/tensorflow/python/keras/models_test.py @@ -352,10 +352,10 @@ class TestModelDeepCopy(test.TestCase): model_copy.get_weights()[0])) -@keras_parameterized.run_all_keras_modes class TestCloneAndBuildModel(keras_parameterized.TestCase): @keras_parameterized.run_with_all_model_types + @keras_parameterized.run_all_keras_modes def test_clone_and_build_non_compiled_model(self): inp = np.random.random((10, 4)) out = np.random.random((10, 4)) @@ -436,6 +436,7 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase): new_model.evaluate(inp, out) @keras_parameterized.run_with_all_model_types + @keras_parameterized.run_all_keras_modes def test_clone_and_build_compiled(self): model = _get_model() model.compile( @@ -445,6 +446,7 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase): self._clone_and_build_test_helper(model, testing_utils.get_model_type()) + @keras_parameterized.run_all_keras_modes def test_clone_and_build_sequential_without_inputs_defined(self): model = models.Sequential(_get_layers(input_shape=None)) model.compile( @@ -476,10 +478,12 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase): self.assertEqual(K.eval(global_step), 124) @keras_parameterized.run_with_all_model_types + @keras_parameterized.run_all_keras_modes def test_replace_tf_optimizer_iterations_variable(self): self.assert_optimizer_iterations_increases(adam.AdamOptimizer(0.01)) @keras_parameterized.run_with_all_model_types + @keras_parameterized.run_all_keras_modes def test_replace_keras_optimizer_iterations_variable(self): if testing_utils.should_run_eagerly(): # This needs to be updated to run with v2 optimizers. @@ -487,6 +491,29 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase): self.assert_optimizer_iterations_increases('adam') + def test_clone_optimizer_in_different_graph(self): + with ops.Graph().as_default(): + with self.session(): + model = testing_utils.get_small_sequential_mlp(3, 4) + optimizer = keras.optimizer_v2.adam.Adam() + model.compile( + optimizer, 'mse', metrics=['acc', metrics.categorical_accuracy], + ) + model.fit( + x=np.array([[1., 2., 3., 4.]]), + y=np.array([[1., 1., 1., 1.]]), + epochs=1) + optimizer_config = optimizer.get_config() + with ops.Graph().as_default(): + with self.session(): + with self.assertRaisesRegexp(ValueError, + 'Cannot use the given session'): + models.clone_and_build_model(model, compile_clone=True) + # The optimizer_config object allows the model to be cloned in a + # different graph. + models.clone_and_build_model(model, compile_clone=True, + optimizer_config=optimizer_config) + if __name__ == '__main__': test.main()