From 01620694d85653d8cc836db17945b0e349838b8c Mon Sep 17 00:00:00 2001 From: Igor Saprykin Date: Mon, 25 Sep 2017 10:40:14 -0700 Subject: [PATCH] Move 'make_export_strategy' utility function from .contrib to .core. That function depends on the utilities related to garbage collection. They are moved too, but are kept private. PiperOrigin-RevId: 169927321 --- tensorflow/python/estimator/BUILD | 90 ++++-- .../python/estimator/export_strategy.py | 93 ++++++- .../python/estimator/export_strategy_test.py | 261 ++++++++++++++++++ tensorflow/python/estimator/gc.py | 209 ++++++++++++++ tensorflow/python/estimator/gc_test.py | 145 ++++++++++ 5 files changed, 778 insertions(+), 20 deletions(-) create mode 100644 tensorflow/python/estimator/export_strategy_test.py create mode 100644 tensorflow/python/estimator/gc.py create mode 100644 tensorflow/python/estimator/gc_test.py diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index b0bf6347af6..ccaa3379d3e 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -8,25 +8,6 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "py_test") -py_library( - name = "estimator_py", - srcs = ["estimator_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":dnn", - ":dnn_linear_combined", - ":estimator", - ":export", - ":inputs", - ":linear", - ":model_fn", - ":parsing_utils", - ":run_config", - ":training", - "//tensorflow/python:util", - ], -) - filegroup( name = "all_files", srcs = glob( @@ -39,6 +20,77 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +py_library( + name = "estimator_py", + srcs = ["estimator_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":dnn", + ":dnn_linear_combined", + ":estimator", + ":export", + ":export_strategy", + ":inputs", + ":linear", + ":model_fn", + ":parsing_utils", + ":run_config", + ":training", + "//tensorflow/python:util", + ], +) + +py_library( + name = "export_strategy", + srcs = ["export_strategy.py"], + srcs_version = "PY2AND3", + deps = [ + ":gc", + "//tensorflow/python:errors", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], +) + +py_test( + name = "export_strategy_test", + size = "small", + srcs = ["export_strategy_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":estimator", + ":export_strategy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], +) + +py_library( + name = "gc", + srcs = ["gc.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], +) + +py_test( + name = "gc_test", + size = "small", + srcs = ["gc_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":gc", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], +) + py_library( name = "model_fn", srcs = ["model_fn.py"], diff --git a/tensorflow/python/estimator/export_strategy.py b/tensorflow/python/estimator/export_strategy.py index bfcd20d7796..a481ddcc8cd 100644 --- a/tensorflow/python/estimator/export_strategy.py +++ b/tensorflow/python/estimator/export_strategy.py @@ -19,10 +19,15 @@ from __future__ import division from __future__ import print_function import collections +import os +from tensorflow.python.estimator import gc from tensorflow.python.estimator import util +from tensorflow.python.framework import errors_impl +from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging -__all__ = ['ExportStrategy'] +__all__ = ['ExportStrategy', 'make_export_strategy'] class ExportStrategy( @@ -81,3 +86,89 @@ class ExportStrategy( kwargs['eval_result'] = eval_result return self.export_fn(estimator, export_path, **kwargs) + + +def make_export_strategy(serving_input_fn, + assets_extra=None, + as_text=False, + exports_to_keep=5): + """Create an ExportStrategy for use with tf.estimator.EvalSpec. + + Args: + serving_input_fn: a function that takes no arguments and returns an + `ServingInputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel. Each key should give the destination + path (including the filename) relative to the assets.extra directory. + The corresponding value gives the full path of the source file to be + copied. For example, the simple case of copying a single file without + renaming it is specified as + `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + as_text: whether to write the SavedModel proto in text format. + exports_to_keep: Number of exports to keep. Older exports will be + garbage-collected. Defaults to 5. Set to None to disable garbage + collection. + + Returns: + An `ExportStrategy` that can be passed to the Experiment constructor. + """ + + def export_fn(estimator, export_dir_base, checkpoint_path=None): + """Exports the given Estimator as a SavedModel. + + Args: + estimator: the Estimator to export. + export_dir_base: A string containing a directory to write the exported + graph and checkpoints. + checkpoint_path: The checkpoint path to export. If None (the default), + the most recent checkpoint found within the model directory is chosen. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: If `estimator` is a ${tf.estimator.Estimator} instance + and `default_output_alternative_key` was specified. + """ + export_result = estimator.export_savedmodel( + export_dir_base, + serving_input_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path) + + _garbage_collect_exports(export_dir_base, exports_to_keep) + return export_result + + return ExportStrategy('Servo', export_fn) + + +def _garbage_collect_exports(export_dir_base, exports_to_keep): + """Deletes older exports, retaining only a given number of the most recent. + + Export subdirectories are assumed to be named with monotonically increasing + integers; the most recent are taken to be those with the largest values. + + Args: + export_dir_base: the base directory under which each export is in a + versioned subdirectory. + exports_to_keep: the number of recent exports to retain. + """ + if exports_to_keep is None: + return + + def _export_version_parser(path): + # create a simple parser that pulls the export_version from the directory. + filename = os.path.basename(path.path) + if not (len(filename) == 10 and filename.isdigit()): + return None + return path._replace(export_version=int(filename)) + + keep_filter = gc._largest_export_versions(exports_to_keep) + delete_filter = gc._negation(keep_filter) + for p in delete_filter( + gc._get_paths(export_dir_base, parser=_export_version_parser)): + try: + gfile.DeleteRecursively(p.path) + except errors_impl.NotFoundError as e: + tf_logging.warn('Can not delete %s recursively: %s', p.path, e) diff --git a/tensorflow/python/estimator/export_strategy_test.py b/tensorflow/python/estimator/export_strategy_test.py new file mode 100644 index 00000000000..32224a6913b --- /dev/null +++ b/tensorflow/python/estimator/export_strategy_test.py @@ -0,0 +1,261 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `make_export_strategy`.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile +import time + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.estimator import export_strategy as export_strategy_lib +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat + + +class ExportStrategyTest(test.TestCase): + + def testAcceptsNameAndFn(self): + def export_fn(estimator, export_path): + del estimator, export_path + + export_strategy = export_strategy_lib.ExportStrategy( + name="test", export_fn=export_fn) + + self.assertEqual("test", export_strategy.name) + self.assertEqual(export_fn, export_strategy.export_fn) + + def testCallsExportFnThatDoesntKnowExtraArguments(self): + expected_estimator = {} + + def export_fn(estimator, export_path): + self.assertEqual(expected_estimator, estimator) + self.assertEqual("expected_path", export_path) + + export_strategy = export_strategy_lib.ExportStrategy( + name="test", export_fn=export_fn) + + export_strategy.export( + estimator=expected_estimator, export_path="expected_path") + + # Also works with additional arguments that `export_fn` doesn't support. + # The lack of support is detected and the arguments aren't passed. + export_strategy.export( + estimator=expected_estimator, + export_path="expected_path", + checkpoint_path="unexpected_checkpoint_path") + export_strategy.export( + estimator=expected_estimator, + export_path="expected_path", + eval_result=()) + export_strategy.export( + estimator=expected_estimator, + export_path="expected_path", + checkpoint_path="unexpected_checkpoint_path", + eval_result=()) + + def testCallsExportFnThatKnowsAboutCheckpointPathButItsNotGiven(self): + expected_estimator = {} + + def export_fn(estimator, export_path, checkpoint_path): + self.assertEqual(expected_estimator, estimator) + self.assertEqual("expected_path", export_path) + self.assertEqual(None, checkpoint_path) + + export_strategy = export_strategy_lib.ExportStrategy( + name="test", export_fn=export_fn) + + export_strategy.export( + estimator=expected_estimator, export_path="expected_path") + export_strategy.export( + estimator=expected_estimator, + export_path="expected_path", + eval_result=()) + + def testCallsExportFnWithCheckpointPath(self): + expected_estimator = {} + + def export_fn(estimator, export_path, checkpoint_path): + self.assertEqual(expected_estimator, estimator) + self.assertEqual("expected_path", export_path) + self.assertEqual("expected_checkpoint_path", checkpoint_path) + + export_strategy = export_strategy_lib.ExportStrategy( + name="test", export_fn=export_fn) + + export_strategy.export( + estimator=expected_estimator, + export_path="expected_path", + checkpoint_path="expected_checkpoint_path") + export_strategy.export( + estimator=expected_estimator, + export_path="expected_path", + checkpoint_path="expected_checkpoint_path", + eval_result=()) + + def testCallsExportFnThatKnowsAboutEvalResultButItsNotGiven(self): + expected_estimator = {} + + def export_fn(estimator, export_path, checkpoint_path, eval_result): + self.assertEqual(expected_estimator, estimator) + self.assertEqual("expected_path", export_path) + self.assertEqual(None, checkpoint_path) + self.assertEqual(None, eval_result) + + export_strategy = export_strategy_lib.ExportStrategy( + name="test", export_fn=export_fn) + + export_strategy.export( + estimator=expected_estimator, export_path="expected_path") + + def testCallsExportFnThatAcceptsEvalResultButNotCheckpoint(self): + expected_estimator = {} + + def export_fn(estimator, export_path, eval_result): + del estimator, export_path, eval_result + raise RuntimeError("Should raise ValueError before this.") + + export_strategy = export_strategy_lib.ExportStrategy( + name="test", export_fn=export_fn) + + expected_error_message = ( + "An export_fn accepting eval_result must also accept checkpoint_path") + + with self.assertRaisesRegexp(ValueError, expected_error_message): + export_strategy.export( + estimator=expected_estimator, export_path="expected_path") + + with self.assertRaisesRegexp(ValueError, expected_error_message): + export_strategy.export( + estimator=expected_estimator, + export_path="expected_path", + checkpoint_path="unexpected_checkpoint_path") + + with self.assertRaisesRegexp(ValueError, expected_error_message): + export_strategy.export( + estimator=expected_estimator, + export_path="expected_path", + eval_result=()) + + with self.assertRaisesRegexp(ValueError, expected_error_message): + export_strategy.export( + estimator=expected_estimator, + export_path="expected_path", + checkpoint_path="unexpected_checkpoint_path", + eval_result=()) + + def testCallsExportFnWithEvalResultAndCheckpointPath(self): + expected_estimator = {} + expected_eval_result = {} + + def export_fn(estimator, export_path, checkpoint_path, eval_result): + self.assertEqual(expected_estimator, estimator) + self.assertEqual("expected_path", export_path) + self.assertEqual("expected_checkpoint_path", checkpoint_path) + self.assertEqual(expected_eval_result, eval_result) + + export_strategy = export_strategy_lib.ExportStrategy( + name="test", export_fn=export_fn) + + export_strategy.export( + estimator=expected_estimator, + export_path="expected_path", + checkpoint_path="expected_checkpoint_path", + eval_result=expected_eval_result) + + +class MakeExportStrategyTest(test.TestCase): + + def test_make_export_strategy(self): + def _serving_input_fn(): + return array_ops.constant([1]), None + + export_strategy = export_strategy_lib.make_export_strategy( + serving_input_fn=_serving_input_fn, + assets_extra={"from/path": "to/path"}, + as_text=False, + exports_to_keep=5) + self.assertTrue( + isinstance(export_strategy, export_strategy_lib.ExportStrategy)) + + def test_garbage_collect_exports(self): + export_dir_base = tempfile.mkdtemp() + "export/" + gfile.MkDir(export_dir_base) + export_dir_1 = _create_test_export_dir(export_dir_base) + export_dir_2 = _create_test_export_dir(export_dir_base) + export_dir_3 = _create_test_export_dir(export_dir_base) + export_dir_4 = _create_test_export_dir(export_dir_base) + + self.assertTrue(gfile.Exists(export_dir_1)) + self.assertTrue(gfile.Exists(export_dir_2)) + self.assertTrue(gfile.Exists(export_dir_3)) + self.assertTrue(gfile.Exists(export_dir_4)) + + def _serving_input_fn(): + return array_ops.constant([1]), None + export_strategy = export_strategy_lib.make_export_strategy( + _serving_input_fn, exports_to_keep=2) + estimator = test.mock.Mock(spec=estimator_lib.Estimator) + # Garbage collect all but the most recent 2 exports, + # where recency is determined based on the timestamp directory names. + export_strategy.export(estimator, export_dir_base) + + self.assertFalse(gfile.Exists(export_dir_1)) + self.assertFalse(gfile.Exists(export_dir_2)) + self.assertTrue(gfile.Exists(export_dir_3)) + self.assertTrue(gfile.Exists(export_dir_4)) + + +def _create_test_export_dir(export_dir_base): + export_dir = _get_timestamped_export_dir(export_dir_base) + gfile.MkDir(export_dir) + time.sleep(2) + return export_dir + + +def _get_timestamped_export_dir(export_dir_base): + # When we create a timestamped directory, there is a small chance that the + # directory already exists because another worker is also writing exports. + # In this case we just wait one second to get a new timestamp and try again. + # If this fails several times in a row, then something is seriously wrong. + max_directory_creation_attempts = 10 + + attempts = 0 + while attempts < max_directory_creation_attempts: + export_timestamp = int(time.time()) + + export_dir = os.path.join( + compat.as_bytes(export_dir_base), + compat.as_bytes(str(export_timestamp))) + if not gfile.Exists(export_dir): + # Collisions are still possible (though extremely unlikely): this + # directory is not actually created yet, but it will be almost + # instantly on return from this function. + return export_dir + time.sleep(1) + attempts += 1 + logging.warn("Export directory {} already exists; retrying (attempt {}/{})". + format(export_dir, attempts, max_directory_creation_attempts)) + raise RuntimeError("Failed to obtain a unique export directory name after " + "{} attempts.".format(max_directory_creation_attempts)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/estimator/gc.py b/tensorflow/python/estimator/gc.py new file mode 100644 index 00000000000..9f8a463ec1e --- /dev/null +++ b/tensorflow/python/estimator/gc.py @@ -0,0 +1,209 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +r"""System for specifying garbage collection (GC) of path based data. + +This framework allows for GC of data specified by path names, for example files +on disk. gc.Path objects each represent a single item stored at a path and may +be a base directory, + /tmp/exports/0/... + /tmp/exports/1/... + ... +or a fully qualified file, + /tmp/train-1.ckpt + /tmp/train-2.ckpt + ... + +A gc filter function takes and returns a list of gc.Path items. Filter +functions are responsible for selecting Path items for preservation or deletion. +Note that functions should always return a sorted list. + +For example, + base_dir = "/tmp" + # Create the directories. + for e in xrange(10): + os.mkdir("%s/%d" % (base_dir, e), 0o755) + + # Create a simple parser that pulls the export_version from the directory. + path_regex = "^" + re.escape(base_dir) + "/(\\d+)$" + def parser(path): + match = re.match(path_regex, path.path) + if not match: + return None + return path._replace(export_version=int(match.group(1))) + + path_list = gc._get_paths("/tmp", parser) # contains all ten Paths + + every_fifth = gc._mod_export_version(5) + print(every_fifth(path_list)) # shows ["/tmp/0", "/tmp/5"] + + largest_three = gc.largest_export_versions(3) + print(largest_three(all_paths)) # shows ["/tmp/7", "/tmp/8", "/tmp/9"] + + both = gc._union(every_fifth, largest_three) + print(both(all_paths)) # shows ["/tmp/0", "/tmp/5", + # "/tmp/7", "/tmp/8", "/tmp/9"] + # Delete everything not in 'both'. + to_delete = gc._negation(both) + for p in to_delete(all_paths): + gfile.DeleteRecursively(p.path) # deletes: "/tmp/1", "/tmp/2", + # "/tmp/3", "/tmp/4", "/tmp/6", +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import heapq +import math +import os + +from tensorflow.python.platform import gfile +from tensorflow.python.util import compat + +Path = collections.namedtuple('Path', 'path export_version') + + +def _largest_export_versions(n): + """Creates a filter that keeps the largest n export versions. + + Args: + n: number of versions to keep. + + Returns: + A filter function that keeps the n largest paths. + """ + def keep(paths): + heap = [] + for idx, path in enumerate(paths): + if path.export_version is not None: + heapq.heappush(heap, (path.export_version, idx)) + keepers = [paths[i] for _, i in heapq.nlargest(n, heap)] + return sorted(keepers) + + return keep + + +def _one_of_every_n_export_versions(n): + """Creates a filter that keeps one of every n export versions. + + Args: + n: interval size. + + Returns: + A filter function that keeps exactly one path from each interval + [0, n], (n, 2n], (2n, 3n], etc... If more than one path exists in an + interval the largest is kept. + """ + def keep(paths): + """A filter function that keeps exactly one out of every n paths.""" + + keeper_map = {} # map from interval to largest path seen in that interval + for p in paths: + if p.export_version is None: + # Skip missing export_versions. + continue + # Find the interval (with a special case to map export_version = 0 to + # interval 0. + interval = math.floor( + (p.export_version - 1) / n) if p.export_version else 0 + existing = keeper_map.get(interval, None) + if (not existing) or (existing.export_version < p.export_version): + keeper_map[interval] = p + return sorted(keeper_map.values()) + + return keep + + +def _mod_export_version(n): + """Creates a filter that keeps every export that is a multiple of n. + + Args: + n: step size. + + Returns: + A filter function that keeps paths where export_version % n == 0. + """ + def keep(paths): + keepers = [] + for p in paths: + if p.export_version % n == 0: + keepers.append(p) + return sorted(keepers) + return keep + + +def _union(lf, rf): + """Creates a filter that keeps the union of two filters. + + Args: + lf: first filter + rf: second filter + + Returns: + A filter function that keeps the n largest paths. + """ + def keep(paths): + l = set(lf(paths)) + r = set(rf(paths)) + return sorted(list(l|r)) + return keep + + +def _negation(f): + """Negate a filter. + + Args: + f: filter function to invert + + Returns: + A filter function that returns the negation of f. + """ + def keep(paths): + l = set(paths) + r = set(f(paths)) + return sorted(list(l-r)) + return keep + + +def _get_paths(base_dir, parser): + """Gets a list of Paths in a given directory. + + Args: + base_dir: directory. + parser: a function which gets the raw Path and can augment it with + information such as the export_version, or ignore the path by returning + None. An example parser may extract the export version from a path + such as "/tmp/exports/100" an another may extract from a full file + name such as "/tmp/checkpoint-99.out". + + Returns: + A list of Paths contained in the base directory with the parsing function + applied. + By default the following fields are populated, + - Path.path + The parsing function is responsible for populating, + - Path.export_version + """ + raw_paths = gfile.ListDirectory(base_dir) + paths = [] + for r in raw_paths: + p = parser(Path(os.path.join(compat.as_str_any(base_dir), + compat.as_str_any(r)), + None)) + if p: + paths.append(p) + return sorted(paths) diff --git a/tensorflow/python/estimator/gc_test.py b/tensorflow/python/estimator/gc_test.py new file mode 100644 index 00000000000..2cbdd511d11 --- /dev/null +++ b/tensorflow/python/estimator/gc_test.py @@ -0,0 +1,145 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for garbage collection utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import re + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.python.estimator import gc +from tensorflow.python.framework import test_util +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +def _create_parser(base_dir): + # create a simple parser that pulls the export_version from the directory. + def parser(path): + # Modify the path object for RegEx match for Windows Paths + if os.name == "nt": + match = re.match( + "^" + compat.as_str_any(base_dir).replace("\\", "/") + "/(\\d+)$", + compat.as_str_any(path.path).replace("\\", "/")) + else: + match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$", + compat.as_str_any(path.path)) + if not match: + return None + return path._replace(export_version=int(match.group(1))) + + return parser + + +class GcTest(test_util.TensorFlowTestCase): + + def testLargestExportVersions(self): + paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)] + newest = gc._largest_export_versions(2) + n = newest(paths) + self.assertEqual(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)]) + + def testLargestExportVersionsDoesNotDeleteZeroFolder(self): + paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)] + newest = gc._largest_export_versions(2) + n = newest(paths) + self.assertEqual(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)]) + + def testModExportVersion(self): + paths = [ + gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6), + gc.Path("/foo", 9) + ] + mod = gc._mod_export_version(2) + self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)]) + mod = gc._mod_export_version(3) + self.assertEqual(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)]) + + def testOneOfEveryNExportVersions(self): + paths = [ + gc.Path("/foo", 0), gc.Path("/foo", 1), gc.Path("/foo", 3), + gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 7), + gc.Path("/foo", 8), gc.Path("/foo", 33) + ] + one_of = gc._one_of_every_n_export_versions(3) + self.assertEqual( + one_of(paths), [ + gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 8), + gc.Path("/foo", 33) + ]) + + def testOneOfEveryNExportVersionsZero(self): + # Zero is a special case since it gets rolled into the first interval. + # Test that here. + paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)] + one_of = gc._one_of_every_n_export_versions(3) + self.assertEqual(one_of(paths), [gc.Path("/foo", 0), gc.Path("/foo", 5)]) + + def testUnion(self): + paths = [] + for i in xrange(10): + paths.append(gc.Path("/foo", i)) + f = gc._union(gc._largest_export_versions(3), gc._mod_export_version(3)) + self.assertEqual( + f(paths), [ + gc.Path("/foo", 0), gc.Path("/foo", 3), gc.Path("/foo", 6), + gc.Path("/foo", 7), gc.Path("/foo", 8), gc.Path("/foo", 9) + ]) + + def testNegation(self): + paths = [ + gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6), + gc.Path("/foo", 9) + ] + mod = gc._negation(gc._mod_export_version(2)) + self.assertEqual(mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)]) + mod = gc._negation(gc._mod_export_version(3)) + self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)]) + + def testPathsWithParse(self): + base_dir = os.path.join(test.get_temp_dir(), "paths_parse") + self.assertFalse(gfile.Exists(base_dir)) + for p in xrange(3): + gfile.MakeDirs(os.path.join(base_dir, "%d" % p)) + # add a base_directory to ignore + gfile.MakeDirs(os.path.join(base_dir, "ignore")) + + self.assertEqual( + gc._get_paths(base_dir, _create_parser(base_dir)), + [ + gc.Path(os.path.join(base_dir, "0"), 0), + gc.Path(os.path.join(base_dir, "1"), 1), + gc.Path(os.path.join(base_dir, "2"), 2) + ]) + + def testMixedStrTypes(self): + temp_dir = compat.as_bytes(test.get_temp_dir()) + + for sub_dir in ["str", b"bytes", u"unicode"]: + base_dir = os.path.join( + (temp_dir if isinstance(sub_dir, bytes) else temp_dir.decode()), + sub_dir) + self.assertFalse(gfile.Exists(base_dir)) + gfile.MakeDirs(os.path.join(compat.as_str_any(base_dir), "42")) + gc._get_paths(base_dir, _create_parser(base_dir)) + + +if __name__ == "__main__": + test.main()