Move 'make_export_strategy' utility function from .contrib to .core.
That function depends on the utilities related to garbage collection. They are moved too, but are kept private. PiperOrigin-RevId: 169927321
This commit is contained in:
parent
dbaf176e1f
commit
01620694d8
@ -8,25 +8,6 @@ licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
py_library(
|
||||
name = "estimator_py",
|
||||
srcs = ["estimator_lib.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":dnn",
|
||||
":dnn_linear_combined",
|
||||
":estimator",
|
||||
":export",
|
||||
":inputs",
|
||||
":linear",
|
||||
":model_fn",
|
||||
":parsing_utils",
|
||||
":run_config",
|
||||
":training",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
@ -39,6 +20,77 @@ filegroup(
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "estimator_py",
|
||||
srcs = ["estimator_lib.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":dnn",
|
||||
":dnn_linear_combined",
|
||||
":estimator",
|
||||
":export",
|
||||
":export_strategy",
|
||||
":inputs",
|
||||
":linear",
|
||||
":model_fn",
|
||||
":parsing_utils",
|
||||
":run_config",
|
||||
":training",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "export_strategy",
|
||||
srcs = ["export_strategy.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":gc",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "export_strategy_test",
|
||||
size = "small",
|
||||
srcs = ["export_strategy_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":estimator",
|
||||
":export_strategy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gc",
|
||||
srcs = ["gc.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "gc_test",
|
||||
size = "small",
|
||||
srcs = ["gc_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":gc",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_fn",
|
||||
srcs = ["model_fn.py"],
|
||||
|
@ -19,10 +19,15 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
|
||||
from tensorflow.python.estimator import gc
|
||||
from tensorflow.python.estimator import util
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging
|
||||
|
||||
__all__ = ['ExportStrategy']
|
||||
__all__ = ['ExportStrategy', 'make_export_strategy']
|
||||
|
||||
|
||||
class ExportStrategy(
|
||||
@ -81,3 +86,89 @@ class ExportStrategy(
|
||||
kwargs['eval_result'] = eval_result
|
||||
|
||||
return self.export_fn(estimator, export_path, **kwargs)
|
||||
|
||||
|
||||
def make_export_strategy(serving_input_fn,
|
||||
assets_extra=None,
|
||||
as_text=False,
|
||||
exports_to_keep=5):
|
||||
"""Create an ExportStrategy for use with tf.estimator.EvalSpec.
|
||||
|
||||
Args:
|
||||
serving_input_fn: a function that takes no arguments and returns an
|
||||
`ServingInputReceiver`.
|
||||
assets_extra: A dict specifying how to populate the assets.extra directory
|
||||
within the exported SavedModel. Each key should give the destination
|
||||
path (including the filename) relative to the assets.extra directory.
|
||||
The corresponding value gives the full path of the source file to be
|
||||
copied. For example, the simple case of copying a single file without
|
||||
renaming it is specified as
|
||||
`{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
|
||||
as_text: whether to write the SavedModel proto in text format.
|
||||
exports_to_keep: Number of exports to keep. Older exports will be
|
||||
garbage-collected. Defaults to 5. Set to None to disable garbage
|
||||
collection.
|
||||
|
||||
Returns:
|
||||
An `ExportStrategy` that can be passed to the Experiment constructor.
|
||||
"""
|
||||
|
||||
def export_fn(estimator, export_dir_base, checkpoint_path=None):
|
||||
"""Exports the given Estimator as a SavedModel.
|
||||
|
||||
Args:
|
||||
estimator: the Estimator to export.
|
||||
export_dir_base: A string containing a directory to write the exported
|
||||
graph and checkpoints.
|
||||
checkpoint_path: The checkpoint path to export. If None (the default),
|
||||
the most recent checkpoint found within the model directory is chosen.
|
||||
|
||||
Returns:
|
||||
The string path to the exported directory.
|
||||
|
||||
Raises:
|
||||
ValueError: If `estimator` is a ${tf.estimator.Estimator} instance
|
||||
and `default_output_alternative_key` was specified.
|
||||
"""
|
||||
export_result = estimator.export_savedmodel(
|
||||
export_dir_base,
|
||||
serving_input_fn,
|
||||
assets_extra=assets_extra,
|
||||
as_text=as_text,
|
||||
checkpoint_path=checkpoint_path)
|
||||
|
||||
_garbage_collect_exports(export_dir_base, exports_to_keep)
|
||||
return export_result
|
||||
|
||||
return ExportStrategy('Servo', export_fn)
|
||||
|
||||
|
||||
def _garbage_collect_exports(export_dir_base, exports_to_keep):
|
||||
"""Deletes older exports, retaining only a given number of the most recent.
|
||||
|
||||
Export subdirectories are assumed to be named with monotonically increasing
|
||||
integers; the most recent are taken to be those with the largest values.
|
||||
|
||||
Args:
|
||||
export_dir_base: the base directory under which each export is in a
|
||||
versioned subdirectory.
|
||||
exports_to_keep: the number of recent exports to retain.
|
||||
"""
|
||||
if exports_to_keep is None:
|
||||
return
|
||||
|
||||
def _export_version_parser(path):
|
||||
# create a simple parser that pulls the export_version from the directory.
|
||||
filename = os.path.basename(path.path)
|
||||
if not (len(filename) == 10 and filename.isdigit()):
|
||||
return None
|
||||
return path._replace(export_version=int(filename))
|
||||
|
||||
keep_filter = gc._largest_export_versions(exports_to_keep)
|
||||
delete_filter = gc._negation(keep_filter)
|
||||
for p in delete_filter(
|
||||
gc._get_paths(export_dir_base, parser=_export_version_parser)):
|
||||
try:
|
||||
gfile.DeleteRecursively(p.path)
|
||||
except errors_impl.NotFoundError as e:
|
||||
tf_logging.warn('Can not delete %s recursively: %s', p.path, e)
|
||||
|
261
tensorflow/python/estimator/export_strategy_test.py
Normal file
261
tensorflow/python/estimator/export_strategy_test.py
Normal file
@ -0,0 +1,261 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `make_export_strategy`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from tensorflow.python.estimator import estimator as estimator_lib
|
||||
from tensorflow.python.estimator import export_strategy as export_strategy_lib
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
class ExportStrategyTest(test.TestCase):
|
||||
|
||||
def testAcceptsNameAndFn(self):
|
||||
def export_fn(estimator, export_path):
|
||||
del estimator, export_path
|
||||
|
||||
export_strategy = export_strategy_lib.ExportStrategy(
|
||||
name="test", export_fn=export_fn)
|
||||
|
||||
self.assertEqual("test", export_strategy.name)
|
||||
self.assertEqual(export_fn, export_strategy.export_fn)
|
||||
|
||||
def testCallsExportFnThatDoesntKnowExtraArguments(self):
|
||||
expected_estimator = {}
|
||||
|
||||
def export_fn(estimator, export_path):
|
||||
self.assertEqual(expected_estimator, estimator)
|
||||
self.assertEqual("expected_path", export_path)
|
||||
|
||||
export_strategy = export_strategy_lib.ExportStrategy(
|
||||
name="test", export_fn=export_fn)
|
||||
|
||||
export_strategy.export(
|
||||
estimator=expected_estimator, export_path="expected_path")
|
||||
|
||||
# Also works with additional arguments that `export_fn` doesn't support.
|
||||
# The lack of support is detected and the arguments aren't passed.
|
||||
export_strategy.export(
|
||||
estimator=expected_estimator,
|
||||
export_path="expected_path",
|
||||
checkpoint_path="unexpected_checkpoint_path")
|
||||
export_strategy.export(
|
||||
estimator=expected_estimator,
|
||||
export_path="expected_path",
|
||||
eval_result=())
|
||||
export_strategy.export(
|
||||
estimator=expected_estimator,
|
||||
export_path="expected_path",
|
||||
checkpoint_path="unexpected_checkpoint_path",
|
||||
eval_result=())
|
||||
|
||||
def testCallsExportFnThatKnowsAboutCheckpointPathButItsNotGiven(self):
|
||||
expected_estimator = {}
|
||||
|
||||
def export_fn(estimator, export_path, checkpoint_path):
|
||||
self.assertEqual(expected_estimator, estimator)
|
||||
self.assertEqual("expected_path", export_path)
|
||||
self.assertEqual(None, checkpoint_path)
|
||||
|
||||
export_strategy = export_strategy_lib.ExportStrategy(
|
||||
name="test", export_fn=export_fn)
|
||||
|
||||
export_strategy.export(
|
||||
estimator=expected_estimator, export_path="expected_path")
|
||||
export_strategy.export(
|
||||
estimator=expected_estimator,
|
||||
export_path="expected_path",
|
||||
eval_result=())
|
||||
|
||||
def testCallsExportFnWithCheckpointPath(self):
|
||||
expected_estimator = {}
|
||||
|
||||
def export_fn(estimator, export_path, checkpoint_path):
|
||||
self.assertEqual(expected_estimator, estimator)
|
||||
self.assertEqual("expected_path", export_path)
|
||||
self.assertEqual("expected_checkpoint_path", checkpoint_path)
|
||||
|
||||
export_strategy = export_strategy_lib.ExportStrategy(
|
||||
name="test", export_fn=export_fn)
|
||||
|
||||
export_strategy.export(
|
||||
estimator=expected_estimator,
|
||||
export_path="expected_path",
|
||||
checkpoint_path="expected_checkpoint_path")
|
||||
export_strategy.export(
|
||||
estimator=expected_estimator,
|
||||
export_path="expected_path",
|
||||
checkpoint_path="expected_checkpoint_path",
|
||||
eval_result=())
|
||||
|
||||
def testCallsExportFnThatKnowsAboutEvalResultButItsNotGiven(self):
|
||||
expected_estimator = {}
|
||||
|
||||
def export_fn(estimator, export_path, checkpoint_path, eval_result):
|
||||
self.assertEqual(expected_estimator, estimator)
|
||||
self.assertEqual("expected_path", export_path)
|
||||
self.assertEqual(None, checkpoint_path)
|
||||
self.assertEqual(None, eval_result)
|
||||
|
||||
export_strategy = export_strategy_lib.ExportStrategy(
|
||||
name="test", export_fn=export_fn)
|
||||
|
||||
export_strategy.export(
|
||||
estimator=expected_estimator, export_path="expected_path")
|
||||
|
||||
def testCallsExportFnThatAcceptsEvalResultButNotCheckpoint(self):
|
||||
expected_estimator = {}
|
||||
|
||||
def export_fn(estimator, export_path, eval_result):
|
||||
del estimator, export_path, eval_result
|
||||
raise RuntimeError("Should raise ValueError before this.")
|
||||
|
||||
export_strategy = export_strategy_lib.ExportStrategy(
|
||||
name="test", export_fn=export_fn)
|
||||
|
||||
expected_error_message = (
|
||||
"An export_fn accepting eval_result must also accept checkpoint_path")
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, expected_error_message):
|
||||
export_strategy.export(
|
||||
estimator=expected_estimator, export_path="expected_path")
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, expected_error_message):
|
||||
export_strategy.export(
|
||||
estimator=expected_estimator,
|
||||
export_path="expected_path",
|
||||
checkpoint_path="unexpected_checkpoint_path")
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, expected_error_message):
|
||||
export_strategy.export(
|
||||
estimator=expected_estimator,
|
||||
export_path="expected_path",
|
||||
eval_result=())
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, expected_error_message):
|
||||
export_strategy.export(
|
||||
estimator=expected_estimator,
|
||||
export_path="expected_path",
|
||||
checkpoint_path="unexpected_checkpoint_path",
|
||||
eval_result=())
|
||||
|
||||
def testCallsExportFnWithEvalResultAndCheckpointPath(self):
|
||||
expected_estimator = {}
|
||||
expected_eval_result = {}
|
||||
|
||||
def export_fn(estimator, export_path, checkpoint_path, eval_result):
|
||||
self.assertEqual(expected_estimator, estimator)
|
||||
self.assertEqual("expected_path", export_path)
|
||||
self.assertEqual("expected_checkpoint_path", checkpoint_path)
|
||||
self.assertEqual(expected_eval_result, eval_result)
|
||||
|
||||
export_strategy = export_strategy_lib.ExportStrategy(
|
||||
name="test", export_fn=export_fn)
|
||||
|
||||
export_strategy.export(
|
||||
estimator=expected_estimator,
|
||||
export_path="expected_path",
|
||||
checkpoint_path="expected_checkpoint_path",
|
||||
eval_result=expected_eval_result)
|
||||
|
||||
|
||||
class MakeExportStrategyTest(test.TestCase):
|
||||
|
||||
def test_make_export_strategy(self):
|
||||
def _serving_input_fn():
|
||||
return array_ops.constant([1]), None
|
||||
|
||||
export_strategy = export_strategy_lib.make_export_strategy(
|
||||
serving_input_fn=_serving_input_fn,
|
||||
assets_extra={"from/path": "to/path"},
|
||||
as_text=False,
|
||||
exports_to_keep=5)
|
||||
self.assertTrue(
|
||||
isinstance(export_strategy, export_strategy_lib.ExportStrategy))
|
||||
|
||||
def test_garbage_collect_exports(self):
|
||||
export_dir_base = tempfile.mkdtemp() + "export/"
|
||||
gfile.MkDir(export_dir_base)
|
||||
export_dir_1 = _create_test_export_dir(export_dir_base)
|
||||
export_dir_2 = _create_test_export_dir(export_dir_base)
|
||||
export_dir_3 = _create_test_export_dir(export_dir_base)
|
||||
export_dir_4 = _create_test_export_dir(export_dir_base)
|
||||
|
||||
self.assertTrue(gfile.Exists(export_dir_1))
|
||||
self.assertTrue(gfile.Exists(export_dir_2))
|
||||
self.assertTrue(gfile.Exists(export_dir_3))
|
||||
self.assertTrue(gfile.Exists(export_dir_4))
|
||||
|
||||
def _serving_input_fn():
|
||||
return array_ops.constant([1]), None
|
||||
export_strategy = export_strategy_lib.make_export_strategy(
|
||||
_serving_input_fn, exports_to_keep=2)
|
||||
estimator = test.mock.Mock(spec=estimator_lib.Estimator)
|
||||
# Garbage collect all but the most recent 2 exports,
|
||||
# where recency is determined based on the timestamp directory names.
|
||||
export_strategy.export(estimator, export_dir_base)
|
||||
|
||||
self.assertFalse(gfile.Exists(export_dir_1))
|
||||
self.assertFalse(gfile.Exists(export_dir_2))
|
||||
self.assertTrue(gfile.Exists(export_dir_3))
|
||||
self.assertTrue(gfile.Exists(export_dir_4))
|
||||
|
||||
|
||||
def _create_test_export_dir(export_dir_base):
|
||||
export_dir = _get_timestamped_export_dir(export_dir_base)
|
||||
gfile.MkDir(export_dir)
|
||||
time.sleep(2)
|
||||
return export_dir
|
||||
|
||||
|
||||
def _get_timestamped_export_dir(export_dir_base):
|
||||
# When we create a timestamped directory, there is a small chance that the
|
||||
# directory already exists because another worker is also writing exports.
|
||||
# In this case we just wait one second to get a new timestamp and try again.
|
||||
# If this fails several times in a row, then something is seriously wrong.
|
||||
max_directory_creation_attempts = 10
|
||||
|
||||
attempts = 0
|
||||
while attempts < max_directory_creation_attempts:
|
||||
export_timestamp = int(time.time())
|
||||
|
||||
export_dir = os.path.join(
|
||||
compat.as_bytes(export_dir_base),
|
||||
compat.as_bytes(str(export_timestamp)))
|
||||
if not gfile.Exists(export_dir):
|
||||
# Collisions are still possible (though extremely unlikely): this
|
||||
# directory is not actually created yet, but it will be almost
|
||||
# instantly on return from this function.
|
||||
return export_dir
|
||||
time.sleep(1)
|
||||
attempts += 1
|
||||
logging.warn("Export directory {} already exists; retrying (attempt {}/{})".
|
||||
format(export_dir, attempts, max_directory_creation_attempts))
|
||||
raise RuntimeError("Failed to obtain a unique export directory name after "
|
||||
"{} attempts.".format(max_directory_creation_attempts))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
209
tensorflow/python/estimator/gc.py
Normal file
209
tensorflow/python/estimator/gc.py
Normal file
@ -0,0 +1,209 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
r"""System for specifying garbage collection (GC) of path based data.
|
||||
|
||||
This framework allows for GC of data specified by path names, for example files
|
||||
on disk. gc.Path objects each represent a single item stored at a path and may
|
||||
be a base directory,
|
||||
/tmp/exports/0/...
|
||||
/tmp/exports/1/...
|
||||
...
|
||||
or a fully qualified file,
|
||||
/tmp/train-1.ckpt
|
||||
/tmp/train-2.ckpt
|
||||
...
|
||||
|
||||
A gc filter function takes and returns a list of gc.Path items. Filter
|
||||
functions are responsible for selecting Path items for preservation or deletion.
|
||||
Note that functions should always return a sorted list.
|
||||
|
||||
For example,
|
||||
base_dir = "/tmp"
|
||||
# Create the directories.
|
||||
for e in xrange(10):
|
||||
os.mkdir("%s/%d" % (base_dir, e), 0o755)
|
||||
|
||||
# Create a simple parser that pulls the export_version from the directory.
|
||||
path_regex = "^" + re.escape(base_dir) + "/(\\d+)$"
|
||||
def parser(path):
|
||||
match = re.match(path_regex, path.path)
|
||||
if not match:
|
||||
return None
|
||||
return path._replace(export_version=int(match.group(1)))
|
||||
|
||||
path_list = gc._get_paths("/tmp", parser) # contains all ten Paths
|
||||
|
||||
every_fifth = gc._mod_export_version(5)
|
||||
print(every_fifth(path_list)) # shows ["/tmp/0", "/tmp/5"]
|
||||
|
||||
largest_three = gc.largest_export_versions(3)
|
||||
print(largest_three(all_paths)) # shows ["/tmp/7", "/tmp/8", "/tmp/9"]
|
||||
|
||||
both = gc._union(every_fifth, largest_three)
|
||||
print(both(all_paths)) # shows ["/tmp/0", "/tmp/5",
|
||||
# "/tmp/7", "/tmp/8", "/tmp/9"]
|
||||
# Delete everything not in 'both'.
|
||||
to_delete = gc._negation(both)
|
||||
for p in to_delete(all_paths):
|
||||
gfile.DeleteRecursively(p.path) # deletes: "/tmp/1", "/tmp/2",
|
||||
# "/tmp/3", "/tmp/4", "/tmp/6",
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import heapq
|
||||
import math
|
||||
import os
|
||||
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
Path = collections.namedtuple('Path', 'path export_version')
|
||||
|
||||
|
||||
def _largest_export_versions(n):
|
||||
"""Creates a filter that keeps the largest n export versions.
|
||||
|
||||
Args:
|
||||
n: number of versions to keep.
|
||||
|
||||
Returns:
|
||||
A filter function that keeps the n largest paths.
|
||||
"""
|
||||
def keep(paths):
|
||||
heap = []
|
||||
for idx, path in enumerate(paths):
|
||||
if path.export_version is not None:
|
||||
heapq.heappush(heap, (path.export_version, idx))
|
||||
keepers = [paths[i] for _, i in heapq.nlargest(n, heap)]
|
||||
return sorted(keepers)
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
def _one_of_every_n_export_versions(n):
|
||||
"""Creates a filter that keeps one of every n export versions.
|
||||
|
||||
Args:
|
||||
n: interval size.
|
||||
|
||||
Returns:
|
||||
A filter function that keeps exactly one path from each interval
|
||||
[0, n], (n, 2n], (2n, 3n], etc... If more than one path exists in an
|
||||
interval the largest is kept.
|
||||
"""
|
||||
def keep(paths):
|
||||
"""A filter function that keeps exactly one out of every n paths."""
|
||||
|
||||
keeper_map = {} # map from interval to largest path seen in that interval
|
||||
for p in paths:
|
||||
if p.export_version is None:
|
||||
# Skip missing export_versions.
|
||||
continue
|
||||
# Find the interval (with a special case to map export_version = 0 to
|
||||
# interval 0.
|
||||
interval = math.floor(
|
||||
(p.export_version - 1) / n) if p.export_version else 0
|
||||
existing = keeper_map.get(interval, None)
|
||||
if (not existing) or (existing.export_version < p.export_version):
|
||||
keeper_map[interval] = p
|
||||
return sorted(keeper_map.values())
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
def _mod_export_version(n):
|
||||
"""Creates a filter that keeps every export that is a multiple of n.
|
||||
|
||||
Args:
|
||||
n: step size.
|
||||
|
||||
Returns:
|
||||
A filter function that keeps paths where export_version % n == 0.
|
||||
"""
|
||||
def keep(paths):
|
||||
keepers = []
|
||||
for p in paths:
|
||||
if p.export_version % n == 0:
|
||||
keepers.append(p)
|
||||
return sorted(keepers)
|
||||
return keep
|
||||
|
||||
|
||||
def _union(lf, rf):
|
||||
"""Creates a filter that keeps the union of two filters.
|
||||
|
||||
Args:
|
||||
lf: first filter
|
||||
rf: second filter
|
||||
|
||||
Returns:
|
||||
A filter function that keeps the n largest paths.
|
||||
"""
|
||||
def keep(paths):
|
||||
l = set(lf(paths))
|
||||
r = set(rf(paths))
|
||||
return sorted(list(l|r))
|
||||
return keep
|
||||
|
||||
|
||||
def _negation(f):
|
||||
"""Negate a filter.
|
||||
|
||||
Args:
|
||||
f: filter function to invert
|
||||
|
||||
Returns:
|
||||
A filter function that returns the negation of f.
|
||||
"""
|
||||
def keep(paths):
|
||||
l = set(paths)
|
||||
r = set(f(paths))
|
||||
return sorted(list(l-r))
|
||||
return keep
|
||||
|
||||
|
||||
def _get_paths(base_dir, parser):
|
||||
"""Gets a list of Paths in a given directory.
|
||||
|
||||
Args:
|
||||
base_dir: directory.
|
||||
parser: a function which gets the raw Path and can augment it with
|
||||
information such as the export_version, or ignore the path by returning
|
||||
None. An example parser may extract the export version from a path
|
||||
such as "/tmp/exports/100" an another may extract from a full file
|
||||
name such as "/tmp/checkpoint-99.out".
|
||||
|
||||
Returns:
|
||||
A list of Paths contained in the base directory with the parsing function
|
||||
applied.
|
||||
By default the following fields are populated,
|
||||
- Path.path
|
||||
The parsing function is responsible for populating,
|
||||
- Path.export_version
|
||||
"""
|
||||
raw_paths = gfile.ListDirectory(base_dir)
|
||||
paths = []
|
||||
for r in raw_paths:
|
||||
p = parser(Path(os.path.join(compat.as_str_any(base_dir),
|
||||
compat.as_str_any(r)),
|
||||
None))
|
||||
if p:
|
||||
paths.append(p)
|
||||
return sorted(paths)
|
145
tensorflow/python/estimator/gc_test.py
Normal file
145
tensorflow/python/estimator/gc_test.py
Normal file
@ -0,0 +1,145 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for garbage collection utilities."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.estimator import gc
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
def _create_parser(base_dir):
|
||||
# create a simple parser that pulls the export_version from the directory.
|
||||
def parser(path):
|
||||
# Modify the path object for RegEx match for Windows Paths
|
||||
if os.name == "nt":
|
||||
match = re.match(
|
||||
"^" + compat.as_str_any(base_dir).replace("\\", "/") + "/(\\d+)$",
|
||||
compat.as_str_any(path.path).replace("\\", "/"))
|
||||
else:
|
||||
match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$",
|
||||
compat.as_str_any(path.path))
|
||||
if not match:
|
||||
return None
|
||||
return path._replace(export_version=int(match.group(1)))
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class GcTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testLargestExportVersions(self):
|
||||
paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)]
|
||||
newest = gc._largest_export_versions(2)
|
||||
n = newest(paths)
|
||||
self.assertEqual(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])
|
||||
|
||||
def testLargestExportVersionsDoesNotDeleteZeroFolder(self):
|
||||
paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)]
|
||||
newest = gc._largest_export_versions(2)
|
||||
n = newest(paths)
|
||||
self.assertEqual(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)])
|
||||
|
||||
def testModExportVersion(self):
|
||||
paths = [
|
||||
gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6),
|
||||
gc.Path("/foo", 9)
|
||||
]
|
||||
mod = gc._mod_export_version(2)
|
||||
self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)])
|
||||
mod = gc._mod_export_version(3)
|
||||
self.assertEqual(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)])
|
||||
|
||||
def testOneOfEveryNExportVersions(self):
|
||||
paths = [
|
||||
gc.Path("/foo", 0), gc.Path("/foo", 1), gc.Path("/foo", 3),
|
||||
gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 7),
|
||||
gc.Path("/foo", 8), gc.Path("/foo", 33)
|
||||
]
|
||||
one_of = gc._one_of_every_n_export_versions(3)
|
||||
self.assertEqual(
|
||||
one_of(paths), [
|
||||
gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 8),
|
||||
gc.Path("/foo", 33)
|
||||
])
|
||||
|
||||
def testOneOfEveryNExportVersionsZero(self):
|
||||
# Zero is a special case since it gets rolled into the first interval.
|
||||
# Test that here.
|
||||
paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)]
|
||||
one_of = gc._one_of_every_n_export_versions(3)
|
||||
self.assertEqual(one_of(paths), [gc.Path("/foo", 0), gc.Path("/foo", 5)])
|
||||
|
||||
def testUnion(self):
|
||||
paths = []
|
||||
for i in xrange(10):
|
||||
paths.append(gc.Path("/foo", i))
|
||||
f = gc._union(gc._largest_export_versions(3), gc._mod_export_version(3))
|
||||
self.assertEqual(
|
||||
f(paths), [
|
||||
gc.Path("/foo", 0), gc.Path("/foo", 3), gc.Path("/foo", 6),
|
||||
gc.Path("/foo", 7), gc.Path("/foo", 8), gc.Path("/foo", 9)
|
||||
])
|
||||
|
||||
def testNegation(self):
|
||||
paths = [
|
||||
gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6),
|
||||
gc.Path("/foo", 9)
|
||||
]
|
||||
mod = gc._negation(gc._mod_export_version(2))
|
||||
self.assertEqual(mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)])
|
||||
mod = gc._negation(gc._mod_export_version(3))
|
||||
self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)])
|
||||
|
||||
def testPathsWithParse(self):
|
||||
base_dir = os.path.join(test.get_temp_dir(), "paths_parse")
|
||||
self.assertFalse(gfile.Exists(base_dir))
|
||||
for p in xrange(3):
|
||||
gfile.MakeDirs(os.path.join(base_dir, "%d" % p))
|
||||
# add a base_directory to ignore
|
||||
gfile.MakeDirs(os.path.join(base_dir, "ignore"))
|
||||
|
||||
self.assertEqual(
|
||||
gc._get_paths(base_dir, _create_parser(base_dir)),
|
||||
[
|
||||
gc.Path(os.path.join(base_dir, "0"), 0),
|
||||
gc.Path(os.path.join(base_dir, "1"), 1),
|
||||
gc.Path(os.path.join(base_dir, "2"), 2)
|
||||
])
|
||||
|
||||
def testMixedStrTypes(self):
|
||||
temp_dir = compat.as_bytes(test.get_temp_dir())
|
||||
|
||||
for sub_dir in ["str", b"bytes", u"unicode"]:
|
||||
base_dir = os.path.join(
|
||||
(temp_dir if isinstance(sub_dir, bytes) else temp_dir.decode()),
|
||||
sub_dir)
|
||||
self.assertFalse(gfile.Exists(base_dir))
|
||||
gfile.MakeDirs(os.path.join(compat.as_str_any(base_dir), "42"))
|
||||
gc._get_paths(base_dir, _create_parser(base_dir))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user