PY3 Migration - //tensorflow/tools [2]

PiperOrigin-RevId: 275334989
Change-Id: Ia0b660e7ce7cde8e97f8d7cb3a39afe7fec63a7f
This commit is contained in:
Hye Soo Yang 2019-10-17 14:07:44 -07:00 committed by TensorFlower Gardener
parent b78675a3bd
commit c396546ca3
39 changed files with 412 additions and 258 deletions

View File

@ -23,5 +23,6 @@ py_library(
":api_objects_proto_py",
"//tensorflow/python:platform",
"//tensorflow/python:util",
"@six_archive//:six",
],
)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -19,8 +20,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import enum
import sys
import six
from google.protobuf import message
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
@ -191,7 +195,8 @@ class PythonObjectToProtoVisitor(object):
if (_SkipMember(parent, member_name) or
isinstance(member_obj, deprecation.HiddenTfApiAttribute)):
return
if member_name == '__init__' or not member_name.startswith('_'):
if member_name == '__init__' or not six.ensure_str(
member_name).startswith('_'):
if tf_inspect.isroutine(member_obj):
new_method = proto.member_method.add()
new_method.name = member_name

View File

@ -14,13 +14,16 @@ py_library(
name = "public_api",
srcs = ["public_api.py"],
srcs_version = "PY2AND3",
deps = ["//tensorflow/python:util"],
deps = [
"//tensorflow/python:util",
"@six_archive//:six",
],
)
py_test(
name = "public_api_test",
srcs = ["public_api_test.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":public_api",
@ -32,13 +35,16 @@ py_library(
name = "traverse",
srcs = ["traverse.py"],
srcs_version = "PY2AND3",
deps = ["//tensorflow/python:util"],
deps = [
"//tensorflow/python:util",
"@six_archive//:six",
],
)
py_test(
name = "traverse_test",
srcs = ["traverse_test.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":test_module1",

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -20,6 +21,8 @@ from __future__ import print_function
import re
import six
from tensorflow.python.util import tf_inspect
@ -108,9 +111,9 @@ class PublicAPIVisitor(object):
"""Return whether a name is private."""
# TODO(wicke): Find out what names to exclude.
del obj # Unused.
return ((path in self._private_map and
name in self._private_map[path]) or
(name.startswith('_') and not re.match('__.*__$', name) or
return ((path in self._private_map and name in self._private_map[path]) or
(six.ensure_str(name).startswith('_') and
not re.match('__.*__$', six.ensure_str(name)) or
name in ['__base__', '__class__']))
def _do_not_descend(self, path, name):
@ -122,7 +125,8 @@ class PublicAPIVisitor(object):
"""Visitor interface, see `traverse` for details."""
# Avoid long waits in cases of pretty unambiguous failure.
if tf_inspect.ismodule(parent) and len(path.split('.')) > 10:
if tf_inspect.ismodule(parent) and len(
six.ensure_str(path).split('.')) > 10:
raise RuntimeError('Modules nested too deep:\n%s.%s\n\nThis is likely a '
'problem with an accidental public import.' %
(self._root_name, path))

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -21,6 +22,8 @@ from __future__ import print_function
import enum
import sys
import six
from tensorflow.python.util import tf_inspect
__all__ = ['traverse']
@ -59,7 +62,8 @@ def _traverse_internal(root, visit, stack, path):
if any(child is item for item in new_stack): # `in`, but using `is`
continue
child_path = path + '.' + name if path else name
child_path = six.ensure_str(path) + '.' + six.ensure_str(
name) if path else name
_traverse_internal(child, visit, new_stack, child_path)

View File

@ -14,6 +14,7 @@ py_library(
name = "ipynb",
srcs = ["ipynb.py"],
srcs_version = "PY2AND3",
deps = ["@six_archive//:six"],
)
py_library(
@ -29,7 +30,7 @@ py_library(
py_test(
name = "ast_edits_test",
srcs = ["ast_edits_test.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":ast_edits",
@ -42,22 +43,28 @@ py_test(
py_binary(
name = "tf_upgrade",
srcs = ["tf_upgrade.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [":tf_upgrade_lib"],
deps = [
":tf_upgrade_lib",
"@six_archive//:six",
],
)
py_library(
name = "tf_upgrade_lib",
srcs = ["tf_upgrade.py"],
srcs_version = "PY2AND3",
deps = [":ast_edits"],
deps = [
":ast_edits",
"@six_archive//:six",
],
)
py_test(
name = "tf_upgrade_test",
srcs = ["tf_upgrade_test.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [
"no_pip",
@ -96,6 +103,7 @@ py_library(
py_test(
name = "all_renames_v2_test",
srcs = ["all_renames_v2_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":all_renames_v2",
@ -108,6 +116,7 @@ py_test(
py_library(
name = "module_deprecations_v2",
srcs = ["module_deprecations_v2.py"],
srcs_version = "PY2AND3",
deps = [":ast_edits"],
)
@ -145,13 +154,14 @@ py_binary(
":ipynb",
":tf_upgrade_v2_lib",
":tf_upgrade_v2_safety_lib",
"@six_archive//:six",
],
)
py_test(
name = "tf_upgrade_v2_test",
srcs = ["tf_upgrade_v2_test.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
tags = ["v1only"],
deps = [
@ -169,6 +179,7 @@ py_test(
py_test(
name = "tf_upgrade_v2_safety_test",
srcs = ["tf_upgrade_v2_safety_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":tf_upgrade_v2_safety_lib",
@ -208,7 +219,7 @@ py_test(
name = "test_file_v1_0",
size = "small",
srcs = ["test_file_v1_0.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
@ -235,7 +246,7 @@ py_test(
name = "test_file_v1_12",
size = "small",
srcs = ["testdata/test_file_v1_12.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
tags = ["v1only"],
deps = [
@ -247,7 +258,7 @@ py_test(
name = "test_file_v2_0",
size = "small",
srcs = ["test_file_v2_0.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -29,6 +30,7 @@ import traceback
import pasta
import six
from six.moves import range
# Some regular expressions we will need for parsing
FIND_OPEN = re.compile(r"^\s*(\[).*$")
@ -56,7 +58,7 @@ def full_name_node(name, ctx=ast.Load()):
Returns:
A Name or Attribute node.
"""
names = name.split(".")
names = six.ensure_str(name).split(".")
names.reverse()
node = ast.Name(id=names.pop(), ctx=ast.Load())
while names:
@ -301,7 +303,7 @@ class _PastaEditVisitor(ast.NodeVisitor):
function_transformers = getattr(self._api_change_spec,
transformer_field, {})
glob_name = "*." + name if name else None
glob_name = "*." + six.ensure_str(name) if name else None
transformers = []
if full_name in function_transformers:
transformers.append(function_transformers[full_name])
@ -318,7 +320,7 @@ class _PastaEditVisitor(ast.NodeVisitor):
function_transformers = getattr(self._api_change_spec,
transformer_field, {})
glob_name = "*." + name if name else None
glob_name = "*." + six.ensure_str(name) if name else None
transformers = function_transformers.get("*", {}).copy()
transformers.update(function_transformers.get(glob_name, {}))
transformers.update(function_transformers.get(full_name, {}))
@ -351,7 +353,7 @@ class _PastaEditVisitor(ast.NodeVisitor):
function_warnings = self._api_change_spec.function_warnings
if full_name in function_warnings:
level, message = function_warnings[full_name]
message = message.replace("<function name>", full_name)
message = six.ensure_str(message).replace("<function name>", full_name)
self.add_log(level, node.lineno, node.col_offset,
"%s requires manual check. %s" % (full_name, message))
return True
@ -363,7 +365,8 @@ class _PastaEditVisitor(ast.NodeVisitor):
warnings = self._api_change_spec.module_deprecations
if full_name in warnings:
level, message = warnings[full_name]
message = message.replace("<function name>", whole_name)
message = six.ensure_str(message).replace("<function name>",
six.ensure_str(whole_name))
self.add_log(level, node.lineno, node.col_offset,
"Using member %s in deprecated module %s. %s" % (whole_name,
full_name,
@ -394,7 +397,7 @@ class _PastaEditVisitor(ast.NodeVisitor):
# an attribute.
warned = False
if isinstance(node.func, ast.Attribute):
warned = self._maybe_add_warning(node, "*." + name)
warned = self._maybe_add_warning(node, "*." + six.ensure_str(name))
# All arg warnings are handled here, since only we have the args
arg_warnings = self._get_applicable_dict("function_arg_warnings",
@ -406,7 +409,8 @@ class _PastaEditVisitor(ast.NodeVisitor):
present, _ = get_arg_value(node, kwarg, arg) or variadic_args
if present:
warned = True
warning_message = warning.replace("<function name>", full_name or name)
warning_message = six.ensure_str(warning).replace(
"<function name>", six.ensure_str(full_name or name))
template = "%s called with %s argument, requires manual check: %s"
if variadic_args:
template = ("%s called with *args or **kwargs that may include %s, "
@ -625,13 +629,13 @@ class _PastaEditVisitor(ast.NodeVisitor):
# This loop processes imports in the format
# import foo as f, bar as b
for import_alias in node.names:
all_import_components = import_alias.name.split(".")
all_import_components = six.ensure_str(import_alias.name).split(".")
# Look for rename, starting with longest import levels.
found_update = False
for i in reversed(range(1, max_submodule_depth + 1)):
for i in reversed(list(range(1, max_submodule_depth + 1))):
import_component = all_import_components[0]
for j in range(1, min(i, len(all_import_components))):
import_component += "." + all_import_components[j]
import_component += "." + six.ensure_str(all_import_components[j])
import_rename_spec = import_renames.get(import_component, None)
if not import_rename_spec or excluded_from_module_rename(
@ -674,7 +678,8 @@ class _PastaEditVisitor(ast.NodeVisitor):
if old_suffix is None:
old_suffix = os.linesep
if os.linesep not in old_suffix:
pasta.base.formatting.set(node, "suffix", old_suffix + os.linesep)
pasta.base.formatting.set(node, "suffix",
six.ensure_str(old_suffix) + os.linesep)
# Apply indentation to new node.
pasta.base.formatting.set(new_line_node, "prefix",
@ -720,7 +725,7 @@ class _PastaEditVisitor(ast.NodeVisitor):
# Look for rename based on first component of from-import.
# i.e. based on foo in foo.bar.
from_import_first_component = from_import.split(".")[0]
from_import_first_component = six.ensure_str(from_import).split(".")[0]
import_renames = getattr(self._api_change_spec, "import_renames", {})
import_rename_spec = import_renames.get(from_import_first_component, None)
if not import_rename_spec:
@ -918,7 +923,7 @@ class ASTCodeUpgrader(object):
def format_log(self, log, in_filename):
log_string = "%d:%d: %s: %s" % (log[1], log[2], log[0], log[3])
if in_filename:
return in_filename + ":" + log_string
return six.ensure_str(in_filename) + ":" + log_string
else:
return log_string
@ -945,12 +950,12 @@ class ASTCodeUpgrader(object):
return 1, pasta.dump(t), logs, errors
def _format_log(self, log, in_filename, out_filename):
text = "-" * 80 + "\n"
text = six.ensure_str("-" * 80) + "\n"
text += "Processing file %r\n outputting to %r\n" % (in_filename,
out_filename)
text += "-" * 80 + "\n\n"
text += six.ensure_str("-" * 80) + "\n\n"
text += "\n".join(log) + "\n"
text += "-" * 80 + "\n\n"
text += six.ensure_str("-" * 80) + "\n\n"
return text
def process_opened_file(self, in_filename, in_file, out_filename, out_file):
@ -1017,8 +1022,10 @@ class ASTCodeUpgrader(object):
files_to_process = []
files_to_copy = []
for dir_name, _, file_list in os.walk(root_directory):
py_files = [f for f in file_list if f.endswith(".py")]
copy_files = [f for f in file_list if not f.endswith(".py")]
py_files = [f for f in file_list if six.ensure_str(f).endswith(".py")]
copy_files = [
f for f in file_list if not six.ensure_str(f).endswith(".py")
]
for filename in py_files:
fullpath = os.path.join(dir_name, filename)
fullpath_output = os.path.join(output_root_directory,
@ -1036,9 +1043,9 @@ class ASTCodeUpgrader(object):
file_count = 0
tree_errors = {}
report = ""
report += ("=" * 80) + "\n"
report += six.ensure_str(("=" * 80)) + "\n"
report += "Input tree: %r\n" % root_directory
report += ("=" * 80) + "\n"
report += six.ensure_str(("=" * 80)) + "\n"
for input_path, output_path in files_to_process:
output_directory = os.path.dirname(output_path)
@ -1074,16 +1081,19 @@ class ASTCodeUpgrader(object):
"""Process a directory of python files in place."""
files_to_process = []
for dir_name, _, file_list in os.walk(root_directory):
py_files = [os.path.join(dir_name,
f) for f in file_list if f.endswith(".py")]
py_files = [
os.path.join(dir_name, f)
for f in file_list
if six.ensure_str(f).endswith(".py")
]
files_to_process += py_files
file_count = 0
tree_errors = {}
report = ""
report += ("=" * 80) + "\n"
report += six.ensure_str(("=" * 80)) + "\n"
report += "Input tree: %r\n" % root_directory
report += ("=" * 80) + "\n"
report += six.ensure_str(("=" * 80)) + "\n"
for path in files_to_process:
if os.path.islink(path):

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -24,6 +25,7 @@ import json
import re
import shutil
import tempfile
import six
CodeLine = collections.namedtuple("CodeLine", ["cell_number", "code"])
@ -31,7 +33,8 @@ def is_python(cell):
"""Checks if the cell consists of Python code."""
return (cell["cell_type"] == "code" # code cells only
and cell["source"] # non-empty cells
and not cell["source"][0].startswith("%%")) # multiline eg: %%bash
and not six.ensure_str(cell["source"][0]).startswith("%%")
) # multiline eg: %%bash
def process_file(in_filename, out_filename, upgrader):
@ -47,8 +50,9 @@ def process_file(in_filename, out_filename, upgrader):
upgrader.update_string_pasta("\n".join(raw_lines), in_filename))
if temp_file and processed_file:
new_notebook = _update_notebook(notebook, raw_code,
new_file_content.split("\n"))
new_notebook = _update_notebook(
notebook, raw_code,
six.ensure_str(new_file_content).split("\n"))
json.dump(new_notebook, temp_file)
else:
raise SyntaxError(
@ -78,7 +82,7 @@ def skip_magic(code_line, magic_list):
"""
for magic in magic_list:
if code_line.startswith(magic):
if six.ensure_str(code_line).startswith(magic):
return True
return False
@ -120,7 +124,7 @@ def _get_code(input_file):
# Idea is to comment these lines, for upgrade time
if skip_magic(code_line, ["%", "!", "?"]) or is_line_split:
# Found a special character, need to "encode"
code_line = "###!!!" + code_line
code_line = "###!!!" + six.ensure_str(code_line)
# if this cell ends with `\` -> skip the next line
is_line_split = check_line_split(code_line)
@ -131,14 +135,16 @@ def _get_code(input_file):
# Sometimes, people leave \n at the end of cell
# in order to migrate only related things, and make the diff
# the smallest -> here is another hack
if (line_idx == len(cell_lines) - 1) and code_line.endswith("\n"):
code_line = code_line.replace("\n", "###===")
if (line_idx == len(cell_lines) -
1) and six.ensure_str(code_line).endswith("\n"):
code_line = six.ensure_str(code_line).replace("\n", "###===")
# sometimes a line would start with `\n` and content after
# that's the hack for this
raw_code.append(
CodeLine(cell_index,
code_line.rstrip().replace("\n", "###===")))
six.ensure_str(code_line.rstrip()).replace("\n",
"###===")))
cell_index += 1

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -20,6 +21,8 @@ from __future__ import print_function
import argparse
import six
from tensorflow.tools.compatibility import ast_edits
@ -245,7 +248,7 @@ Simple usage:
else:
parser.print_help()
if report_text:
open(report_filename, "w").write(report_text)
open(report_filename, "w").write(six.ensure_str(report_text))
print("TensorFlow 1.0 Upgrade Script")
print("-----------------------------")
print("Converted %d files\n" % files_processed)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -46,14 +47,16 @@ class TestUpgrade(test_util.TensorFlowTestCase):
def testParseError(self):
_, report, unused_errors, unused_new_text = self._upgrade(
"import tensorflow as tf\na + \n")
self.assertTrue(report.find("Failed to parse") != -1)
self.assertNotEqual(six.ensure_str(report).find("Failed to parse"), -1)
def testReport(self):
text = "tf.mul(a, b)\n"
_, report, unused_errors, unused_new_text = self._upgrade(text)
# This is not a complete test, but it is a sanity test that a report
# is generating information.
self.assertTrue(report.find("Renamed function `tf.mul` to `tf.multiply`"))
self.assertTrue(
six.ensure_str(report).find(
"Renamed function `tf.mul` to `tf.multiply`"))
def testRename(self):
text = "tf.mul(a, tf.sub(b, c))\n"

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -24,6 +25,7 @@ import functools
import sys
import pasta
import six
from tensorflow.tools.compatibility import all_renames_v2
from tensorflow.tools.compatibility import ast_edits
@ -47,8 +49,9 @@ class VersionedTFImport(ast_edits.AnalysisResult):
def __init__(self, version):
self.log_level = ast_edits.INFO
self.log_message = ("Not upgrading symbols because `tensorflow." + version
+ "` was directly imported as `tf`.")
self.log_message = ("Not upgrading symbols because `tensorflow." +
six.ensure_str(version) +
"` was directly imported as `tf`.")
class TFAPIImportAnalysisSpec(ast_edits.APIAnalysisSpec):
@ -1687,7 +1690,7 @@ def _rename_if_arg_found_transformer(parent, node, full_name, name, logs,
# All conditions met, insert v1 and log what we did.
# We must have a full name, so the func is an attribute.
new_name = full_name.replace("tf.", "tf.compat.v1.", 1)
new_name = six.ensure_str(full_name).replace("tf.", "tf.compat.v1.", 1)
node.func = ast_edits.full_name_node(new_name)
logs.append((
ast_edits.INFO, node.lineno, node.col_offset,
@ -1715,8 +1718,8 @@ def _iterator_transformer(parent, node, full_name, name, logs):
# (tf.compat.v1.data), or something which is handled in the rename
# (tf.data). This transformer only handles the method call to function call
# conversion.
if full_name and (full_name.startswith("tf.compat.v1.data") or
full_name.startswith("tf.data")):
if full_name and (six.ensure_str(full_name).startswith("tf.compat.v1.data") or
six.ensure_str(full_name).startswith("tf.data")):
return
# This should never happen, since we're only called for Attribute nodes.
@ -2460,7 +2463,7 @@ def _name_scope_transformer(parent, node, full_name, name, logs):
def _rename_to_compat_v1(node, full_name, logs, reason):
new_name = full_name.replace("tf.", "tf.compat.v1.", 1)
new_name = six.ensure_str(full_name).replace("tf.", "tf.compat.v1.", 1)
return _rename_func(node, full_name, new_name, logs, reason)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -20,10 +21,13 @@ from __future__ import print_function
import argparse
import six
from tensorflow.tools.compatibility import ast_edits
from tensorflow.tools.compatibility import ipynb
from tensorflow.tools.compatibility import tf_upgrade_v2
from tensorflow.tools.compatibility import tf_upgrade_v2_safety
# Make straightforward changes to convert to 2.0. In harder cases,
# use compat.v1.
_DEFAULT_MODE = "DEFAULT"
@ -35,10 +39,10 @@ _SAFETY_MODE = "SAFETY"
def process_file(in_filename, out_filename, upgrader):
"""Process a file of type `.py` or `.ipynb`."""
if in_filename.endswith(".py"):
if six.ensure_str(in_filename).endswith(".py"):
files_processed, report_text, errors = \
upgrader.process_file(in_filename, out_filename)
elif in_filename.endswith(".ipynb"):
elif six.ensure_str(in_filename).endswith(".ipynb"):
files_processed, report_text, errors = \
ipynb.process_file(in_filename, out_filename, upgrader)
else:
@ -157,24 +161,24 @@ Simple usage:
for f in errors:
if errors[f]:
num_errors += len(errors[f])
report.append("-" * 80 + "\n")
report.append(six.ensure_str("-" * 80) + "\n")
report.append("File: %s\n" % f)
report.append("-" * 80 + "\n")
report.append(six.ensure_str("-" * 80) + "\n")
report.append("\n".join(errors[f]) + "\n")
report = ("TensorFlow 2.0 Upgrade Script\n"
"-----------------------------\n"
"Converted %d files\n" % files_processed +
"Detected %d issues that require attention" % num_errors + "\n" +
"-" * 80 + "\n") + "".join(report)
detailed_report_header = "=" * 80 + "\n"
six.ensure_str("-" * 80) + "\n") + "".join(report)
detailed_report_header = six.ensure_str("=" * 80) + "\n"
detailed_report_header += "Detailed log follows:\n\n"
detailed_report_header += "=" * 80 + "\n"
detailed_report_header += six.ensure_str("=" * 80) + "\n"
with open(report_filename, "w") as report_file:
report_file.write(report)
report_file.write(detailed_report_header)
report_file.write(report_text)
report_file.write(six.ensure_str(report_text))
if args.print_all:
print(report)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -39,7 +40,7 @@ from tensorflow.tools.compatibility import tf_upgrade_v2
def get_symbol_for_name(root, name):
name_parts = name.split(".")
name_parts = six.ensure_str(name).split(".")
symbol = root
# Iterate starting with second item since 1st item is "tf.".
for part in name_parts[1:]:
@ -66,12 +67,13 @@ def get_func_and_args_from_str(call_str):
Returns:
(function_name, list of arg names) tuple.
"""
open_paren_index = call_str.find("(")
open_paren_index = six.ensure_str(call_str).find("(")
close_paren_index = call_str.rfind(")")
function_name = call_str[:call_str.find("(")]
args = call_str[open_paren_index+1:close_paren_index].split(",")
args = [arg.split("=")[0].strip() for arg in args]
function_name = call_str[:six.ensure_str(call_str).find("(")]
args = six.ensure_str(call_str[open_paren_index +
1:close_paren_index]).split(",")
args = [six.ensure_str(arg).split("=")[0].strip() for arg in args]
args = [arg for arg in args if arg] # filter out empty strings
return function_name, args
@ -96,7 +98,7 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
_, attr = tf_decorator.unwrap(child[1])
api_names_v2 = tf_export.get_v2_names(attr)
for name in api_names_v2:
cls.v2_symbols["tf." + name] = attr
cls.v2_symbols["tf." + six.ensure_str(name)] = attr
visitor = public_api.PublicAPIVisitor(symbol_collector)
visitor.private_map["tf.compat"] = ["v1"]
@ -109,7 +111,7 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
_, attr = tf_decorator.unwrap(child[1])
api_names_v1 = tf_export.get_v1_names(attr)
for name in api_names_v1:
cls.v1_symbols["tf." + name] = attr
cls.v1_symbols["tf." + six.ensure_str(name)] = attr
visitor = public_api.PublicAPIVisitor(symbol_collector_v1)
traverse.traverse(tf.compat.v1, visitor)
@ -138,15 +140,16 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
def testParseError(self):
_, report, unused_errors, unused_new_text = self._upgrade(
"import tensorflow as tf\na + \n")
self.assertTrue(report.find("Failed to parse") != -1)
self.assertNotEqual(six.ensure_str(report).find("Failed to parse"), -1)
def testReport(self):
text = "tf.angle(a)\n"
_, report, unused_errors, unused_new_text = self._upgrade(text)
# This is not a complete test, but it is a sanity test that a report
# is generating information.
self.assertTrue(report.find("Renamed function `tf.angle` to "
"`tf.math.angle`"))
self.assertTrue(
six.ensure_str(report).find("Renamed function `tf.angle` to "
"`tf.math.angle`"))
def testRename(self):
text = "tf.conj(a)\n"
@ -169,7 +172,7 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
_, attr = tf_decorator.unwrap(child[1])
api_names = tf_export.get_v1_names(attr)
for name in api_names:
_, _, _, text = self._upgrade("tf." + name)
_, _, _, text = self._upgrade("tf." + six.ensure_str(name))
if (text and
not text.startswith("tf.compat.v1") and
not text.startswith("tf.compat.v2") and
@ -198,9 +201,9 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
api_names = tf_export.get_v1_names(attr)
for name in api_names:
if collect:
v1_symbols.add("tf." + name)
v1_symbols.add("tf." + six.ensure_str(name))
else:
_, _, _, text = self._upgrade("tf." + name)
_, _, _, text = self._upgrade("tf." + six.ensure_str(name))
if (text and
not text.startswith("tf.compat.v1") and
not text.startswith("tf.compat.v2") and
@ -337,16 +340,16 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
def testPositionsMatchArgGiven(self):
full_dict = tf_upgrade_v2.TFAPIChangeSpec().function_arg_warnings
method_names = full_dict.keys()
method_names = list(full_dict.keys())
for method_name in method_names:
args = full_dict[method_name].keys()
args = list(full_dict[method_name].keys())
if "contrib" in method_name:
# Skip descending and fetching contrib methods during test. These are
# not available in the repo anymore.
continue
elif method_name.startswith("*."):
elif six.ensure_str(method_name).startswith("*."):
# special case for optimizer methods
method = method_name.replace("*", "tf.train.Optimizer")
method = six.ensure_str(method_name).replace("*", "tf.train.Optimizer")
else:
method = method_name
@ -354,7 +357,7 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
arg_spec = tf_inspect.getfullargspec(method)
for (arg, pos) in args:
# to deal with the self argument on methods on objects
if method_name.startswith("*."):
if six.ensure_str(method_name).startswith("*."):
pos += 1
self.assertEqual(arg_spec[0][pos], arg)

View File

@ -6,7 +6,7 @@ package(
py_binary(
name = "generate_v2_renames_map",
srcs = ["generate_v2_renames_map.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
@ -15,13 +15,14 @@ py_binary(
"//tensorflow/tools/common:public_api",
"//tensorflow/tools/common:traverse",
"//tensorflow/tools/compatibility:all_renames_v2",
"@six_archive//:six",
],
)
py_binary(
name = "generate_v2_reorders_map",
srcs = ["generate_v2_reorders_map.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -23,9 +24,9 @@ To update renames_v2.py, run:
# pylint: enable=line-too-long
import sys
import six
import tensorflow as tf
# This import is needed so that TensorFlow python modules are in sys.modules.
from tensorflow import python as tf_python # pylint: disable=unused-import
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import app
@ -35,6 +36,7 @@ from tensorflow.tools.common import public_api
from tensorflow.tools.common import traverse
from tensorflow.tools.compatibility import all_renames_v2
# This import is needed so that TensorFlow python modules are in sys.modules.
_OUTPUT_FILE_PATH = 'third_party/tensorflow/tools/compatibility/renames_v2.py'
_FILE_HEADER = """# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
@ -178,7 +180,8 @@ def update_renames_v2(output_file_path):
rename_lines = [
get_rename_line(name, canonical_name)
for name, canonical_name in all_renames
if 'tf.' + name not in manual_renames]
if 'tf.' + six.ensure_str(name) not in manual_renames
]
renames_file_text = '%srenames = {\n%s\n}\n' % (
_FILE_HEADER, ',\n'.join(sorted(rename_lines)))
file_io.write_string_to_file(output_file_path, renames_file_text)

View File

@ -31,6 +31,7 @@ py_library(
"doc_generator_visitor.py",
],
srcs_version = "PY2AND3",
deps = ["@six_archive//:six"],
)
py_test(
@ -39,12 +40,13 @@ py_test(
srcs = [
"doc_generator_visitor_test.py",
],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":doc_generator_visitor",
":generate_lib",
"//tensorflow/python:platform_test",
"@six_archive//:six",
],
)
@ -59,7 +61,7 @@ py_test(
name = "doc_controls_test",
size = "small",
srcs = ["doc_controls_test.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":doc_controls",
@ -77,6 +79,7 @@ py_library(
"//tensorflow/python:platform",
"//tensorflow/python:util",
"@astor_archive//:astor",
"@six_archive//:six",
],
)
@ -84,11 +87,12 @@ py_test(
name = "parser_test",
size = "small",
srcs = ["parser_test.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":parser",
"//tensorflow/python:platform_test",
"@six_archive//:six",
],
)
@ -96,6 +100,7 @@ py_library(
name = "pretty_docs",
srcs = ["pretty_docs.py"],
srcs_version = "PY2AND3",
deps = ["@six_archive//:six"],
)
py_library(
@ -120,7 +125,7 @@ py_test(
name = "generate_lib_test",
size = "small",
srcs = ["generate_lib_test.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":generate_lib",
@ -132,7 +137,7 @@ py_test(
py_binary(
name = "generate",
srcs = ["generate.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":generate_lib",
@ -165,9 +170,12 @@ py_test(
py_binary(
name = "generate2",
srcs = ["generate2.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [":generate2_lib"],
deps = [
":generate2_lib",
"@six_archive//:six",
],
)
py_library(
@ -184,16 +192,18 @@ py_library(
name = "py_guide_parser",
srcs = ["py_guide_parser.py"],
srcs_version = "PY2AND3",
deps = ["@six_archive//:six"],
)
py_test(
name = "py_guide_parser_test",
size = "small",
srcs = ["py_guide_parser_test.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":py_guide_parser",
"//tensorflow/python:client_testlib",
"@six_archive//:six",
],
)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -48,7 +49,7 @@ class DocGeneratorVisitor(object):
def set_root_name(self, root_name):
"""Sets the root name for subsequent __call__s."""
self._root_name = root_name or ''
self._prefix = (root_name + '.') if root_name else ''
self._prefix = (six.ensure_str(root_name) + '.') if root_name else ''
@property
def index(self):
@ -178,7 +179,7 @@ class DocGeneratorVisitor(object):
A tuple of scores. When sorted the preferred name will have the lowest
value.
"""
parts = name.split('.')
parts = six.ensure_str(name).split('.')
short_name = parts[-1]
container = self._index['.'.join(parts[:-1])]

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -20,6 +21,8 @@ from __future__ import print_function
import types
import six
from tensorflow.python.platform import googletest
from tensorflow.tools.docs import doc_generator_visitor
from tensorflow.tools.docs import generate_lib
@ -29,9 +32,9 @@ class NoDunderVisitor(doc_generator_visitor.DocGeneratorVisitor):
def __call__(self, parent_name, parent, children):
"""Drop all the dunder methods to make testing easier."""
children = [
(name, obj) for (name, obj) in children if not name.startswith('_')
]
children = [(name, obj)
for (name, obj) in children
if not six.ensure_str(name).startswith('_')]
super(NoDunderVisitor, self).__call__(parent_name, parent, children)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -147,8 +148,10 @@ def write_docs(output_dir,
duplicates = [item for item in duplicates if item != full_name]
for dup in duplicates:
from_path = os.path.join(site_api_path, dup.replace('.', '/'))
to_path = os.path.join(site_api_path, full_name.replace('.', '/'))
from_path = os.path.join(site_api_path,
six.ensure_str(dup).replace('.', '/'))
to_path = os.path.join(site_api_path,
six.ensure_str(full_name).replace('.', '/'))
redirects.append((
os.path.join('/', from_path),
os.path.join('/', to_path)))
@ -167,7 +170,7 @@ def write_docs(output_dir,
# Generate table of contents
# Put modules in alphabetical order, case-insensitive
modules = sorted(module_children.keys(), key=lambda a: a.upper())
modules = sorted(list(module_children.keys()), key=lambda a: a.upper())
leftnav_path = os.path.join(output_dir, '_toc.yaml')
with open(leftnav_path, 'w') as f:
@ -183,16 +186,15 @@ def write_docs(output_dir,
if indent_num > 1:
# tf.contrib.baysflow.entropy will be under
# tf.contrib->baysflow->entropy
title = module.split('.')[-1]
title = six.ensure_str(module).split('.')[-1]
else:
title = module
header = [
'- title: ' + title,
' section:',
' - title: Overview',
' path: ' + os.path.join('/', site_api_path,
symbol_to_file[module])]
'- title: ' + six.ensure_str(title), ' section:',
' - title: Overview', ' path: ' +
os.path.join('/', site_api_path, symbol_to_file[module])
]
header = ''.join([indent+line+'\n' for line in header])
f.write(header)
@ -211,8 +213,9 @@ def write_docs(output_dir,
# Write a global index containing all full names with links.
with open(os.path.join(output_dir, 'index.md'), 'w') as f:
f.write(
parser.generate_global_index(root_title, parser_config.index,
parser_config.reference_resolver))
six.ensure_str(
parser.generate_global_index(root_title, parser_config.index,
parser_config.reference_resolver)))
def add_dict_to_dict(add_from, add_to):
@ -345,7 +348,7 @@ def build_doc_index(src_dir):
for dirpath, _, filenames in os.walk(src_dir):
suffix = os.path.relpath(path=dirpath, start=src_dir)
for base_name in filenames:
if not base_name.endswith('.md'):
if not six.ensure_str(base_name).endswith('.md'):
continue
title_parser = _GetMarkdownTitle()
title_parser.process(os.path.join(dirpath, base_name))
@ -353,7 +356,8 @@ def build_doc_index(src_dir):
msg = ('`{}` has no markdown title (# title)'.format(
os.path.join(dirpath, base_name)))
raise ValueError(msg)
key_parts = os.path.join(suffix, base_name[:-3]).split('/')
key_parts = six.ensure_str(os.path.join(suffix,
base_name[:-3])).split('/')
if key_parts[-1] == 'index':
key_parts = key_parts[:-1]
doc_info = _DocInfo(os.path.join(suffix, base_name), title_parser.title)
@ -367,8 +371,8 @@ def build_doc_index(src_dir):
class _GuideRef(object):
def __init__(self, base_name, title, section_title, section_tag):
self.url = 'api_guides/python/' + (('%s#%s' % (base_name, section_tag))
if section_tag else base_name)
self.url = 'api_guides/python/' + six.ensure_str(
(('%s#%s' % (base_name, section_tag)) if section_tag else base_name))
self.link_text = (('%s > %s' % (title, section_title))
if section_title else title)
@ -447,7 +451,7 @@ def update_id_tags_inplace(src_dir):
# modified file contents
content = tag_updater.process(full_path)
with open(full_path, 'w') as f:
f.write(content)
f.write(six.ensure_str(content))
EXCLUDED = set(['__init__.py', 'OWNERS', 'README.txt'])
@ -512,7 +516,7 @@ def replace_refs(src_dir,
content = reference_resolver.replace_references(content,
relative_path_to_root)
with open(full_out_path, 'wb') as f:
f.write(content.encode('utf-8'))
f.write(six.ensure_binary(content, 'utf-8'))
class DocGenerator(object):

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -28,6 +29,7 @@ import re
import astor
import six
from six.moves import zip
from google.protobuf.message import Message as ProtoMessage
from tensorflow.python.platform import tf_logging as logging
@ -50,7 +52,7 @@ def is_free_function(py_object, full_name, index):
if not tf_inspect.isfunction(py_object):
return False
parent_name = full_name.rsplit('.', 1)[0]
parent_name = six.ensure_str(full_name).rsplit('.', 1)[0]
if tf_inspect.isclass(index[parent_name]):
return False
@ -112,14 +114,14 @@ def documentation_path(full_name, is_fragment=False):
Returns:
The file path to which to write the documentation for `full_name`.
"""
parts = full_name.split('.')
parts = six.ensure_str(full_name).split('.')
if is_fragment:
parts, fragment = parts[:-1], parts[-1]
result = os.path.join(*parts) + '.md'
result = six.ensure_str(os.path.join(*parts)) + '.md'
if is_fragment:
result = result + '#' + fragment
result = six.ensure_str(result) + '#' + six.ensure_str(fragment)
return result
@ -288,7 +290,7 @@ class ReferenceResolver(object):
self.add_error(e.message)
return 'BAD_LINK'
string = re.sub(SYMBOL_REFERENCE_RE, strict_one_ref, string)
string = re.sub(SYMBOL_REFERENCE_RE, strict_one_ref, six.ensure_str(string))
def sloppy_one_ref(match):
try:
@ -333,7 +335,7 @@ class ReferenceResolver(object):
@staticmethod
def _link_text_to_html(link_text):
code_re = '`(.*?)`'
return re.sub(code_re, r'<code>\1</code>', link_text)
return re.sub(code_re, r'<code>\1</code>', six.ensure_str(link_text))
def py_master_name(self, full_name):
"""Return the master name for a Python symbol name."""
@ -389,11 +391,11 @@ class ReferenceResolver(object):
manual_link_text = False
# Handle different types of references.
if string.startswith('$'): # Doc reference
if six.ensure_str(string).startswith('$'): # Doc reference
return self._doc_link(string, link_text, manual_link_text,
relative_path_to_root)
elif string.startswith('tensorflow::'):
elif six.ensure_str(string).startswith('tensorflow::'):
# C++ symbol
return self._cc_link(string, link_text, manual_link_text,
relative_path_to_root)
@ -401,7 +403,8 @@ class ReferenceResolver(object):
else:
is_python = False
for py_module_name in self._py_module_names:
if string == py_module_name or string.startswith(py_module_name + '.'):
if string == py_module_name or string.startswith(
six.ensure_str(py_module_name) + '.'):
is_python = True
break
if is_python: # Python symbol
@ -421,7 +424,7 @@ class ReferenceResolver(object):
string = string[1:] # remove leading $
# If string has a #, split that part into `hash_tag`
hash_pos = string.find('#')
hash_pos = six.ensure_str(string).find('#')
if hash_pos > -1:
hash_tag = string[hash_pos:]
string = string[:hash_pos]
@ -520,10 +523,10 @@ class _FunctionDetail(
def __str__(self):
"""Return the original string that represents the function detail."""
parts = [self.keyword + ':\n']
parts = [six.ensure_str(self.keyword) + ':\n']
parts.append(self.header)
for key, value in self.items:
parts.append(' ' + key + ': ')
parts.append(' ' + six.ensure_str(key) + ': ')
parts.append(value)
return ''.join(parts)
@ -587,7 +590,7 @@ def _parse_function_details(docstring):
item_re = re.compile(r'^ ? ?(\*?\*?\w[\w.]*?\s*):\s', re.MULTILINE)
for keyword, content in pairs:
content = item_re.split(content)
content = item_re.split(six.ensure_str(content))
header = content[0]
items = list(_gen_pairs(content[1:]))
@ -634,7 +637,8 @@ def _parse_md_docstring(py_object, relative_path_to_root, reference_resolver):
atat_re = re.compile(r' *@@[a-zA-Z_.0-9]+ *$')
raw_docstring = '\n'.join(
line for line in raw_docstring.split('\n') if not atat_re.match(line))
line for line in six.ensure_str(raw_docstring).split('\n')
if not atat_re.match(six.ensure_str(line)))
docstring, compatibility = _handle_compatibility(raw_docstring)
docstring, function_details = _parse_function_details(docstring)
@ -698,8 +702,9 @@ def _get_arg_spec(func):
def _remove_first_line_indent(string):
indent = len(re.match(r'^\s*', string).group(0))
return '\n'.join([line[indent:] for line in string.split('\n')])
indent = len(re.match(r'^\s*', six.ensure_str(string)).group(0))
return '\n'.join(
[line[indent:] for line in six.ensure_str(string).split('\n')])
PAREN_NUMBER_RE = re.compile(r'^\(([0-9.e-]+)\)')
@ -761,9 +766,9 @@ def _generate_signature(func, reverse_index):
default_text = reverse_index[id(default)]
elif ast_default is not None:
default_text = (
astor.to_source(ast_default).rstrip('\n').replace('\t', '\\t')
.replace('\n', '\\n').replace('"""', "'"))
default_text = PAREN_NUMBER_RE.sub('\\1', default_text)
six.ensure_str(astor.to_source(ast_default)).rstrip('\n').replace(
'\t', '\\t').replace('\n', '\\n').replace('"""', "'"))
default_text = PAREN_NUMBER_RE.sub('\\1', six.ensure_str(default_text))
if default_text != repr(default):
# This may be an internal name. If so, handle the ones we know about.
@ -797,9 +802,9 @@ def _generate_signature(func, reverse_index):
# Add *args and *kwargs.
if argspec.varargs:
args_list.append('*' + argspec.varargs)
args_list.append('*' + six.ensure_str(argspec.varargs))
if argspec.varkw:
args_list.append('**' + argspec.varkw)
args_list.append('**' + six.ensure_str(argspec.varkw))
return args_list
@ -879,7 +884,7 @@ class _FunctionPageInfo(object):
@property
def short_name(self):
return self._full_name.split('.')[-1]
return six.ensure_str(self._full_name).split('.')[-1]
@property
def defined_in(self):
@ -998,7 +1003,7 @@ class _ClassPageInfo(object):
@property
def short_name(self):
"""Returns the documented object's short name."""
return self._full_name.split('.')[-1]
return six.ensure_str(self._full_name).split('.')[-1]
@property
def defined_in(self):
@ -1091,9 +1096,12 @@ class _ClassPageInfo(object):
base_url = parser_config.reference_resolver.reference_to_url(
base_full_name, relative_path)
link_info = _LinkInfo(short_name=base_full_name.split('.')[-1],
full_name=base_full_name, obj=base,
doc=base_doc, url=base_url)
link_info = _LinkInfo(
short_name=six.ensure_str(base_full_name).split('.')[-1],
full_name=base_full_name,
obj=base,
doc=base_doc,
url=base_url)
bases.append(link_info)
self._bases = bases
@ -1121,7 +1129,7 @@ class _ClassPageInfo(object):
doc: The property's parsed docstring, a `_DocstringInfo`.
"""
# Hide useless namedtuple docs-trings
if re.match('Alias for field number [0-9]+', doc.docstring):
if re.match('Alias for field number [0-9]+', six.ensure_str(doc.docstring)):
doc = doc._replace(docstring='', brief='')
property_info = _PropertyInfo(short_name, full_name, obj, doc)
self._properties.append(property_info)
@ -1255,8 +1263,8 @@ class _ClassPageInfo(object):
# Omit methods defined by namedtuple.
original_method = defining_class.__dict__[short_name]
if (hasattr(original_method, '__module__') and
(original_method.__module__ or '').startswith('namedtuple')):
if (hasattr(original_method, '__module__') and six.ensure_str(
(original_method.__module__ or '')).startswith('namedtuple')):
continue
# Some methods are often overridden without documentation. Because it's
@ -1294,7 +1302,7 @@ class _ClassPageInfo(object):
else:
# Exclude members defined by protobuf that are useless
if issubclass(py_class, ProtoMessage):
if (short_name.endswith('_FIELD_NUMBER') or
if (six.ensure_str(short_name).endswith('_FIELD_NUMBER') or
short_name in ['__slots__', 'DESCRIPTOR']):
continue
@ -1332,7 +1340,7 @@ class _ModulePageInfo(object):
@property
def short_name(self):
return self._full_name.split('.')[-1]
return six.ensure_str(self._full_name).split('.')[-1]
@property
def defined_in(self):
@ -1425,7 +1433,8 @@ class _ModulePageInfo(object):
'__cached__', '__loader__', '__spec__']:
continue
member_full_name = self.full_name + '.' + name if self.full_name else name
member_full_name = six.ensure_str(self.full_name) + '.' + six.ensure_str(
name) if self.full_name else name
member = parser_config.py_name_to_object(member_full_name)
member_doc = _parse_md_docstring(member, relative_path,
@ -1680,20 +1689,21 @@ def _get_defined_in(py_object, parser_config):
# TODO(wicke): And make their source file predictable from the file name.
# In case this is compiled, point to the original
if path.endswith('.pyc'):
if six.ensure_str(path).endswith('.pyc'):
path = path[:-1]
# Never include links outside this code base.
if path.startswith('..') or re.search(r'\b_api\b', path):
if six.ensure_str(path).startswith('..') or re.search(r'\b_api\b',
six.ensure_str(path)):
return None
if re.match(r'.*/gen_[^/]*\.py$', path):
if re.match(r'.*/gen_[^/]*\.py$', six.ensure_str(path)):
return _GeneratedFile(path, parser_config)
if 'genfiles' in path or 'tools/api/generator' in path:
return _GeneratedFile(path, parser_config)
elif re.match(r'.*_pb2\.py$', path):
elif re.match(r'.*_pb2\.py$', six.ensure_str(path)):
# The _pb2.py files all appear right next to their defining .proto file.
return _ProtoFile(path[:-7] + '.proto', parser_config)
return _ProtoFile(six.ensure_str(path[:-7]) + '.proto', parser_config)
else:
return _PythonFile(path, parser_config)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -23,6 +24,8 @@ import functools
import os
import sys
import six
from tensorflow.python.platform import googletest
from tensorflow.python.util import tf_inspect
from tensorflow.tools.docs import doc_controls
@ -180,7 +183,8 @@ class ParserTest(googletest.TestCase):
# Make sure the brief docstring is present
self.assertEqual(
tf_inspect.getdoc(TestClass).split('\n')[0], page_info.doc.brief)
six.ensure_str(tf_inspect.getdoc(TestClass)).split('\n')[0],
page_info.doc.brief)
# Make sure the method is present
self.assertEqual(TestClass.a_method, page_info.methods[0].obj)
@ -236,7 +240,7 @@ class ParserTest(googletest.TestCase):
# 'Alias for field number ##'. These props are returned sorted.
def sort_key(prop_info):
return int(prop_info.obj.__doc__.split(' ')[-1])
return int(six.ensure_str(prop_info.obj.__doc__).split(' ')[-1])
self.assertSequenceEqual(page_info.properties,
sorted(page_info.properties, key=sort_key))
@ -378,7 +382,8 @@ class ParserTest(googletest.TestCase):
# Make sure the brief docstring is present
self.assertEqual(
tf_inspect.getdoc(test_module).split('\n')[0], page_info.doc.brief)
six.ensure_str(tf_inspect.getdoc(test_module)).split('\n')[0],
page_info.doc.brief)
# Make sure that the members are there
funcs = {f_info.obj for f_info in page_info.functions}
@ -422,7 +427,8 @@ class ParserTest(googletest.TestCase):
# Make sure the brief docstring is present
self.assertEqual(
tf_inspect.getdoc(test_function).split('\n')[0], page_info.doc.brief)
six.ensure_str(tf_inspect.getdoc(test_function)).split('\n')[0],
page_info.doc.brief)
# Make sure the extracted signature is good.
self.assertEqual(['unused_arg', "unused_kwarg='default'"],
@ -461,7 +467,8 @@ class ParserTest(googletest.TestCase):
# Make sure the brief docstring is present
self.assertEqual(
tf_inspect.getdoc(test_function_with_args_kwargs).split('\n')[0],
six.ensure_str(
tf_inspect.getdoc(test_function_with_args_kwargs)).split('\n')[0],
page_info.doc.brief)
# Make sure the extracted signature is good.
@ -751,7 +758,8 @@ class TestParseFunctionDetails(googletest.TestCase):
self.assertEqual(
RELU_DOC,
docstring + ''.join(str(detail) for detail in function_details))
six.ensure_str(docstring) +
''.join(str(detail) for detail in function_details))
class TestGenerateSignature(googletest.TestCase):

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -28,6 +29,7 @@ from __future__ import division
from __future__ import print_function
import textwrap
import six
def build_md_page(page_info):
@ -83,7 +85,8 @@ def _build_class_page(page_info):
"""Given a ClassPageInfo object Return the page as an md string."""
parts = ['# {page_info.full_name}\n\n'.format(page_info=page_info)]
parts.append('## Class `%s`\n\n' % page_info.full_name.split('.')[-1])
parts.append('## Class `%s`\n\n' %
six.ensure_str(page_info.full_name).split('.')[-1])
if page_info.bases:
parts.append('Inherits From: ')
@ -222,7 +225,7 @@ def _build_module_page(page_info):
parts.append(template.format(**item._asdict()))
if item.doc.brief:
parts.append(': ' + item.doc.brief)
parts.append(': ' + six.ensure_str(item.doc.brief))
parts.append('\n\n')
@ -234,7 +237,7 @@ def _build_module_page(page_info):
parts.append(template.format(**item._asdict()))
if item.doc.brief:
parts.append(': ' + item.doc.brief)
parts.append(': ' + six.ensure_str(item.doc.brief))
parts.append('\n\n')
@ -246,7 +249,7 @@ def _build_module_page(page_info):
parts.append(template.format(**item._asdict()))
if item.doc.brief:
parts.append(': ' + item.doc.brief)
parts.append(': ' + six.ensure_str(item.doc.brief))
parts.append('\n\n')
@ -273,7 +276,7 @@ def _build_signature(obj_info, use_full_name=True):
'```\n\n')
parts = ['``` python']
parts.extend(['@' + dec for dec in obj_info.decorators])
parts.extend(['@' + six.ensure_str(dec) for dec in obj_info.decorators])
signature_template = '{name}({sig})'
if not obj_info.signature:
@ -313,7 +316,7 @@ def _build_function_details(function_details):
parts = []
for detail in function_details:
sub = []
sub.append('#### ' + detail.keyword + ':\n\n')
sub.append('#### ' + six.ensure_str(detail.keyword) + ':\n\n')
sub.append(textwrap.dedent(detail.header))
for key, value in detail.items:
sub.append('* <b>`%s`</b>: %s' % (key, value))

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -21,14 +22,16 @@ from __future__ import print_function
import os
import re
import six
def md_files_in_dir(py_guide_src_dir):
"""Returns a list of filename (full_path, base) pairs for guide files."""
all_in_dir = [(os.path.join(py_guide_src_dir, f), f)
for f in os.listdir(py_guide_src_dir)]
return [(full, f) for full, f in all_in_dir
if os.path.isfile(full) and f.endswith('.md')]
return [(full, f)
for full, f in all_in_dir
if os.path.isfile(full) and six.ensure_str(f).endswith('.md')]
class PyGuideParser(object):

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -20,6 +21,8 @@ from __future__ import print_function
import os
import six
from tensorflow.python.platform import test
from tensorflow.tools.docs import py_guide_parser
@ -38,7 +41,7 @@ class TestPyGuideParser(py_guide_parser.PyGuideParser):
def process_in_blockquote(self, line_number, line):
self.calls.append((line_number, 'b', line))
self.replace_line(line_number, line + ' BQ')
self.replace_line(line_number, six.ensure_str(line) + ' BQ')
def process_line(self, line_number, line):
self.calls.append((line_number, 'l', line))

View File

@ -11,6 +11,7 @@ package(
py_binary(
name = "gen_git_source",
srcs = ["gen_git_source.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = ["@six_archive//:six"],
)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -26,13 +27,16 @@ NOTE: this script is only used in opensource.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from builtins import bytes # pylint: disable=redefined-builtin
import argparse
from builtins import bytes # pylint: disable=redefined-builtin
import json
import os
import shutil
import subprocess
import six
def parse_branch_ref(filename):
"""Given a filename of a .git/HEAD file return ref path.
@ -161,10 +165,13 @@ def get_git_version(git_base_path, git_tag_override):
unknown_label = b"unknown"
try:
# Force to bytes so this works on python 2 and python 3
val = bytes(subprocess.check_output([
"git", str("--git-dir=%s/.git" % git_base_path),
str("--work-tree=" + git_base_path), "describe", "--long", "--tags"
]).strip())
val = bytes(
subprocess.check_output([
"git",
str("--git-dir=%s/.git" % git_base_path),
str("--work-tree=" + six.ensure_str(git_base_path)), "describe",
"--long", "--tags"
]).strip())
version_separator = b"-"
if git_tag_override and val:
split_val = val.split(version_separator)

View File

@ -3,7 +3,6 @@
load(
"//tensorflow:tensorflow.bzl",
"if_not_v2",
"if_not_windows",
"tf_cc_binary",
"tf_cc_test",

View File

@ -163,6 +163,7 @@ genrule(
"@com_google_protobuf//:LICENSE",
"@snappy//:COPYING",
"@zlib_archive//:zlib.h",
"@six_archive//:LICENSE",
] + select({
"//tensorflow:android": [],
"//tensorflow:ios": [],
@ -235,6 +236,7 @@ genrule(
"@zlib_archive//:zlib.h",
"@grpc//:LICENSE",
"@grpc//third_party/address_sorting:LICENSE",
"@six_archive//:LICENSE",
] + select({
"//tensorflow:android": [],
"//tensorflow:ios": [],

View File

@ -13,9 +13,12 @@ package(
],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "compat_checker",
srcs = ["compat_checker.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:platform",
"//tensorflow/python:util",
@ -29,6 +32,7 @@ py_test(
data = [
"//tensorflow/tools/tensorflow_builder/compat_checker:test_config",
],
python_version = "PY3",
tags = ["no_pip"],
deps = [
":compat_checker",

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -23,16 +24,11 @@ import re
import sys
import six
from six.moves import range
import six.moves.configparser
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_inspect
# pylint: disable=g-import-not-at-top
if six.PY2:
import ConfigParser
else:
import configparser as ConfigParser
# pylint: enable=g-import-not-at-top
PATH_TO_DIR = "tensorflow/tools/tensorflow_builder/compat_checker"
@ -56,8 +52,8 @@ def _compare_versions(v1, v2):
raise RuntimeError("Cannot compare `inf` to `inf`.")
rtn_dict = {"smaller": None, "larger": None}
v1_list = v1.split(".")
v2_list = v2.split(".")
v1_list = six.ensure_str(v1).split(".")
v2_list = six.ensure_str(v2).split(".")
# Take care of cases with infinity (arg=`inf`).
if v1_list[0] == "inf":
v1_list[0] = str(int(v2_list[0]) + 1)
@ -380,7 +376,7 @@ class ConfigCompatChecker(object):
curr_status = True
# Initialize config parser for parsing version requirements file.
parser = ConfigParser.ConfigParser()
parser = six.moves.configparser.ConfigParser()
parser.read(self.req_file)
if not parser.sections():
@ -643,7 +639,7 @@ class ConfigCompatChecker(object):
if filtered[-1] == "]":
filtered = filtered[:-1]
elif "]" in filtered[-1]:
filtered[-1] = filtered[-1].replace("]", "")
filtered[-1] = six.ensure_str(filtered[-1]).replace("]", "")
# If `]` is missing, then it could be a formatting issue with
# config file (.ini.). Add to warning.
else:
@ -792,7 +788,7 @@ class ConfigCompatChecker(object):
Boolean that is a status of the compatibility check result.
"""
# Check if all `Required` configs are found in user configs.
usr_keys = self.usr_config.keys()
usr_keys = list(self.usr_config.keys())
for k in six.iterkeys(self.usr_config):
if k not in usr_keys:
@ -809,10 +805,10 @@ class ConfigCompatChecker(object):
for config_name, spec in six.iteritems(self.usr_config):
temp_status = True
# Check under which section the user config is defined.
in_required = config_name in self.required.keys()
in_optional = config_name in self.optional.keys()
in_unsupported = config_name in self.unsupported.keys()
in_dependency = config_name in self.dependency.keys()
in_required = config_name in list(self.required.keys())
in_optional = config_name in list(self.optional.keys())
in_unsupported = config_name in list(self.unsupported.keys())
in_dependency = config_name in list(self.dependency.keys())
# Add to warning if user config is not specified in the config file.
if not (in_required or in_optional or in_unsupported or in_dependency):

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -88,7 +89,7 @@ class CompatCheckerTest(unittest.TestCase):
# Make sure no warning or error messages are recorded.
self.assertFalse(len(self.compat_checker.error_msg))
# Make sure total # of successes match total # of configs.
cnt = len(USER_CONFIG_IN_RANGE.keys())
cnt = len(list(USER_CONFIG_IN_RANGE.keys()))
self.assertEqual(len(self.compat_checker.successes), cnt)
def testWithUserConfigNotInRange(self):
@ -106,7 +107,7 @@ class CompatCheckerTest(unittest.TestCase):
err_msg_list = self.compat_checker.failures
self.assertTrue(len(err_msg_list))
# Make sure total # of failures match total # of configs.
cnt = len(USER_CONFIG_NOT_IN_RANGE.keys())
cnt = len(list(USER_CONFIG_NOT_IN_RANGE.keys()))
self.assertEqual(len(err_msg_list), cnt)
def testWithUserConfigMissing(self):

View File

@ -14,22 +14,24 @@ py_binary(
data = [
"//tensorflow/tools/tensorflow_builder/config_detector/data/golden:cuda_cc_golden",
],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":cuda_compute_capability",
"@absl_py//absl:app",
"@absl_py//absl/flags",
"@six_archive//:six",
],
)
py_binary(
name = "cuda_compute_capability",
srcs = ["data/cuda_compute_capability.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
"@absl_py//absl:app",
"@absl_py//absl/flags",
"@six_archive//:six",
],
)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -70,6 +71,7 @@ import subprocess
import sys
from absl import app
from absl import flags
import six
from tensorflow.tools.tensorflow_builder.config_detector.data import cuda_compute_capability
@ -182,7 +184,7 @@ def get_cpu_type():
"""
key = "cpu_type"
out, err = run_shell_cmd(cmds_all[PLATFORM][key])
cpu_detected = out.split(":")[1].strip()
cpu_detected = out.split(b":")[1].strip()
if err and FLAGS.debug:
print("Error in detecting CPU type:\n %s" % str(err))
@ -201,7 +203,7 @@ def get_cpu_arch():
if err and FLAGS.debug:
print("Error in detecting CPU arch:\n %s" % str(err))
return out.strip("\n")
return out.strip(b"\n")
def get_distrib():
@ -216,7 +218,7 @@ def get_distrib():
if err and FLAGS.debug:
print("Error in detecting distribution:\n %s" % str(err))
return out.strip("\n")
return out.strip(b"\n")
def get_distrib_version():
@ -233,7 +235,7 @@ def get_distrib_version():
"Error in detecting distribution version:\n %s" % str(err)
)
return out.strip("\n")
return out.strip(b"\n")
def get_gpu_type():
@ -251,7 +253,7 @@ def get_gpu_type():
key = "gpu_type_no_sudo"
gpu_dict = cuda_compute_capability.retrieve_from_golden()
out, err = run_shell_cmd(cmds_all[PLATFORM][key])
ret_val = out.split(" ")
ret_val = out.split(b" ")
gpu_id = ret_val[0]
if err and FLAGS.debug:
print("Error in detecting GPU type:\n %s" % str(err))
@ -261,10 +263,10 @@ def get_gpu_type():
return gpu_id, GPU_TYPE
else:
if "[" or "]" in ret_val[1]:
gpu_release = ret_val[1].replace("[", "") + " "
gpu_release += ret_val[2].replace("]", "").strip("\n")
gpu_release = ret_val[1].replace(b"[", b"") + b" "
gpu_release += ret_val[2].replace(b"]", b"").strip(b"\n")
else:
gpu_release = ret_val[1].replace("\n", " ")
gpu_release = six.ensure_str(ret_val[1]).replace("\n", " ")
if gpu_release not in gpu_dict:
GPU_TYPE = "unknown"
@ -285,7 +287,7 @@ def get_gpu_count():
if err and FLAGS.debug:
print("Error in detecting GPU count:\n %s" % str(err))
return out.strip("\n")
return out.strip(b"\n")
def get_cuda_version_all():
@ -303,7 +305,7 @@ def get_cuda_version_all():
"""
key = "cuda_ver_all"
out, err = run_shell_cmd(cmds_all[PLATFORM.lower()][key])
ret_val = out.split("\n")
ret_val = out.split(b"\n")
filtered = []
for item in ret_val:
if item not in ["\n", ""]:
@ -311,9 +313,9 @@ def get_cuda_version_all():
all_vers = []
for item in filtered:
ver_re = re.search(r".*/cuda(\-[\d]+\.[\d]+)?", item)
ver_re = re.search(r".*/cuda(\-[\d]+\.[\d]+)?", item.decode("utf-8"))
if ver_re.group(1):
all_vers.append(ver_re.group(1).strip("-"))
all_vers.append(six.ensure_str(ver_re.group(1)).strip("-"))
if err and FLAGS.debug:
print("Error in detecting CUDA version:\n %s" % str(err))
@ -409,13 +411,13 @@ def get_cudnn_version():
if err and FLAGS.debug:
print("Error in finding `cudnn.h`:\n %s" % str(err))
if len(out.split(" ")) > 1:
if len(out.split(b" ")) > 1:
cmd = cmds[0] + " | " + cmds[1]
out_re, err_re = run_shell_cmd(cmd)
if err_re and FLAGS.debug:
print("Error in detecting cuDNN version:\n %s" % str(err_re))
return out_re.strip("\n")
return out_re.strip(b"\n")
else:
return
@ -432,7 +434,7 @@ def get_gcc_version():
if err and FLAGS.debug:
print("Error in detecting GCC version:\n %s" % str(err))
return out.strip("\n")
return out.strip(b"\n")
def get_glibc_version():
@ -447,7 +449,7 @@ def get_glibc_version():
if err and FLAGS.debug:
print("Error in detecting GCC version:\n %s" % str(err))
return out.strip("\n")
return out.strip(b"\n")
def get_libstdcpp_version():
@ -462,7 +464,7 @@ def get_libstdcpp_version():
if err and FLAGS.debug:
print("Error in detecting libstdc++ version:\n %s" % str(err))
ver = out.split("_")[-1].replace("\n", "")
ver = out.split(b"_")[-1].replace(b"\n", b"")
return ver
@ -485,7 +487,7 @@ def get_cpu_isa_version():
found = []
missing = []
for isa in required_isa:
for sys_isa in ret_val.split(" "):
for sys_isa in ret_val.split(b" "):
if isa == sys_isa:
if isa not in found:
found.append(isa)
@ -539,7 +541,7 @@ def get_all_configs():
json_data = {}
missing = []
warning = []
for config, call_func in all_functions.iteritems():
for config, call_func in six.iteritems(all_functions):
ret_val = call_func
if not ret_val:
configs_found.append([config, "\033[91m\033[1mMissing\033[0m"])
@ -557,10 +559,10 @@ def get_all_configs():
configs_found.append([config, ret_val[0]])
json_data[config] = ret_val[0]
else:
configs_found.append(
[config,
"\033[91m\033[1mMissing " + str(ret_val[1])[1:-1] + "\033[0m"]
)
configs_found.append([
config, "\033[91m\033[1mMissing " +
six.ensure_str(str(ret_val[1])[1:-1]) + "\033[0m"
])
missing.append(
[config,
"\n\t=> Found %s but missing %s"
@ -587,7 +589,7 @@ def print_all_configs(configs, missing, warning):
llen = 65 # line length
for i, row in enumerate(configs):
if i != 0:
print_text += "-"*llen + "\n"
print_text += six.ensure_str("-" * llen) + "\n"
if isinstance(row[1], list):
val = ", ".join(row[1])
@ -629,7 +631,7 @@ def save_to_file(json_data, filename):
print("filename: %s" % filename)
filename += ".json"
with open(PATH_TO_DIR + "/" + filename, "w") as f:
with open(PATH_TO_DIR + "/" + six.ensure_str(filename), "w") as f:
json.dump(json_data, f, sort_keys=True, indent=4)
print(" Successfully wrote configs to file `%s`.\n" % (filename))

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -42,6 +43,7 @@ import re
from absl import app
from absl import flags
import six
import six.moves.urllib.request as urllib
FLAGS = flags.FLAGS
@ -61,21 +63,18 @@ def retrieve_from_web(generate_csv=False):
NVIDIA page. Order goes from top to bottom of the webpage content (.html).
"""
url = "https://developer.nvidia.com/cuda-gpus"
source = urllib.urlopen(url)
source = urllib.request.urlopen(url)
matches = []
while True:
line = source.readline()
if "</html>" in line:
break
else:
gpu = re.search(
r"<a href=.*>([\w\S\s\d\[\]\,]+[^*])</a>(<a href=.*)?.*",
line
)
gpu = re.search(r"<a href=.*>([\w\S\s\d\[\]\,]+[^*])</a>(<a href=.*)?.*",
six.ensure_str(line))
capability = re.search(
r"([\d]+).([\d]+)(/)?([\d]+)?(.)?([\d]+)?.*</td>.*",
line
)
six.ensure_str(line))
if gpu:
matches.append(gpu.group(1))
elif capability:
@ -155,15 +154,15 @@ def create_gpu_capa_map(match_list,
gpu = ""
cnt += 1
if len(gpu_capa.keys()) < cnt:
if len(list(gpu_capa.keys())) < cnt:
mismatch_cnt += 1
cnt = len(gpu_capa.keys())
cnt = len(list(gpu_capa.keys()))
else:
gpu = match
if generate_csv:
f_name = filename + ".csv"
f_name = six.ensure_str(filename) + ".csv"
write_csv_from_dict(f_name, gpu_capa)
return gpu_capa
@ -179,8 +178,8 @@ def write_csv_from_dict(filename, input_dict):
filename: String that is the output file name.
input_dict: Dictionary that is to be written out to a `.csv` file.
"""
f = open(PATH_TO_DIR + "/data/" + filename, "w")
for k, v in input_dict.iteritems():
f = open(PATH_TO_DIR + "/data/" + six.ensure_str(filename), "w")
for k, v in six.iteritems(input_dict):
line = k
for item in v:
line += "," + item
@ -203,7 +202,7 @@ def check_with_golden(filename):
Args:
filename: String that is the name of the newly created file.
"""
path_to_file = PATH_TO_DIR + "/data/" + filename
path_to_file = PATH_TO_DIR + "/data/" + six.ensure_str(filename)
if os.path.isfile(path_to_file) and os.path.isfile(CUDA_CC_GOLDEN_DIR):
with open(path_to_file, "r") as f_new:
with open(CUDA_CC_GOLDEN_DIR, "r") as f_golden:

View File

@ -22,17 +22,19 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client",
"//tensorflow/python:errors",
"//tensorflow/python:platform",
"@six_archive//:six",
],
)
py_binary(
name = "system_info",
srcs = ["system_info.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":system_info_lib",
@ -50,16 +52,20 @@ py_library(
":system_info_lib",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:platform",
"@six_archive//:six",
],
)
py_binary(
name = "run_and_gather_logs",
srcs = ["run_and_gather_logs.py"],
python_version = "PY2",
python_version = "PY3",
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [":run_and_gather_logs_main_lib"],
deps = [
":run_and_gather_logs_main_lib",
"@six_archive//:six",
],
)
py_library(
@ -72,6 +78,7 @@ py_library(
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform",
"@six_archive//:six",
],
)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -21,6 +22,9 @@ from __future__ import print_function
import ctypes as ct
import platform
import six
from six.moves import range
from tensorflow.core.util import test_log_pb2
from tensorflow.python.framework import errors
from tensorflow.python.platform import gfile
@ -30,10 +34,11 @@ def _gather_gpu_devices_proc():
"""Try to gather NVidia GPU device information via /proc/driver."""
dev_info = []
for f in gfile.Glob("/proc/driver/nvidia/gpus/*/information"):
bus_id = f.split("/")[5]
key_values = dict(line.rstrip().replace("\t", "").split(":", 1)
for line in gfile.GFile(f, "r"))
key_values = dict((k.lower(), v.strip(" ").rstrip(" "))
bus_id = six.ensure_str(f).split("/")[5]
key_values = dict(
six.ensure_str(line.rstrip()).replace("\t", "").split(":", 1)
for line in gfile.GFile(f, "r"))
key_values = dict((k.lower(), six.ensure_str(v).strip(" ").rstrip(" "))
for (k, v) in key_values.items())
info = test_log_pb2.GPUInfo()
info.model = key_values.get("model", "Unknown")

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -25,9 +26,10 @@ from string import maketrans
import sys
import time
import six
from google.protobuf import json_format
from google.protobuf import text_format
from tensorflow.core.util import test_log_pb2
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
@ -83,8 +85,9 @@ def main(unused_args):
if FLAGS.test_log_output_filename:
file_name = FLAGS.test_log_output_filename
else:
file_name = (name.strip("/").translate(maketrans("/:", "__")) +
time.strftime("%Y%m%d%H%M%S", time.gmtime()))
file_name = (
six.ensure_str(name).strip("/").translate(maketrans("/:", "__")) +
time.strftime("%Y%m%d%H%M%S", time.gmtime()))
if FLAGS.test_log_output_use_tmpdir:
tmpdir = test.get_temp_dir()
output_path = os.path.join(tmpdir, FLAGS.test_log_output_dir, file_name)
@ -92,7 +95,8 @@ def main(unused_args):
output_path = os.path.join(
os.path.abspath(FLAGS.test_log_output_dir), file_name)
json_test_results = json_format.MessageToJson(test_results)
gfile.GFile(output_path + ".json", "w").write(json_test_results)
gfile.GFile(six.ensure_str(output_path) + ".json",
"w").write(json_test_results)
tf_logging.info("Test results written to: %s" % output_path)

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -25,6 +26,8 @@ import subprocess
import tempfile
import time
import six
from tensorflow.core.util import test_log_pb2
from tensorflow.python.platform import gfile
from tensorflow.tools.test import gpu_info_lib
@ -118,12 +121,15 @@ def run_and_gather_logs(name, test_name, test_args,
IOError: If there are problems gathering test log output from the test.
MissingLogsError: If we couldn't find benchmark logs.
"""
if not (test_name and test_name.startswith("//") and ".." not in test_name and
not test_name.endswith(":") and not test_name.endswith(":all") and
not test_name.endswith("...") and len(test_name.split(":")) == 2):
if not (test_name and six.ensure_str(test_name).startswith("//") and
".." not in test_name and not six.ensure_str(test_name).endswith(":")
and not six.ensure_str(test_name).endswith(":all") and
not six.ensure_str(test_name).endswith("...") and
len(six.ensure_str(test_name).split(":")) == 2):
raise ValueError("Expected test_name parameter with a unique test, e.g.: "
"--test_name=//path/to:test")
test_executable = test_name.rstrip().strip("/").replace(":", "/")
test_executable = six.ensure_str(test_name.rstrip()).strip("/").replace(
":", "/")
if gfile.Exists(os.path.join("bazel-bin", test_executable)):
# Running in standalone mode from core of the repository
@ -136,14 +142,17 @@ def run_and_gather_logs(name, test_name, test_args,
gpu_config = gpu_info_lib.gather_gpu_devices()
if gpu_config:
gpu_name = gpu_config[0].model
gpu_short_name_match = re.search(r"Tesla (K40|K80|P100|V100)", gpu_name)
gpu_short_name_match = re.search(r"Tesla (K40|K80|P100|V100)",
six.ensure_str(gpu_name))
if gpu_short_name_match:
gpu_short_name = gpu_short_name_match.group(0)
test_adjusted_name = name + "|" + gpu_short_name.replace(" ", "_")
test_adjusted_name = six.ensure_str(name) + "|" + gpu_short_name.replace(
" ", "_")
temp_directory = tempfile.mkdtemp(prefix="run_and_gather_logs")
mangled_test_name = (test_adjusted_name.strip("/")
.replace("|", "_").replace("/", "_").replace(":", "_"))
mangled_test_name = (
six.ensure_str(test_adjusted_name).strip("/").replace("|", "_").replace(
"/", "_").replace(":", "_"))
test_file_prefix = os.path.join(temp_directory, mangled_test_name)
test_file_prefix = "%s." % test_file_prefix

View File

@ -1,3 +1,4 @@
# Lint as: python2, python3
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -29,6 +30,8 @@ import socket
# OSS tree. They are installable via pip.
import cpuinfo
import psutil
import six
# pylint: enable=g-bad-import-order
from tensorflow.core.util import test_log_pb2
@ -81,7 +84,8 @@ def gather_cpu_info():
# Gather num_cores_allowed
try:
with gfile.GFile('/proc/self/status', 'rb') as fh:
nc = re.search(r'(?m)^Cpus_allowed:\s*(.*)$', fh.read().decode('utf-8'))
nc = re.search(r'(?m)^Cpus_allowed:\s*(.*)$',
six.ensure_text(fh.read(), 'utf-8'))
if nc: # e.g. 'ff' => 8, 'fff' => 12
cpu_info.num_cores_allowed = (
bin(int(nc.group(1).replace(',', ''), 16)).count('1'))