add support to load checkpoint that saved by ModelCheckpoint callback for sidecar_evaluator.

PiperOrigin-RevId: 348071878
Change-Id: Id9ae47394adecb497d60a2b902da288265a4b400
This commit is contained in:
A. Unique TensorFlower 2020-12-17 12:26:33 -08:00 committed by TensorFlower Gardener
parent 48d3d6eb47
commit 1fbddee1b1
2 changed files with 47 additions and 5 deletions
tensorflow/python/keras/distribute

View File

@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@ -32,6 +34,26 @@ _PRINT_EVAL_STEP_EVERY_SEC = 60.0
_ITERATIONS_UNINITIALIZED = -1
def list_checkpoint_attributes(ckpt_dir_or_file):
"""Lists all the attributes in a checkpoint.
Checkpoint keys are paths in a checkpoint graph, and attribute is the first
element in the path. e.g. with a checkpoint key
"optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE", optimizer is the attribute. The
attribute is also used to save/restore a variable in a checkpoint,
e.g. tf.train.Checkpoint(optimizer=optimizer, model=model).
Args:
ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
Returns:
Set of attributes in a checkpoint.
"""
reader = checkpoint_utils.load_checkpoint(ckpt_dir_or_file)
variable_map = reader.get_variable_to_shape_map()
return {name.split('/')[0] for name in variable_map.keys()}
class SidecarEvaluator(object):
"""A class designed for a dedicated evaluator task.
@ -148,6 +170,21 @@ class SidecarEvaluator(object):
# `expect_partial` because the checkpoint can have other `Trackable`s
# such as `optimizer`.
checkpoint.restore(latest_checkpoint).expect_partial()
checkpoint_attributes = list_checkpoint_attributes(latest_checkpoint)
# The checkpoint should contain model and optimizer for SidecarEvaluator
# to work. But the model weights saved by ModelCheckpoint callback does
# not contain model as an attribute. To make SidecarEvaluator compatibly
# work in this case, if model attribute is not found but
# layer_with_weights attribute is found, use model.load_weights to load
# the model's weights, while self._iterations is still restored by
# checkpoint variable.
if 'model' not in checkpoint_attributes:
for attribute in checkpoint_attributes:
# check whether the checkpoint has the required attributes for
# model.load_weights to work.
if re.match(r'^layer_with_weights-[\d+]', attribute) is not None:
self.model.load_weights(latest_checkpoint)
break
except (errors_impl.OpError,) as e:
# A couple errors can happen here with the coordinator racing to write
# checkpoint:

View File

@ -20,7 +20,6 @@ from __future__ import division
from __future__ import print_function
import os
import unittest
from absl import logging
import numpy as np
@ -36,6 +35,8 @@ from tensorflow.python.summary import summary_iterator
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.tracking import util as tracking_util
_BATCH_SIZE = 32
class SidecarEvaluatorTest(test.TestCase):
@ -130,7 +131,6 @@ class SidecarEvaluatorTest(test.TestCase):
self.assertSummaryEventsWritten(log_dir)
@unittest.skip('b/172976255')
def testSidecarEvaluatorOutputsSummarySavedWithCallback(self):
checkpoint_dir = os.path.join(self.get_temp_dir(), 'checkpoints')
log_dir = os.path.join(self.get_temp_dir(), 'summary')
@ -139,7 +139,7 @@ class SidecarEvaluatorTest(test.TestCase):
data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))
dataset = dataset_ops.Dataset.from_tensor_slices((data, labels))
dataset = dataset.batch(32)
dataset = dataset.batch(_BATCH_SIZE)
save_callback = keras.callbacks.ModelCheckpoint(
filepath=os.path.join(checkpoint_dir, 'ckpt-{epoch}'),
save_weights_only=True)
@ -152,17 +152,22 @@ class SidecarEvaluatorTest(test.TestCase):
# Create a new model used for evaluation.
eval_model = self.createTestModel(compile_model=True)
# Have an sidecar_evaluator evaluate once.
sidecar_evaluator_lib.SidecarEvaluator(
sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator(
eval_model,
data=dataset,
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
max_evaluations=1).start()
max_evaluations=1)
sidecar_evaluator.start()
# Eval model has been restored to the same state as the original model, so
# their weights should match. If not, restoration of the model didn't
# work.
self.assertModelsSameVariables(model, eval_model)
# check the iterations is restored.
self.assertEqual(sidecar_evaluator._iterations.numpy(), _BATCH_SIZE)
self.assertSummaryEventsWritten(log_dir)