Add finalizer callback to ExportStrategy

Change: 144897533
This commit is contained in:
David Soergel 2017-01-18 16:25:13 -08:00 committed by TensorFlower Gardener
parent d2eb7da54f
commit d8fd68242e
4 changed files with 117 additions and 54 deletions

View File

@ -355,48 +355,51 @@ class Experiment(object):
previous_path = None
eval_result = None
last_warning_time = 0
while (not self.continuous_eval_predicate_fn or
self.continuous_eval_predicate_fn(eval_result)):
start = time.time()
try:
while (not self.continuous_eval_predicate_fn or
self.continuous_eval_predicate_fn(eval_result)):
start = time.time()
error_msg = None
latest_path = saver.latest_checkpoint(self._estimator.model_dir)
if not latest_path:
error_msg = ("Estimator is not fitted yet. "
"Will start an evaluation when a checkpoint is ready.")
elif evaluate_checkpoint_only_once and latest_path == previous_path:
error_msg = "No new checkpoint ready for evaluation."
error_msg = None
latest_path = saver.latest_checkpoint(self._estimator.model_dir)
if not latest_path:
error_msg = ("Estimator is not fitted yet. "
"Will start an evaluation when a checkpoint is ready.")
elif evaluate_checkpoint_only_once and latest_path == previous_path:
error_msg = "No new checkpoint ready for evaluation."
if error_msg:
# Print warning message every 10 mins.
eval_result = {}
if time.time() - last_warning_time > 600:
logging.warning(error_msg)
last_warning_time = time.time()
else:
eval_result = self._estimator.evaluate(input_fn=input_fn,
steps=self._eval_steps,
metrics=self._eval_metrics,
name=name,
checkpoint_path=latest_path,
hooks=self._eval_hooks)
# Ensure eval result is not None for next round of evaluation.
if not eval_result:
if error_msg:
# Print warning message every 10 mins.
eval_result = {}
if time.time() - last_warning_time > 600:
logging.warning(error_msg)
last_warning_time = time.time()
else:
eval_result = self._estimator.evaluate(input_fn=input_fn,
steps=self._eval_steps,
metrics=self._eval_metrics,
name=name,
checkpoint_path=latest_path,
hooks=self._eval_hooks)
# Ensure eval result is not None for next round of evaluation.
if not eval_result:
eval_result = {}
# TODO(soergel): further throttle how often export happens?
self._maybe_export(eval_result)
# TODO(soergel): further throttle how often export happens?
self._maybe_export(eval_result)
# Clear warning timer and update last evaluated checkpoint
last_warning_time = 0
previous_path = latest_path
# Clear warning timer and update last evaluated checkpoint
last_warning_time = 0
previous_path = latest_path
duration = time.time() - start
if duration < throttle_delay_secs:
difference = throttle_delay_secs - duration
logging.info("Waiting %f secs before starting next eval run.",
difference)
time.sleep(difference)
duration = time.time() - start
if duration < throttle_delay_secs:
difference = throttle_delay_secs - duration
logging.info("Waiting %f secs before starting next eval run.",
difference)
time.sleep(difference)
finally:
self._finalize_exports()
def continuous_eval(self,
delay_secs=None,
@ -465,14 +468,18 @@ class Experiment(object):
name=eval_dir_suffix,
hooks=self._eval_hooks)
export_results = self._maybe_export(eval_result)
self._finalize_exports()
return eval_result, export_results
def _maybe_export(self, eval_result): # pylint: disable=unused-argument
"""Export the Estimator using export_fn, if defined."""
export_dir_base = os.path.join(
def _get_export_dir_base(self):
return os.path.join(
compat.as_bytes(self._estimator.model_dir),
compat.as_bytes("export"))
def _maybe_export(self, eval_result): # pylint: disable=unused-argument
"""Export the Estimator using ExportStrategies, if defined."""
export_dir_base = self._get_export_dir_base()
export_results = []
for strategy in self._export_strategies:
# TODO(soergel): possibly, allow users to decide whether to export here
@ -487,6 +494,17 @@ class Experiment(object):
return export_results
def _finalize_exports(self):
"""Perform any final cleanup defined by the ExportStrategies."""
export_dir_base = self._get_export_dir_base()
for strategy in self._export_strategies:
if strategy.end_fn:
strategy.end_fn(
os.path.join(
compat.as_bytes(export_dir_base),
compat.as_bytes(strategy.name)))
def run_std_server(self):
"""Starts a TensorFlow server and joins the serving thread.

View File

@ -13,8 +13,7 @@
# limitations under the License.
# ==============================================================================
"""ExportStrategy class that provides strategies to export model so later it
can be used for TensorFlow serving."""
"""ExportStrategy class represents different flavors of model export."""
from __future__ import absolute_import
from __future__ import division
@ -26,8 +25,13 @@ __all__ = ['ExportStrategy']
class ExportStrategy(collections.namedtuple('ExportStrategy',
['name', 'export_fn'])):
['name',
'export_fn',
'end_fn'])):
def export(self, estimator, export_path):
return self.export_fn(estimator, export_path)
def end(self, export_path):
return self.end_fn(export_path)

View File

@ -223,6 +223,33 @@ def get_timestamped_export_dir(export_dir_base):
return export_dir
# create a simple parser that pulls the export_version from the directory.
def _export_version_parser(path):
filename = os.path.basename(path.path)
if not (len(filename) == 10 and filename.isdigit()):
return None
return path._replace(export_version=int(filename))
def get_most_recent_export(export_dir_base):
"""Locate the most recent SavedModel export in a directory of many exports.
This method assumes that SavedModel subdirectories are named as a timestamp
(seconds from epoch), as produced by get_timestamped_export_dir().
Args:
export_dir_base: A base directory containing multiple timestamped
directories.
Returns:
A gc.Path, whith is just a namedtuple of (path, export_version).
"""
select_filter = gc.largest_export_versions(1)
results = select_filter(gc.get_paths(export_dir_base,
parser=_export_version_parser))
return next(iter(results or []), None)
def garbage_collect_exports(export_dir_base, exports_to_keep):
"""Deletes older exports, retaining only a given number of the most recent.
@ -239,15 +266,8 @@ def garbage_collect_exports(export_dir_base, exports_to_keep):
keep_filter = gc.largest_export_versions(exports_to_keep)
delete_filter = gc.negation(keep_filter)
# create a simple parser that pulls the export_version from the directory.
def parser(path):
filename = os.path.basename(path.path)
if not (len(filename) == 10 and filename.isdigit()):
return None
return path._replace(export_version=int(filename))
for p in delete_filter(gc.get_paths(export_dir_base, parser=parser)):
for p in delete_filter(gc.get_paths(export_dir_base,
parser=_export_version_parser)):
gfile.DeleteRecursively(p.path)
@ -255,11 +275,12 @@ def make_export_strategy(export_input_fn,
default_output_alternative_key='default',
assets_extra=None,
as_text=False,
exports_to_keep=5):
exports_to_keep=5,
end_fn=None):
"""Create an ExportStrategy for use with Experiment.
Args:
export_input_fn: A function that takes no argument and returns an
export_input_fn: A function that takes no arguments and returns an
`InputFnOps`.
default_output_alternative_key: the name of the head to serve when an
incoming serving request does not explicitly request a specific head.
@ -275,6 +296,10 @@ def make_export_strategy(export_input_fn,
exports_to_keep: Number of exports to keep. Older exports will be
garbage-collected. Defaults to 5. Set to None to disable garbage
collection.
end_fn: A function to be run at the end of training, taking a single
argument naming the ExportStrategy-specific export directory. This is
typically used to take some action regarding the most recent export, such
as copying it to another location.
Returns:
an ExportStrategy that can be passed to the Experiment constructor.
@ -301,4 +326,4 @@ def make_export_strategy(export_input_fn,
garbage_collect_exports(export_dir_base, exports_to_keep)
return export_result
return export_strategy.ExportStrategy('Servo', export_fn)
return export_strategy.ExportStrategy('Servo', export_fn, end_fn)

View File

@ -281,6 +281,22 @@ class SavedModelExportUtilsTest(test.TestCase):
self.assertTrue(gfile.Exists(export_dir_3))
self.assertTrue(gfile.Exists(export_dir_4))
def test_get_most_recent_export(self):
export_dir_base = tempfile.mkdtemp() + "export/"
gfile.MkDir(export_dir_base)
_create_test_export_dir(export_dir_base)
_create_test_export_dir(export_dir_base)
_create_test_export_dir(export_dir_base)
export_dir_4 = _create_test_export_dir(export_dir_base)
(most_recent_export_dir, most_recent_export_version) = (
saved_model_export_utils.get_most_recent_export(export_dir_base))
self.assertEqual(export_dir_4, most_recent_export_dir)
self.assertEqual(export_dir_4,
os.path.join(export_dir_base,
str(most_recent_export_version)))
def test_make_export_strategy(self):
"""Only tests that an ExportStrategy instance is created."""
def _export_input_fn():