(CL 1/2) Allow model_to_estimator
to save object-based checkpoints.
This change adds an `optimizer_config` field to models.clone_and_build, which allows the optimizer config to be passed in (instead of generating the config in the function -- this caused issues when clone_and_build is called in a different context from the original model). PiperOrigin-RevId: 248605885
This commit is contained in:
parent
1bb6236372
commit
fe20616775
@ -480,7 +480,8 @@ def in_place_subclassed_model_state_restoration(model):
|
||||
|
||||
def clone_and_build_model(
|
||||
model, input_tensors=None, target_tensors=None, custom_objects=None,
|
||||
compile_clone=True, in_place_reset=False, optimizer_iterations=None):
|
||||
compile_clone=True, in_place_reset=False, optimizer_iterations=None,
|
||||
optimizer_config=None):
|
||||
"""Clone a `Model` and build/compile it with the same settings used before.
|
||||
|
||||
This function can be be run in the same graph or in a separate graph from the
|
||||
@ -508,6 +509,10 @@ def clone_and_build_model(
|
||||
optimizer if the clone is compiled. This argument is used when a Keras
|
||||
model is cloned into an Estimator model function, because Estimators
|
||||
create their own global step variable.
|
||||
optimizer_config: Optimizer config dictionary returned from `get_config()`.
|
||||
This argument should be defined if `clone_and_build_model` is called in
|
||||
a different graph or session from the original model, and the optimizer is
|
||||
an instance of `OptimizerV2`.
|
||||
|
||||
Returns:
|
||||
Clone of the model.
|
||||
@ -562,7 +567,7 @@ def clone_and_build_model(
|
||||
orig_optimizer.optimizer, optimizer_iterations)
|
||||
K.track_tf_optimizer(optimizer)
|
||||
else:
|
||||
optimizer_config = orig_optimizer.get_config()
|
||||
optimizer_config = optimizer_config or orig_optimizer.get_config()
|
||||
optimizer = orig_optimizer.__class__.from_config(optimizer_config)
|
||||
if optimizer_iterations is not None:
|
||||
optimizer.iterations = optimizer_iterations
|
||||
|
@ -352,10 +352,10 @@ class TestModelDeepCopy(test.TestCase):
|
||||
model_copy.get_weights()[0]))
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class TestCloneAndBuildModel(keras_parameterized.TestCase):
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_clone_and_build_non_compiled_model(self):
|
||||
inp = np.random.random((10, 4))
|
||||
out = np.random.random((10, 4))
|
||||
@ -436,6 +436,7 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase):
|
||||
new_model.evaluate(inp, out)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_clone_and_build_compiled(self):
|
||||
model = _get_model()
|
||||
model.compile(
|
||||
@ -445,6 +446,7 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase):
|
||||
|
||||
self._clone_and_build_test_helper(model, testing_utils.get_model_type())
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_clone_and_build_sequential_without_inputs_defined(self):
|
||||
model = models.Sequential(_get_layers(input_shape=None))
|
||||
model.compile(
|
||||
@ -476,10 +478,12 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase):
|
||||
self.assertEqual(K.eval(global_step), 124)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_replace_tf_optimizer_iterations_variable(self):
|
||||
self.assert_optimizer_iterations_increases(adam.AdamOptimizer(0.01))
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_replace_keras_optimizer_iterations_variable(self):
|
||||
if testing_utils.should_run_eagerly():
|
||||
# This needs to be updated to run with v2 optimizers.
|
||||
@ -487,6 +491,29 @@ class TestCloneAndBuildModel(keras_parameterized.TestCase):
|
||||
|
||||
self.assert_optimizer_iterations_increases('adam')
|
||||
|
||||
def test_clone_optimizer_in_different_graph(self):
|
||||
with ops.Graph().as_default():
|
||||
with self.session():
|
||||
model = testing_utils.get_small_sequential_mlp(3, 4)
|
||||
optimizer = keras.optimizer_v2.adam.Adam()
|
||||
model.compile(
|
||||
optimizer, 'mse', metrics=['acc', metrics.categorical_accuracy],
|
||||
)
|
||||
model.fit(
|
||||
x=np.array([[1., 2., 3., 4.]]),
|
||||
y=np.array([[1., 1., 1., 1.]]),
|
||||
epochs=1)
|
||||
optimizer_config = optimizer.get_config()
|
||||
with ops.Graph().as_default():
|
||||
with self.session():
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'Cannot use the given session'):
|
||||
models.clone_and_build_model(model, compile_clone=True)
|
||||
# The optimizer_config object allows the model to be cloned in a
|
||||
# different graph.
|
||||
models.clone_and_build_model(model, compile_clone=True,
|
||||
optimizer_config=optimizer_config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user