add support to load checkpoint that saved by ModelCheckpoint callback for sidecar_evaluator.
PiperOrigin-RevId: 348071878 Change-Id: Id9ae47394adecb497d60a2b902da288265a4b400
This commit is contained in:
parent
48d3d6eb47
commit
1fbddee1b1
tensorflow/python/keras/distribute
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
|
||||
# pylint: disable=g-direct-tensorflow-import
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
@ -32,6 +34,26 @@ _PRINT_EVAL_STEP_EVERY_SEC = 60.0
|
||||
_ITERATIONS_UNINITIALIZED = -1
|
||||
|
||||
|
||||
def list_checkpoint_attributes(ckpt_dir_or_file):
|
||||
"""Lists all the attributes in a checkpoint.
|
||||
|
||||
Checkpoint keys are paths in a checkpoint graph, and attribute is the first
|
||||
element in the path. e.g. with a checkpoint key
|
||||
"optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE", optimizer is the attribute. The
|
||||
attribute is also used to save/restore a variable in a checkpoint,
|
||||
e.g. tf.train.Checkpoint(optimizer=optimizer, model=model).
|
||||
|
||||
Args:
|
||||
ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
|
||||
|
||||
Returns:
|
||||
Set of attributes in a checkpoint.
|
||||
"""
|
||||
reader = checkpoint_utils.load_checkpoint(ckpt_dir_or_file)
|
||||
variable_map = reader.get_variable_to_shape_map()
|
||||
return {name.split('/')[0] for name in variable_map.keys()}
|
||||
|
||||
|
||||
class SidecarEvaluator(object):
|
||||
"""A class designed for a dedicated evaluator task.
|
||||
|
||||
@ -148,6 +170,21 @@ class SidecarEvaluator(object):
|
||||
# `expect_partial` because the checkpoint can have other `Trackable`s
|
||||
# such as `optimizer`.
|
||||
checkpoint.restore(latest_checkpoint).expect_partial()
|
||||
checkpoint_attributes = list_checkpoint_attributes(latest_checkpoint)
|
||||
# The checkpoint should contain model and optimizer for SidecarEvaluator
|
||||
# to work. But the model weights saved by ModelCheckpoint callback does
|
||||
# not contain model as an attribute. To make SidecarEvaluator compatibly
|
||||
# work in this case, if model attribute is not found but
|
||||
# layer_with_weights attribute is found, use model.load_weights to load
|
||||
# the model's weights, while self._iterations is still restored by
|
||||
# checkpoint variable.
|
||||
if 'model' not in checkpoint_attributes:
|
||||
for attribute in checkpoint_attributes:
|
||||
# check whether the checkpoint has the required attributes for
|
||||
# model.load_weights to work.
|
||||
if re.match(r'^layer_with_weights-[\d+]', attribute) is not None:
|
||||
self.model.load_weights(latest_checkpoint)
|
||||
break
|
||||
except (errors_impl.OpError,) as e:
|
||||
# A couple errors can happen here with the coordinator racing to write
|
||||
# checkpoint:
|
||||
|
@ -20,7 +20,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
@ -36,6 +35,8 @@ from tensorflow.python.summary import summary_iterator
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training.tracking import util as tracking_util
|
||||
|
||||
_BATCH_SIZE = 32
|
||||
|
||||
|
||||
class SidecarEvaluatorTest(test.TestCase):
|
||||
|
||||
@ -130,7 +131,6 @@ class SidecarEvaluatorTest(test.TestCase):
|
||||
|
||||
self.assertSummaryEventsWritten(log_dir)
|
||||
|
||||
@unittest.skip('b/172976255')
|
||||
def testSidecarEvaluatorOutputsSummarySavedWithCallback(self):
|
||||
checkpoint_dir = os.path.join(self.get_temp_dir(), 'checkpoints')
|
||||
log_dir = os.path.join(self.get_temp_dir(), 'summary')
|
||||
@ -139,7 +139,7 @@ class SidecarEvaluatorTest(test.TestCase):
|
||||
data = np.random.random((1000, 32))
|
||||
labels = np.random.random((1000, 10))
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((data, labels))
|
||||
dataset = dataset.batch(32)
|
||||
dataset = dataset.batch(_BATCH_SIZE)
|
||||
save_callback = keras.callbacks.ModelCheckpoint(
|
||||
filepath=os.path.join(checkpoint_dir, 'ckpt-{epoch}'),
|
||||
save_weights_only=True)
|
||||
@ -152,17 +152,22 @@ class SidecarEvaluatorTest(test.TestCase):
|
||||
# Create a new model used for evaluation.
|
||||
eval_model = self.createTestModel(compile_model=True)
|
||||
# Have an sidecar_evaluator evaluate once.
|
||||
sidecar_evaluator_lib.SidecarEvaluator(
|
||||
sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator(
|
||||
eval_model,
|
||||
data=dataset,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
max_evaluations=1).start()
|
||||
max_evaluations=1)
|
||||
sidecar_evaluator.start()
|
||||
|
||||
# Eval model has been restored to the same state as the original model, so
|
||||
# their weights should match. If not, restoration of the model didn't
|
||||
# work.
|
||||
self.assertModelsSameVariables(model, eval_model)
|
||||
|
||||
# check the iterations is restored.
|
||||
self.assertEqual(sidecar_evaluator._iterations.numpy(), _BATCH_SIZE)
|
||||
|
||||
self.assertSummaryEventsWritten(log_dir)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user