Add finalizer callback to ExportStrategy
Change: 144897533
This commit is contained in:
parent
d2eb7da54f
commit
d8fd68242e
@ -355,48 +355,51 @@ class Experiment(object):
|
||||
previous_path = None
|
||||
eval_result = None
|
||||
last_warning_time = 0
|
||||
while (not self.continuous_eval_predicate_fn or
|
||||
self.continuous_eval_predicate_fn(eval_result)):
|
||||
start = time.time()
|
||||
try:
|
||||
while (not self.continuous_eval_predicate_fn or
|
||||
self.continuous_eval_predicate_fn(eval_result)):
|
||||
start = time.time()
|
||||
|
||||
error_msg = None
|
||||
latest_path = saver.latest_checkpoint(self._estimator.model_dir)
|
||||
if not latest_path:
|
||||
error_msg = ("Estimator is not fitted yet. "
|
||||
"Will start an evaluation when a checkpoint is ready.")
|
||||
elif evaluate_checkpoint_only_once and latest_path == previous_path:
|
||||
error_msg = "No new checkpoint ready for evaluation."
|
||||
error_msg = None
|
||||
latest_path = saver.latest_checkpoint(self._estimator.model_dir)
|
||||
if not latest_path:
|
||||
error_msg = ("Estimator is not fitted yet. "
|
||||
"Will start an evaluation when a checkpoint is ready.")
|
||||
elif evaluate_checkpoint_only_once and latest_path == previous_path:
|
||||
error_msg = "No new checkpoint ready for evaluation."
|
||||
|
||||
if error_msg:
|
||||
# Print warning message every 10 mins.
|
||||
eval_result = {}
|
||||
if time.time() - last_warning_time > 600:
|
||||
logging.warning(error_msg)
|
||||
last_warning_time = time.time()
|
||||
else:
|
||||
eval_result = self._estimator.evaluate(input_fn=input_fn,
|
||||
steps=self._eval_steps,
|
||||
metrics=self._eval_metrics,
|
||||
name=name,
|
||||
checkpoint_path=latest_path,
|
||||
hooks=self._eval_hooks)
|
||||
# Ensure eval result is not None for next round of evaluation.
|
||||
if not eval_result:
|
||||
if error_msg:
|
||||
# Print warning message every 10 mins.
|
||||
eval_result = {}
|
||||
if time.time() - last_warning_time > 600:
|
||||
logging.warning(error_msg)
|
||||
last_warning_time = time.time()
|
||||
else:
|
||||
eval_result = self._estimator.evaluate(input_fn=input_fn,
|
||||
steps=self._eval_steps,
|
||||
metrics=self._eval_metrics,
|
||||
name=name,
|
||||
checkpoint_path=latest_path,
|
||||
hooks=self._eval_hooks)
|
||||
# Ensure eval result is not None for next round of evaluation.
|
||||
if not eval_result:
|
||||
eval_result = {}
|
||||
|
||||
# TODO(soergel): further throttle how often export happens?
|
||||
self._maybe_export(eval_result)
|
||||
# TODO(soergel): further throttle how often export happens?
|
||||
self._maybe_export(eval_result)
|
||||
|
||||
# Clear warning timer and update last evaluated checkpoint
|
||||
last_warning_time = 0
|
||||
previous_path = latest_path
|
||||
# Clear warning timer and update last evaluated checkpoint
|
||||
last_warning_time = 0
|
||||
previous_path = latest_path
|
||||
|
||||
duration = time.time() - start
|
||||
if duration < throttle_delay_secs:
|
||||
difference = throttle_delay_secs - duration
|
||||
logging.info("Waiting %f secs before starting next eval run.",
|
||||
difference)
|
||||
time.sleep(difference)
|
||||
duration = time.time() - start
|
||||
if duration < throttle_delay_secs:
|
||||
difference = throttle_delay_secs - duration
|
||||
logging.info("Waiting %f secs before starting next eval run.",
|
||||
difference)
|
||||
time.sleep(difference)
|
||||
finally:
|
||||
self._finalize_exports()
|
||||
|
||||
def continuous_eval(self,
|
||||
delay_secs=None,
|
||||
@ -465,14 +468,18 @@ class Experiment(object):
|
||||
name=eval_dir_suffix,
|
||||
hooks=self._eval_hooks)
|
||||
export_results = self._maybe_export(eval_result)
|
||||
self._finalize_exports()
|
||||
return eval_result, export_results
|
||||
|
||||
def _maybe_export(self, eval_result): # pylint: disable=unused-argument
|
||||
"""Export the Estimator using export_fn, if defined."""
|
||||
export_dir_base = os.path.join(
|
||||
def _get_export_dir_base(self):
|
||||
return os.path.join(
|
||||
compat.as_bytes(self._estimator.model_dir),
|
||||
compat.as_bytes("export"))
|
||||
|
||||
def _maybe_export(self, eval_result): # pylint: disable=unused-argument
|
||||
"""Export the Estimator using ExportStrategies, if defined."""
|
||||
export_dir_base = self._get_export_dir_base()
|
||||
|
||||
export_results = []
|
||||
for strategy in self._export_strategies:
|
||||
# TODO(soergel): possibly, allow users to decide whether to export here
|
||||
@ -487,6 +494,17 @@ class Experiment(object):
|
||||
|
||||
return export_results
|
||||
|
||||
def _finalize_exports(self):
|
||||
"""Perform any final cleanup defined by the ExportStrategies."""
|
||||
export_dir_base = self._get_export_dir_base()
|
||||
|
||||
for strategy in self._export_strategies:
|
||||
if strategy.end_fn:
|
||||
strategy.end_fn(
|
||||
os.path.join(
|
||||
compat.as_bytes(export_dir_base),
|
||||
compat.as_bytes(strategy.name)))
|
||||
|
||||
def run_std_server(self):
|
||||
"""Starts a TensorFlow server and joins the serving thread.
|
||||
|
||||
|
@ -13,8 +13,7 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""ExportStrategy class that provides strategies to export model so later it
|
||||
can be used for TensorFlow serving."""
|
||||
"""ExportStrategy class represents different flavors of model export."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -26,8 +25,13 @@ __all__ = ['ExportStrategy']
|
||||
|
||||
|
||||
class ExportStrategy(collections.namedtuple('ExportStrategy',
|
||||
['name', 'export_fn'])):
|
||||
['name',
|
||||
'export_fn',
|
||||
'end_fn'])):
|
||||
|
||||
def export(self, estimator, export_path):
|
||||
return self.export_fn(estimator, export_path)
|
||||
|
||||
def end(self, export_path):
|
||||
return self.end_fn(export_path)
|
||||
|
||||
|
@ -223,6 +223,33 @@ def get_timestamped_export_dir(export_dir_base):
|
||||
return export_dir
|
||||
|
||||
|
||||
# create a simple parser that pulls the export_version from the directory.
|
||||
def _export_version_parser(path):
|
||||
filename = os.path.basename(path.path)
|
||||
if not (len(filename) == 10 and filename.isdigit()):
|
||||
return None
|
||||
return path._replace(export_version=int(filename))
|
||||
|
||||
|
||||
def get_most_recent_export(export_dir_base):
|
||||
"""Locate the most recent SavedModel export in a directory of many exports.
|
||||
|
||||
This method assumes that SavedModel subdirectories are named as a timestamp
|
||||
(seconds from epoch), as produced by get_timestamped_export_dir().
|
||||
|
||||
Args:
|
||||
export_dir_base: A base directory containing multiple timestamped
|
||||
directories.
|
||||
|
||||
Returns:
|
||||
A gc.Path, whith is just a namedtuple of (path, export_version).
|
||||
"""
|
||||
select_filter = gc.largest_export_versions(1)
|
||||
results = select_filter(gc.get_paths(export_dir_base,
|
||||
parser=_export_version_parser))
|
||||
return next(iter(results or []), None)
|
||||
|
||||
|
||||
def garbage_collect_exports(export_dir_base, exports_to_keep):
|
||||
"""Deletes older exports, retaining only a given number of the most recent.
|
||||
|
||||
@ -239,15 +266,8 @@ def garbage_collect_exports(export_dir_base, exports_to_keep):
|
||||
|
||||
keep_filter = gc.largest_export_versions(exports_to_keep)
|
||||
delete_filter = gc.negation(keep_filter)
|
||||
|
||||
# create a simple parser that pulls the export_version from the directory.
|
||||
def parser(path):
|
||||
filename = os.path.basename(path.path)
|
||||
if not (len(filename) == 10 and filename.isdigit()):
|
||||
return None
|
||||
return path._replace(export_version=int(filename))
|
||||
|
||||
for p in delete_filter(gc.get_paths(export_dir_base, parser=parser)):
|
||||
for p in delete_filter(gc.get_paths(export_dir_base,
|
||||
parser=_export_version_parser)):
|
||||
gfile.DeleteRecursively(p.path)
|
||||
|
||||
|
||||
@ -255,11 +275,12 @@ def make_export_strategy(export_input_fn,
|
||||
default_output_alternative_key='default',
|
||||
assets_extra=None,
|
||||
as_text=False,
|
||||
exports_to_keep=5):
|
||||
exports_to_keep=5,
|
||||
end_fn=None):
|
||||
"""Create an ExportStrategy for use with Experiment.
|
||||
|
||||
Args:
|
||||
export_input_fn: A function that takes no argument and returns an
|
||||
export_input_fn: A function that takes no arguments and returns an
|
||||
`InputFnOps`.
|
||||
default_output_alternative_key: the name of the head to serve when an
|
||||
incoming serving request does not explicitly request a specific head.
|
||||
@ -275,6 +296,10 @@ def make_export_strategy(export_input_fn,
|
||||
exports_to_keep: Number of exports to keep. Older exports will be
|
||||
garbage-collected. Defaults to 5. Set to None to disable garbage
|
||||
collection.
|
||||
end_fn: A function to be run at the end of training, taking a single
|
||||
argument naming the ExportStrategy-specific export directory. This is
|
||||
typically used to take some action regarding the most recent export, such
|
||||
as copying it to another location.
|
||||
|
||||
Returns:
|
||||
an ExportStrategy that can be passed to the Experiment constructor.
|
||||
@ -301,4 +326,4 @@ def make_export_strategy(export_input_fn,
|
||||
garbage_collect_exports(export_dir_base, exports_to_keep)
|
||||
return export_result
|
||||
|
||||
return export_strategy.ExportStrategy('Servo', export_fn)
|
||||
return export_strategy.ExportStrategy('Servo', export_fn, end_fn)
|
||||
|
@ -281,6 +281,22 @@ class SavedModelExportUtilsTest(test.TestCase):
|
||||
self.assertTrue(gfile.Exists(export_dir_3))
|
||||
self.assertTrue(gfile.Exists(export_dir_4))
|
||||
|
||||
def test_get_most_recent_export(self):
|
||||
export_dir_base = tempfile.mkdtemp() + "export/"
|
||||
gfile.MkDir(export_dir_base)
|
||||
_create_test_export_dir(export_dir_base)
|
||||
_create_test_export_dir(export_dir_base)
|
||||
_create_test_export_dir(export_dir_base)
|
||||
export_dir_4 = _create_test_export_dir(export_dir_base)
|
||||
|
||||
(most_recent_export_dir, most_recent_export_version) = (
|
||||
saved_model_export_utils.get_most_recent_export(export_dir_base))
|
||||
|
||||
self.assertEqual(export_dir_4, most_recent_export_dir)
|
||||
self.assertEqual(export_dir_4,
|
||||
os.path.join(export_dir_base,
|
||||
str(most_recent_export_version)))
|
||||
|
||||
def test_make_export_strategy(self):
|
||||
"""Only tests that an ExportStrategy instance is created."""
|
||||
def _export_input_fn():
|
||||
|
Loading…
x
Reference in New Issue
Block a user