(CL 1/2) Allow model_to_estimator to save object-based checkpoints.

This change adds an `optimizer_config` field to models.clone_and_build, which allows the optimizer config to be passed in (instead of generating the config in the function -- this caused issues when clone_and_build is called in a different context from the original model).

PiperOrigin-RevId: 248605885
This commit is contained in:
Katherine Wu 2019-05-16 14:56:01 -07:00 committed by TensorFlower Gardener
parent 1bb6236372
commit fe20616775
2 changed files with 35 additions and 3 deletions

View File

@ -480,7 +480,8 @@ def in_place_subclassed_model_state_restoration(model):
def clone_and_build_model(
model, input_tensors=None, target_tensors=None, custom_objects=None,
compile_clone=True, in_place_reset=False, optimizer_iterations=None):
compile_clone=True, in_place_reset=False, optimizer_iterations=None,
optimizer_config=None):
"""Clone a `Model` and build/compile it with the same settings used before.
This function can be be run in the same graph or in a separate graph from the
@ -508,6 +509,10 @@ def clone_and_build_model(
optimizer if the clone is compiled. This argument is used when a Keras
model is cloned into an Estimator model function, because Estimators
create their own global step variable.
optimizer_config: Optimizer config dictionary returned from `get_config()`.
This argument should be defined if `clone_and_build_model` is called in
a different graph or session from the original model, and the optimizer is
an instance of `OptimizerV2`.
Returns:
Clone of the model.
@ -562,7 +567,7 @@ def clone_and_build_model(
orig_optimizer.optimizer, optimizer_iterations)
K.track_tf_optimizer(optimizer)
else:
optimizer_config = orig_optimizer.get_config()
optimizer_config = optimizer_config or orig_optimizer.get_config()
optimizer = orig_optimizer.__class__.from_config(optimizer_config)
if optimizer_iterations is not None:
optimizer.iterations = optimizer_iterations

View File

@ -352,10 +352,10 @@ class TestModelDeepCopy(test.TestCase):
model_copy.get_weights()[0]))
@keras_parameterized.run_all_keras_modes
class TestCloneAndBuildModel(keras_parameterized.TestCase):
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_clone_and_build_non_compiled_model(self):
inp = np.random.random((10, 4))
out = np.random.random((10, 4))
@ -436,6 +436,7 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase):
new_model.evaluate(inp, out)
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_clone_and_build_compiled(self):
model = _get_model()
model.compile(
@ -445,6 +446,7 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase):
self._clone_and_build_test_helper(model, testing_utils.get_model_type())
@keras_parameterized.run_all_keras_modes
def test_clone_and_build_sequential_without_inputs_defined(self):
model = models.Sequential(_get_layers(input_shape=None))
model.compile(
@ -476,10 +478,12 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase):
self.assertEqual(K.eval(global_step), 124)
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_replace_tf_optimizer_iterations_variable(self):
self.assert_optimizer_iterations_increases(adam.AdamOptimizer(0.01))
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_replace_keras_optimizer_iterations_variable(self):
if testing_utils.should_run_eagerly():
# This needs to be updated to run with v2 optimizers.
@ -487,6 +491,29 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase):
self.assert_optimizer_iterations_increases('adam')
def test_clone_optimizer_in_different_graph(self):
with ops.Graph().as_default():
with self.session():
model = testing_utils.get_small_sequential_mlp(3, 4)
optimizer = keras.optimizer_v2.adam.Adam()
model.compile(
optimizer, 'mse', metrics=['acc', metrics.categorical_accuracy],
)
model.fit(
x=np.array([[1., 2., 3., 4.]]),
y=np.array([[1., 1., 1., 1.]]),
epochs=1)
optimizer_config = optimizer.get_config()
with ops.Graph().as_default():
with self.session():
with self.assertRaisesRegexp(ValueError,
'Cannot use the given session'):
models.clone_and_build_model(model, compile_clone=True)
# The optimizer_config object allows the model to be cloned in a
# different graph.
models.clone_and_build_model(model, compile_clone=True,
optimizer_config=optimizer_config)
if __name__ == '__main__':
test.main()