Move 'make_export_strategy' utility function from .contrib to .core.

That function depends on the utilities related to garbage collection.  They are moved too, but are kept private.

PiperOrigin-RevId: 169927321
This commit is contained in:
Igor Saprykin 2017-09-25 10:40:14 -07:00 committed by TensorFlower Gardener
parent dbaf176e1f
commit 01620694d8
5 changed files with 778 additions and 20 deletions

View File

@ -8,25 +8,6 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
name = "estimator_py",
srcs = ["estimator_lib.py"],
srcs_version = "PY2AND3",
deps = [
":dnn",
":dnn_linear_combined",
":estimator",
":export",
":inputs",
":linear",
":model_fn",
":parsing_utils",
":run_config",
":training",
"//tensorflow/python:util",
],
)
filegroup(
name = "all_files",
srcs = glob(
@ -39,6 +20,77 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"],
)
py_library(
name = "estimator_py",
srcs = ["estimator_lib.py"],
srcs_version = "PY2AND3",
deps = [
":dnn",
":dnn_linear_combined",
":estimator",
":export",
":export_strategy",
":inputs",
":linear",
":model_fn",
":parsing_utils",
":run_config",
":training",
"//tensorflow/python:util",
],
)
py_library(
name = "export_strategy",
srcs = ["export_strategy.py"],
srcs_version = "PY2AND3",
deps = [
":gc",
"//tensorflow/python:errors",
"//tensorflow/python:platform",
"//tensorflow/python:util",
],
)
py_test(
name = "export_strategy_test",
size = "small",
srcs = ["export_strategy_test.py"],
srcs_version = "PY2AND3",
deps = [
":estimator",
":export_strategy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform",
"//tensorflow/python:util",
],
)
py_library(
name = "gc",
srcs = ["gc.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:platform",
"//tensorflow/python:util",
],
)
py_test(
name = "gc_test",
size = "small",
srcs = ["gc_test.py"],
srcs_version = "PY2AND3",
deps = [
":gc",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform",
"//tensorflow/python:util",
],
)
py_library(
name = "model_fn",
srcs = ["model_fn.py"],

View File

@ -19,10 +19,15 @@ from __future__ import division
from __future__ import print_function
import collections
import os
from tensorflow.python.estimator import gc
from tensorflow.python.estimator import util
from tensorflow.python.framework import errors_impl
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging
__all__ = ['ExportStrategy']
__all__ = ['ExportStrategy', 'make_export_strategy']
class ExportStrategy(
@ -81,3 +86,89 @@ class ExportStrategy(
kwargs['eval_result'] = eval_result
return self.export_fn(estimator, export_path, **kwargs)
def make_export_strategy(serving_input_fn,
assets_extra=None,
as_text=False,
exports_to_keep=5):
"""Create an ExportStrategy for use with tf.estimator.EvalSpec.
Args:
serving_input_fn: a function that takes no arguments and returns an
`ServingInputReceiver`.
assets_extra: A dict specifying how to populate the assets.extra directory
within the exported SavedModel. Each key should give the destination
path (including the filename) relative to the assets.extra directory.
The corresponding value gives the full path of the source file to be
copied. For example, the simple case of copying a single file without
renaming it is specified as
`{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
as_text: whether to write the SavedModel proto in text format.
exports_to_keep: Number of exports to keep. Older exports will be
garbage-collected. Defaults to 5. Set to None to disable garbage
collection.
Returns:
An `ExportStrategy` that can be passed to the Experiment constructor.
"""
def export_fn(estimator, export_dir_base, checkpoint_path=None):
"""Exports the given Estimator as a SavedModel.
Args:
estimator: the Estimator to export.
export_dir_base: A string containing a directory to write the exported
graph and checkpoints.
checkpoint_path: The checkpoint path to export. If None (the default),
the most recent checkpoint found within the model directory is chosen.
Returns:
The string path to the exported directory.
Raises:
ValueError: If `estimator` is a ${tf.estimator.Estimator} instance
and `default_output_alternative_key` was specified.
"""
export_result = estimator.export_savedmodel(
export_dir_base,
serving_input_fn,
assets_extra=assets_extra,
as_text=as_text,
checkpoint_path=checkpoint_path)
_garbage_collect_exports(export_dir_base, exports_to_keep)
return export_result
return ExportStrategy('Servo', export_fn)
def _garbage_collect_exports(export_dir_base, exports_to_keep):
"""Deletes older exports, retaining only a given number of the most recent.
Export subdirectories are assumed to be named with monotonically increasing
integers; the most recent are taken to be those with the largest values.
Args:
export_dir_base: the base directory under which each export is in a
versioned subdirectory.
exports_to_keep: the number of recent exports to retain.
"""
if exports_to_keep is None:
return
def _export_version_parser(path):
# create a simple parser that pulls the export_version from the directory.
filename = os.path.basename(path.path)
if not (len(filename) == 10 and filename.isdigit()):
return None
return path._replace(export_version=int(filename))
keep_filter = gc._largest_export_versions(exports_to_keep)
delete_filter = gc._negation(keep_filter)
for p in delete_filter(
gc._get_paths(export_dir_base, parser=_export_version_parser)):
try:
gfile.DeleteRecursively(p.path)
except errors_impl.NotFoundError as e:
tf_logging.warn('Can not delete %s recursively: %s', p.path, e)

View File

@ -0,0 +1,261 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `make_export_strategy`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
import time
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import export_strategy as export_strategy_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
class ExportStrategyTest(test.TestCase):
def testAcceptsNameAndFn(self):
def export_fn(estimator, export_path):
del estimator, export_path
export_strategy = export_strategy_lib.ExportStrategy(
name="test", export_fn=export_fn)
self.assertEqual("test", export_strategy.name)
self.assertEqual(export_fn, export_strategy.export_fn)
def testCallsExportFnThatDoesntKnowExtraArguments(self):
expected_estimator = {}
def export_fn(estimator, export_path):
self.assertEqual(expected_estimator, estimator)
self.assertEqual("expected_path", export_path)
export_strategy = export_strategy_lib.ExportStrategy(
name="test", export_fn=export_fn)
export_strategy.export(
estimator=expected_estimator, export_path="expected_path")
# Also works with additional arguments that `export_fn` doesn't support.
# The lack of support is detected and the arguments aren't passed.
export_strategy.export(
estimator=expected_estimator,
export_path="expected_path",
checkpoint_path="unexpected_checkpoint_path")
export_strategy.export(
estimator=expected_estimator,
export_path="expected_path",
eval_result=())
export_strategy.export(
estimator=expected_estimator,
export_path="expected_path",
checkpoint_path="unexpected_checkpoint_path",
eval_result=())
def testCallsExportFnThatKnowsAboutCheckpointPathButItsNotGiven(self):
expected_estimator = {}
def export_fn(estimator, export_path, checkpoint_path):
self.assertEqual(expected_estimator, estimator)
self.assertEqual("expected_path", export_path)
self.assertEqual(None, checkpoint_path)
export_strategy = export_strategy_lib.ExportStrategy(
name="test", export_fn=export_fn)
export_strategy.export(
estimator=expected_estimator, export_path="expected_path")
export_strategy.export(
estimator=expected_estimator,
export_path="expected_path",
eval_result=())
def testCallsExportFnWithCheckpointPath(self):
expected_estimator = {}
def export_fn(estimator, export_path, checkpoint_path):
self.assertEqual(expected_estimator, estimator)
self.assertEqual("expected_path", export_path)
self.assertEqual("expected_checkpoint_path", checkpoint_path)
export_strategy = export_strategy_lib.ExportStrategy(
name="test", export_fn=export_fn)
export_strategy.export(
estimator=expected_estimator,
export_path="expected_path",
checkpoint_path="expected_checkpoint_path")
export_strategy.export(
estimator=expected_estimator,
export_path="expected_path",
checkpoint_path="expected_checkpoint_path",
eval_result=())
def testCallsExportFnThatKnowsAboutEvalResultButItsNotGiven(self):
expected_estimator = {}
def export_fn(estimator, export_path, checkpoint_path, eval_result):
self.assertEqual(expected_estimator, estimator)
self.assertEqual("expected_path", export_path)
self.assertEqual(None, checkpoint_path)
self.assertEqual(None, eval_result)
export_strategy = export_strategy_lib.ExportStrategy(
name="test", export_fn=export_fn)
export_strategy.export(
estimator=expected_estimator, export_path="expected_path")
def testCallsExportFnThatAcceptsEvalResultButNotCheckpoint(self):
expected_estimator = {}
def export_fn(estimator, export_path, eval_result):
del estimator, export_path, eval_result
raise RuntimeError("Should raise ValueError before this.")
export_strategy = export_strategy_lib.ExportStrategy(
name="test", export_fn=export_fn)
expected_error_message = (
"An export_fn accepting eval_result must also accept checkpoint_path")
with self.assertRaisesRegexp(ValueError, expected_error_message):
export_strategy.export(
estimator=expected_estimator, export_path="expected_path")
with self.assertRaisesRegexp(ValueError, expected_error_message):
export_strategy.export(
estimator=expected_estimator,
export_path="expected_path",
checkpoint_path="unexpected_checkpoint_path")
with self.assertRaisesRegexp(ValueError, expected_error_message):
export_strategy.export(
estimator=expected_estimator,
export_path="expected_path",
eval_result=())
with self.assertRaisesRegexp(ValueError, expected_error_message):
export_strategy.export(
estimator=expected_estimator,
export_path="expected_path",
checkpoint_path="unexpected_checkpoint_path",
eval_result=())
def testCallsExportFnWithEvalResultAndCheckpointPath(self):
expected_estimator = {}
expected_eval_result = {}
def export_fn(estimator, export_path, checkpoint_path, eval_result):
self.assertEqual(expected_estimator, estimator)
self.assertEqual("expected_path", export_path)
self.assertEqual("expected_checkpoint_path", checkpoint_path)
self.assertEqual(expected_eval_result, eval_result)
export_strategy = export_strategy_lib.ExportStrategy(
name="test", export_fn=export_fn)
export_strategy.export(
estimator=expected_estimator,
export_path="expected_path",
checkpoint_path="expected_checkpoint_path",
eval_result=expected_eval_result)
class MakeExportStrategyTest(test.TestCase):
def test_make_export_strategy(self):
def _serving_input_fn():
return array_ops.constant([1]), None
export_strategy = export_strategy_lib.make_export_strategy(
serving_input_fn=_serving_input_fn,
assets_extra={"from/path": "to/path"},
as_text=False,
exports_to_keep=5)
self.assertTrue(
isinstance(export_strategy, export_strategy_lib.ExportStrategy))
def test_garbage_collect_exports(self):
export_dir_base = tempfile.mkdtemp() + "export/"
gfile.MkDir(export_dir_base)
export_dir_1 = _create_test_export_dir(export_dir_base)
export_dir_2 = _create_test_export_dir(export_dir_base)
export_dir_3 = _create_test_export_dir(export_dir_base)
export_dir_4 = _create_test_export_dir(export_dir_base)
self.assertTrue(gfile.Exists(export_dir_1))
self.assertTrue(gfile.Exists(export_dir_2))
self.assertTrue(gfile.Exists(export_dir_3))
self.assertTrue(gfile.Exists(export_dir_4))
def _serving_input_fn():
return array_ops.constant([1]), None
export_strategy = export_strategy_lib.make_export_strategy(
_serving_input_fn, exports_to_keep=2)
estimator = test.mock.Mock(spec=estimator_lib.Estimator)
# Garbage collect all but the most recent 2 exports,
# where recency is determined based on the timestamp directory names.
export_strategy.export(estimator, export_dir_base)
self.assertFalse(gfile.Exists(export_dir_1))
self.assertFalse(gfile.Exists(export_dir_2))
self.assertTrue(gfile.Exists(export_dir_3))
self.assertTrue(gfile.Exists(export_dir_4))
def _create_test_export_dir(export_dir_base):
export_dir = _get_timestamped_export_dir(export_dir_base)
gfile.MkDir(export_dir)
time.sleep(2)
return export_dir
def _get_timestamped_export_dir(export_dir_base):
# When we create a timestamped directory, there is a small chance that the
# directory already exists because another worker is also writing exports.
# In this case we just wait one second to get a new timestamp and try again.
# If this fails several times in a row, then something is seriously wrong.
max_directory_creation_attempts = 10
attempts = 0
while attempts < max_directory_creation_attempts:
export_timestamp = int(time.time())
export_dir = os.path.join(
compat.as_bytes(export_dir_base),
compat.as_bytes(str(export_timestamp)))
if not gfile.Exists(export_dir):
# Collisions are still possible (though extremely unlikely): this
# directory is not actually created yet, but it will be almost
# instantly on return from this function.
return export_dir
time.sleep(1)
attempts += 1
logging.warn("Export directory {} already exists; retrying (attempt {}/{})".
format(export_dir, attempts, max_directory_creation_attempts))
raise RuntimeError("Failed to obtain a unique export directory name after "
"{} attempts.".format(max_directory_creation_attempts))
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,209 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""System for specifying garbage collection (GC) of path based data.
This framework allows for GC of data specified by path names, for example files
on disk. gc.Path objects each represent a single item stored at a path and may
be a base directory,
/tmp/exports/0/...
/tmp/exports/1/...
...
or a fully qualified file,
/tmp/train-1.ckpt
/tmp/train-2.ckpt
...
A gc filter function takes and returns a list of gc.Path items. Filter
functions are responsible for selecting Path items for preservation or deletion.
Note that functions should always return a sorted list.
For example,
base_dir = "/tmp"
# Create the directories.
for e in xrange(10):
os.mkdir("%s/%d" % (base_dir, e), 0o755)
# Create a simple parser that pulls the export_version from the directory.
path_regex = "^" + re.escape(base_dir) + "/(\\d+)$"
def parser(path):
match = re.match(path_regex, path.path)
if not match:
return None
return path._replace(export_version=int(match.group(1)))
path_list = gc._get_paths("/tmp", parser) # contains all ten Paths
every_fifth = gc._mod_export_version(5)
print(every_fifth(path_list)) # shows ["/tmp/0", "/tmp/5"]
largest_three = gc.largest_export_versions(3)
print(largest_three(all_paths)) # shows ["/tmp/7", "/tmp/8", "/tmp/9"]
both = gc._union(every_fifth, largest_three)
print(both(all_paths)) # shows ["/tmp/0", "/tmp/5",
# "/tmp/7", "/tmp/8", "/tmp/9"]
# Delete everything not in 'both'.
to_delete = gc._negation(both)
for p in to_delete(all_paths):
gfile.DeleteRecursively(p.path) # deletes: "/tmp/1", "/tmp/2",
# "/tmp/3", "/tmp/4", "/tmp/6",
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import heapq
import math
import os
from tensorflow.python.platform import gfile
from tensorflow.python.util import compat
Path = collections.namedtuple('Path', 'path export_version')
def _largest_export_versions(n):
"""Creates a filter that keeps the largest n export versions.
Args:
n: number of versions to keep.
Returns:
A filter function that keeps the n largest paths.
"""
def keep(paths):
heap = []
for idx, path in enumerate(paths):
if path.export_version is not None:
heapq.heappush(heap, (path.export_version, idx))
keepers = [paths[i] for _, i in heapq.nlargest(n, heap)]
return sorted(keepers)
return keep
def _one_of_every_n_export_versions(n):
"""Creates a filter that keeps one of every n export versions.
Args:
n: interval size.
Returns:
A filter function that keeps exactly one path from each interval
[0, n], (n, 2n], (2n, 3n], etc... If more than one path exists in an
interval the largest is kept.
"""
def keep(paths):
"""A filter function that keeps exactly one out of every n paths."""
keeper_map = {} # map from interval to largest path seen in that interval
for p in paths:
if p.export_version is None:
# Skip missing export_versions.
continue
# Find the interval (with a special case to map export_version = 0 to
# interval 0.
interval = math.floor(
(p.export_version - 1) / n) if p.export_version else 0
existing = keeper_map.get(interval, None)
if (not existing) or (existing.export_version < p.export_version):
keeper_map[interval] = p
return sorted(keeper_map.values())
return keep
def _mod_export_version(n):
"""Creates a filter that keeps every export that is a multiple of n.
Args:
n: step size.
Returns:
A filter function that keeps paths where export_version % n == 0.
"""
def keep(paths):
keepers = []
for p in paths:
if p.export_version % n == 0:
keepers.append(p)
return sorted(keepers)
return keep
def _union(lf, rf):
"""Creates a filter that keeps the union of two filters.
Args:
lf: first filter
rf: second filter
Returns:
A filter function that keeps the n largest paths.
"""
def keep(paths):
l = set(lf(paths))
r = set(rf(paths))
return sorted(list(l|r))
return keep
def _negation(f):
"""Negate a filter.
Args:
f: filter function to invert
Returns:
A filter function that returns the negation of f.
"""
def keep(paths):
l = set(paths)
r = set(f(paths))
return sorted(list(l-r))
return keep
def _get_paths(base_dir, parser):
"""Gets a list of Paths in a given directory.
Args:
base_dir: directory.
parser: a function which gets the raw Path and can augment it with
information such as the export_version, or ignore the path by returning
None. An example parser may extract the export version from a path
such as "/tmp/exports/100" an another may extract from a full file
name such as "/tmp/checkpoint-99.out".
Returns:
A list of Paths contained in the base directory with the parsing function
applied.
By default the following fields are populated,
- Path.path
The parsing function is responsible for populating,
- Path.export_version
"""
raw_paths = gfile.ListDirectory(base_dir)
paths = []
for r in raw_paths:
p = parser(Path(os.path.join(compat.as_str_any(base_dir),
compat.as_str_any(r)),
None))
if p:
paths.append(p)
return sorted(paths)

View File

@ -0,0 +1,145 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for garbage collection utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import re
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.estimator import gc
from tensorflow.python.framework import test_util
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.util import compat
def _create_parser(base_dir):
# create a simple parser that pulls the export_version from the directory.
def parser(path):
# Modify the path object for RegEx match for Windows Paths
if os.name == "nt":
match = re.match(
"^" + compat.as_str_any(base_dir).replace("\\", "/") + "/(\\d+)$",
compat.as_str_any(path.path).replace("\\", "/"))
else:
match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$",
compat.as_str_any(path.path))
if not match:
return None
return path._replace(export_version=int(match.group(1)))
return parser
class GcTest(test_util.TensorFlowTestCase):
def testLargestExportVersions(self):
paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)]
newest = gc._largest_export_versions(2)
n = newest(paths)
self.assertEqual(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])
def testLargestExportVersionsDoesNotDeleteZeroFolder(self):
paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)]
newest = gc._largest_export_versions(2)
n = newest(paths)
self.assertEqual(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)])
def testModExportVersion(self):
paths = [
gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6),
gc.Path("/foo", 9)
]
mod = gc._mod_export_version(2)
self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)])
mod = gc._mod_export_version(3)
self.assertEqual(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)])
def testOneOfEveryNExportVersions(self):
paths = [
gc.Path("/foo", 0), gc.Path("/foo", 1), gc.Path("/foo", 3),
gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 7),
gc.Path("/foo", 8), gc.Path("/foo", 33)
]
one_of = gc._one_of_every_n_export_versions(3)
self.assertEqual(
one_of(paths), [
gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 8),
gc.Path("/foo", 33)
])
def testOneOfEveryNExportVersionsZero(self):
# Zero is a special case since it gets rolled into the first interval.
# Test that here.
paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)]
one_of = gc._one_of_every_n_export_versions(3)
self.assertEqual(one_of(paths), [gc.Path("/foo", 0), gc.Path("/foo", 5)])
def testUnion(self):
paths = []
for i in xrange(10):
paths.append(gc.Path("/foo", i))
f = gc._union(gc._largest_export_versions(3), gc._mod_export_version(3))
self.assertEqual(
f(paths), [
gc.Path("/foo", 0), gc.Path("/foo", 3), gc.Path("/foo", 6),
gc.Path("/foo", 7), gc.Path("/foo", 8), gc.Path("/foo", 9)
])
def testNegation(self):
paths = [
gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6),
gc.Path("/foo", 9)
]
mod = gc._negation(gc._mod_export_version(2))
self.assertEqual(mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)])
mod = gc._negation(gc._mod_export_version(3))
self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)])
def testPathsWithParse(self):
base_dir = os.path.join(test.get_temp_dir(), "paths_parse")
self.assertFalse(gfile.Exists(base_dir))
for p in xrange(3):
gfile.MakeDirs(os.path.join(base_dir, "%d" % p))
# add a base_directory to ignore
gfile.MakeDirs(os.path.join(base_dir, "ignore"))
self.assertEqual(
gc._get_paths(base_dir, _create_parser(base_dir)),
[
gc.Path(os.path.join(base_dir, "0"), 0),
gc.Path(os.path.join(base_dir, "1"), 1),
gc.Path(os.path.join(base_dir, "2"), 2)
])
def testMixedStrTypes(self):
temp_dir = compat.as_bytes(test.get_temp_dir())
for sub_dir in ["str", b"bytes", u"unicode"]:
base_dir = os.path.join(
(temp_dir if isinstance(sub_dir, bytes) else temp_dir.decode()),
sub_dir)
self.assertFalse(gfile.Exists(base_dir))
gfile.MakeDirs(os.path.join(compat.as_str_any(base_dir), "42"))
gc._get_paths(base_dir, _create_parser(base_dir))
if __name__ == "__main__":
test.main()