(CL 1/2) Allow model_to_estimator to save object-based checkpoints.

This change adds an `optimizer_config` field to models.clone_and_build, which allows the optimizer config to be passed in (instead of generating the config in the function -- this caused issues when clone_and_build is called in a different context from the original model).

PiperOrigin-RevId: 248605885
This commit is contained in:
Katherine Wu 2019-05-16 14:56:01 -07:00 committed by TensorFlower Gardener
parent 1bb6236372
commit fe20616775
2 changed files with 35 additions and 3 deletions

View File

@ -480,7 +480,8 @@ def in_place_subclassed_model_state_restoration(model):
def clone_and_build_model( def clone_and_build_model(
model, input_tensors=None, target_tensors=None, custom_objects=None, model, input_tensors=None, target_tensors=None, custom_objects=None,
compile_clone=True, in_place_reset=False, optimizer_iterations=None): compile_clone=True, in_place_reset=False, optimizer_iterations=None,
optimizer_config=None):
"""Clone a `Model` and build/compile it with the same settings used before. """Clone a `Model` and build/compile it with the same settings used before.
This function can be be run in the same graph or in a separate graph from the This function can be be run in the same graph or in a separate graph from the
@ -508,6 +509,10 @@ def clone_and_build_model(
optimizer if the clone is compiled. This argument is used when a Keras optimizer if the clone is compiled. This argument is used when a Keras
model is cloned into an Estimator model function, because Estimators model is cloned into an Estimator model function, because Estimators
create their own global step variable. create their own global step variable.
optimizer_config: Optimizer config dictionary returned from `get_config()`.
This argument should be defined if `clone_and_build_model` is called in
a different graph or session from the original model, and the optimizer is
an instance of `OptimizerV2`.
Returns: Returns:
Clone of the model. Clone of the model.
@ -562,7 +567,7 @@ def clone_and_build_model(
orig_optimizer.optimizer, optimizer_iterations) orig_optimizer.optimizer, optimizer_iterations)
K.track_tf_optimizer(optimizer) K.track_tf_optimizer(optimizer)
else: else:
optimizer_config = orig_optimizer.get_config() optimizer_config = optimizer_config or orig_optimizer.get_config()
optimizer = orig_optimizer.__class__.from_config(optimizer_config) optimizer = orig_optimizer.__class__.from_config(optimizer_config)
if optimizer_iterations is not None: if optimizer_iterations is not None:
optimizer.iterations = optimizer_iterations optimizer.iterations = optimizer_iterations

View File

@ -352,10 +352,10 @@ class TestModelDeepCopy(test.TestCase):
model_copy.get_weights()[0])) model_copy.get_weights()[0]))
@keras_parameterized.run_all_keras_modes
class TestCloneAndBuildModel(keras_parameterized.TestCase): class TestCloneAndBuildModel(keras_parameterized.TestCase):
@keras_parameterized.run_with_all_model_types @keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_clone_and_build_non_compiled_model(self): def test_clone_and_build_non_compiled_model(self):
inp = np.random.random((10, 4)) inp = np.random.random((10, 4))
out = np.random.random((10, 4)) out = np.random.random((10, 4))
@ -436,6 +436,7 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase):
new_model.evaluate(inp, out) new_model.evaluate(inp, out)
@keras_parameterized.run_with_all_model_types @keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_clone_and_build_compiled(self): def test_clone_and_build_compiled(self):
model = _get_model() model = _get_model()
model.compile( model.compile(
@ -445,6 +446,7 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase):
self._clone_and_build_test_helper(model, testing_utils.get_model_type()) self._clone_and_build_test_helper(model, testing_utils.get_model_type())
@keras_parameterized.run_all_keras_modes
def test_clone_and_build_sequential_without_inputs_defined(self): def test_clone_and_build_sequential_without_inputs_defined(self):
model = models.Sequential(_get_layers(input_shape=None)) model = models.Sequential(_get_layers(input_shape=None))
model.compile( model.compile(
@ -476,10 +478,12 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase):
self.assertEqual(K.eval(global_step), 124) self.assertEqual(K.eval(global_step), 124)
@keras_parameterized.run_with_all_model_types @keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_replace_tf_optimizer_iterations_variable(self): def test_replace_tf_optimizer_iterations_variable(self):
self.assert_optimizer_iterations_increases(adam.AdamOptimizer(0.01)) self.assert_optimizer_iterations_increases(adam.AdamOptimizer(0.01))
@keras_parameterized.run_with_all_model_types @keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_replace_keras_optimizer_iterations_variable(self): def test_replace_keras_optimizer_iterations_variable(self):
if testing_utils.should_run_eagerly(): if testing_utils.should_run_eagerly():
# This needs to be updated to run with v2 optimizers. # This needs to be updated to run with v2 optimizers.
@ -487,6 +491,29 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase):
self.assert_optimizer_iterations_increases('adam') self.assert_optimizer_iterations_increases('adam')
def test_clone_optimizer_in_different_graph(self):
with ops.Graph().as_default():
with self.session():
model = testing_utils.get_small_sequential_mlp(3, 4)
optimizer = keras.optimizer_v2.adam.Adam()
model.compile(
optimizer, 'mse', metrics=['acc', metrics.categorical_accuracy],
)
model.fit(
x=np.array([[1., 2., 3., 4.]]),
y=np.array([[1., 1., 1., 1.]]),
epochs=1)
optimizer_config = optimizer.get_config()
with ops.Graph().as_default():
with self.session():
with self.assertRaisesRegexp(ValueError,
'Cannot use the given session'):
models.clone_and_build_model(model, compile_clone=True)
# The optimizer_config object allows the model to be cloned in a
# different graph.
models.clone_and_build_model(model, compile_clone=True,
optimizer_config=optimizer_config)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()