PY3 Migration - //tensorflow/tools [2]
PiperOrigin-RevId: 275334989 Change-Id: Ia0b660e7ce7cde8e97f8d7cb3a39afe7fec63a7f
This commit is contained in:
parent
b78675a3bd
commit
c396546ca3
@ -23,5 +23,6 @@ py_library(
|
||||
":api_objects_proto_py",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -19,8 +20,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import enum
|
||||
import sys
|
||||
|
||||
import six
|
||||
|
||||
from google.protobuf import message
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import deprecation
|
||||
@ -191,7 +195,8 @@ class PythonObjectToProtoVisitor(object):
|
||||
if (_SkipMember(parent, member_name) or
|
||||
isinstance(member_obj, deprecation.HiddenTfApiAttribute)):
|
||||
return
|
||||
if member_name == '__init__' or not member_name.startswith('_'):
|
||||
if member_name == '__init__' or not six.ensure_str(
|
||||
member_name).startswith('_'):
|
||||
if tf_inspect.isroutine(member_obj):
|
||||
new_method = proto.member_method.add()
|
||||
new_method.name = member_name
|
||||
|
@ -14,13 +14,16 @@ py_library(
|
||||
name = "public_api",
|
||||
srcs = ["public_api.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow/python:util"],
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "public_api_test",
|
||||
srcs = ["public_api_test.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":public_api",
|
||||
@ -32,13 +35,16 @@ py_library(
|
||||
name = "traverse",
|
||||
srcs = ["traverse.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow/python:util"],
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "traverse_test",
|
||||
srcs = ["traverse_test.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":test_module1",
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -20,6 +21,8 @@ from __future__ import print_function
|
||||
|
||||
import re
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
@ -108,9 +111,9 @@ class PublicAPIVisitor(object):
|
||||
"""Return whether a name is private."""
|
||||
# TODO(wicke): Find out what names to exclude.
|
||||
del obj # Unused.
|
||||
return ((path in self._private_map and
|
||||
name in self._private_map[path]) or
|
||||
(name.startswith('_') and not re.match('__.*__$', name) or
|
||||
return ((path in self._private_map and name in self._private_map[path]) or
|
||||
(six.ensure_str(name).startswith('_') and
|
||||
not re.match('__.*__$', six.ensure_str(name)) or
|
||||
name in ['__base__', '__class__']))
|
||||
|
||||
def _do_not_descend(self, path, name):
|
||||
@ -122,7 +125,8 @@ class PublicAPIVisitor(object):
|
||||
"""Visitor interface, see `traverse` for details."""
|
||||
|
||||
# Avoid long waits in cases of pretty unambiguous failure.
|
||||
if tf_inspect.ismodule(parent) and len(path.split('.')) > 10:
|
||||
if tf_inspect.ismodule(parent) and len(
|
||||
six.ensure_str(path).split('.')) > 10:
|
||||
raise RuntimeError('Modules nested too deep:\n%s.%s\n\nThis is likely a '
|
||||
'problem with an accidental public import.' %
|
||||
(self._root_name, path))
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -21,6 +22,8 @@ from __future__ import print_function
|
||||
import enum
|
||||
import sys
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
__all__ = ['traverse']
|
||||
@ -59,7 +62,8 @@ def _traverse_internal(root, visit, stack, path):
|
||||
if any(child is item for item in new_stack): # `in`, but using `is`
|
||||
continue
|
||||
|
||||
child_path = path + '.' + name if path else name
|
||||
child_path = six.ensure_str(path) + '.' + six.ensure_str(
|
||||
name) if path else name
|
||||
_traverse_internal(child, visit, new_stack, child_path)
|
||||
|
||||
|
||||
|
@ -14,6 +14,7 @@ py_library(
|
||||
name = "ipynb",
|
||||
srcs = ["ipynb.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["@six_archive//:six"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -29,7 +30,7 @@ py_library(
|
||||
py_test(
|
||||
name = "ast_edits_test",
|
||||
srcs = ["ast_edits_test.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ast_edits",
|
||||
@ -42,22 +43,28 @@ py_test(
|
||||
py_binary(
|
||||
name = "tf_upgrade",
|
||||
srcs = ["tf_upgrade.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":tf_upgrade_lib"],
|
||||
deps = [
|
||||
":tf_upgrade_lib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tf_upgrade_lib",
|
||||
srcs = ["tf_upgrade.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":ast_edits"],
|
||||
deps = [
|
||||
":ast_edits",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tf_upgrade_test",
|
||||
srcs = ["tf_upgrade_test.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_pip",
|
||||
@ -96,6 +103,7 @@ py_library(
|
||||
py_test(
|
||||
name = "all_renames_v2_test",
|
||||
srcs = ["all_renames_v2_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":all_renames_v2",
|
||||
@ -108,6 +116,7 @@ py_test(
|
||||
py_library(
|
||||
name = "module_deprecations_v2",
|
||||
srcs = ["module_deprecations_v2.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":ast_edits"],
|
||||
)
|
||||
|
||||
@ -145,13 +154,14 @@ py_binary(
|
||||
":ipynb",
|
||||
":tf_upgrade_v2_lib",
|
||||
":tf_upgrade_v2_safety_lib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tf_upgrade_v2_test",
|
||||
srcs = ["tf_upgrade_v2_test.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["v1only"],
|
||||
deps = [
|
||||
@ -169,6 +179,7 @@ py_test(
|
||||
py_test(
|
||||
name = "tf_upgrade_v2_safety_test",
|
||||
srcs = ["tf_upgrade_v2_safety_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":tf_upgrade_v2_safety_lib",
|
||||
@ -208,7 +219,7 @@ py_test(
|
||||
name = "test_file_v1_0",
|
||||
size = "small",
|
||||
srcs = ["test_file_v1_0.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
@ -235,7 +246,7 @@ py_test(
|
||||
name = "test_file_v1_12",
|
||||
size = "small",
|
||||
srcs = ["testdata/test_file_v1_12.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["v1only"],
|
||||
deps = [
|
||||
@ -247,7 +258,7 @@ py_test(
|
||||
name = "test_file_v2_0",
|
||||
size = "small",
|
||||
srcs = ["test_file_v2_0.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -29,6 +30,7 @@ import traceback
|
||||
|
||||
import pasta
|
||||
import six
|
||||
from six.moves import range
|
||||
|
||||
# Some regular expressions we will need for parsing
|
||||
FIND_OPEN = re.compile(r"^\s*(\[).*$")
|
||||
@ -56,7 +58,7 @@ def full_name_node(name, ctx=ast.Load()):
|
||||
Returns:
|
||||
A Name or Attribute node.
|
||||
"""
|
||||
names = name.split(".")
|
||||
names = six.ensure_str(name).split(".")
|
||||
names.reverse()
|
||||
node = ast.Name(id=names.pop(), ctx=ast.Load())
|
||||
while names:
|
||||
@ -301,7 +303,7 @@ class _PastaEditVisitor(ast.NodeVisitor):
|
||||
function_transformers = getattr(self._api_change_spec,
|
||||
transformer_field, {})
|
||||
|
||||
glob_name = "*." + name if name else None
|
||||
glob_name = "*." + six.ensure_str(name) if name else None
|
||||
transformers = []
|
||||
if full_name in function_transformers:
|
||||
transformers.append(function_transformers[full_name])
|
||||
@ -318,7 +320,7 @@ class _PastaEditVisitor(ast.NodeVisitor):
|
||||
function_transformers = getattr(self._api_change_spec,
|
||||
transformer_field, {})
|
||||
|
||||
glob_name = "*." + name if name else None
|
||||
glob_name = "*." + six.ensure_str(name) if name else None
|
||||
transformers = function_transformers.get("*", {}).copy()
|
||||
transformers.update(function_transformers.get(glob_name, {}))
|
||||
transformers.update(function_transformers.get(full_name, {}))
|
||||
@ -351,7 +353,7 @@ class _PastaEditVisitor(ast.NodeVisitor):
|
||||
function_warnings = self._api_change_spec.function_warnings
|
||||
if full_name in function_warnings:
|
||||
level, message = function_warnings[full_name]
|
||||
message = message.replace("<function name>", full_name)
|
||||
message = six.ensure_str(message).replace("<function name>", full_name)
|
||||
self.add_log(level, node.lineno, node.col_offset,
|
||||
"%s requires manual check. %s" % (full_name, message))
|
||||
return True
|
||||
@ -363,7 +365,8 @@ class _PastaEditVisitor(ast.NodeVisitor):
|
||||
warnings = self._api_change_spec.module_deprecations
|
||||
if full_name in warnings:
|
||||
level, message = warnings[full_name]
|
||||
message = message.replace("<function name>", whole_name)
|
||||
message = six.ensure_str(message).replace("<function name>",
|
||||
six.ensure_str(whole_name))
|
||||
self.add_log(level, node.lineno, node.col_offset,
|
||||
"Using member %s in deprecated module %s. %s" % (whole_name,
|
||||
full_name,
|
||||
@ -394,7 +397,7 @@ class _PastaEditVisitor(ast.NodeVisitor):
|
||||
# an attribute.
|
||||
warned = False
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
warned = self._maybe_add_warning(node, "*." + name)
|
||||
warned = self._maybe_add_warning(node, "*." + six.ensure_str(name))
|
||||
|
||||
# All arg warnings are handled here, since only we have the args
|
||||
arg_warnings = self._get_applicable_dict("function_arg_warnings",
|
||||
@ -406,7 +409,8 @@ class _PastaEditVisitor(ast.NodeVisitor):
|
||||
present, _ = get_arg_value(node, kwarg, arg) or variadic_args
|
||||
if present:
|
||||
warned = True
|
||||
warning_message = warning.replace("<function name>", full_name or name)
|
||||
warning_message = six.ensure_str(warning).replace(
|
||||
"<function name>", six.ensure_str(full_name or name))
|
||||
template = "%s called with %s argument, requires manual check: %s"
|
||||
if variadic_args:
|
||||
template = ("%s called with *args or **kwargs that may include %s, "
|
||||
@ -625,13 +629,13 @@ class _PastaEditVisitor(ast.NodeVisitor):
|
||||
# This loop processes imports in the format
|
||||
# import foo as f, bar as b
|
||||
for import_alias in node.names:
|
||||
all_import_components = import_alias.name.split(".")
|
||||
all_import_components = six.ensure_str(import_alias.name).split(".")
|
||||
# Look for rename, starting with longest import levels.
|
||||
found_update = False
|
||||
for i in reversed(range(1, max_submodule_depth + 1)):
|
||||
for i in reversed(list(range(1, max_submodule_depth + 1))):
|
||||
import_component = all_import_components[0]
|
||||
for j in range(1, min(i, len(all_import_components))):
|
||||
import_component += "." + all_import_components[j]
|
||||
import_component += "." + six.ensure_str(all_import_components[j])
|
||||
import_rename_spec = import_renames.get(import_component, None)
|
||||
|
||||
if not import_rename_spec or excluded_from_module_rename(
|
||||
@ -674,7 +678,8 @@ class _PastaEditVisitor(ast.NodeVisitor):
|
||||
if old_suffix is None:
|
||||
old_suffix = os.linesep
|
||||
if os.linesep not in old_suffix:
|
||||
pasta.base.formatting.set(node, "suffix", old_suffix + os.linesep)
|
||||
pasta.base.formatting.set(node, "suffix",
|
||||
six.ensure_str(old_suffix) + os.linesep)
|
||||
|
||||
# Apply indentation to new node.
|
||||
pasta.base.formatting.set(new_line_node, "prefix",
|
||||
@ -720,7 +725,7 @@ class _PastaEditVisitor(ast.NodeVisitor):
|
||||
|
||||
# Look for rename based on first component of from-import.
|
||||
# i.e. based on foo in foo.bar.
|
||||
from_import_first_component = from_import.split(".")[0]
|
||||
from_import_first_component = six.ensure_str(from_import).split(".")[0]
|
||||
import_renames = getattr(self._api_change_spec, "import_renames", {})
|
||||
import_rename_spec = import_renames.get(from_import_first_component, None)
|
||||
if not import_rename_spec:
|
||||
@ -918,7 +923,7 @@ class ASTCodeUpgrader(object):
|
||||
def format_log(self, log, in_filename):
|
||||
log_string = "%d:%d: %s: %s" % (log[1], log[2], log[0], log[3])
|
||||
if in_filename:
|
||||
return in_filename + ":" + log_string
|
||||
return six.ensure_str(in_filename) + ":" + log_string
|
||||
else:
|
||||
return log_string
|
||||
|
||||
@ -945,12 +950,12 @@ class ASTCodeUpgrader(object):
|
||||
return 1, pasta.dump(t), logs, errors
|
||||
|
||||
def _format_log(self, log, in_filename, out_filename):
|
||||
text = "-" * 80 + "\n"
|
||||
text = six.ensure_str("-" * 80) + "\n"
|
||||
text += "Processing file %r\n outputting to %r\n" % (in_filename,
|
||||
out_filename)
|
||||
text += "-" * 80 + "\n\n"
|
||||
text += six.ensure_str("-" * 80) + "\n\n"
|
||||
text += "\n".join(log) + "\n"
|
||||
text += "-" * 80 + "\n\n"
|
||||
text += six.ensure_str("-" * 80) + "\n\n"
|
||||
return text
|
||||
|
||||
def process_opened_file(self, in_filename, in_file, out_filename, out_file):
|
||||
@ -1017,8 +1022,10 @@ class ASTCodeUpgrader(object):
|
||||
files_to_process = []
|
||||
files_to_copy = []
|
||||
for dir_name, _, file_list in os.walk(root_directory):
|
||||
py_files = [f for f in file_list if f.endswith(".py")]
|
||||
copy_files = [f for f in file_list if not f.endswith(".py")]
|
||||
py_files = [f for f in file_list if six.ensure_str(f).endswith(".py")]
|
||||
copy_files = [
|
||||
f for f in file_list if not six.ensure_str(f).endswith(".py")
|
||||
]
|
||||
for filename in py_files:
|
||||
fullpath = os.path.join(dir_name, filename)
|
||||
fullpath_output = os.path.join(output_root_directory,
|
||||
@ -1036,9 +1043,9 @@ class ASTCodeUpgrader(object):
|
||||
file_count = 0
|
||||
tree_errors = {}
|
||||
report = ""
|
||||
report += ("=" * 80) + "\n"
|
||||
report += six.ensure_str(("=" * 80)) + "\n"
|
||||
report += "Input tree: %r\n" % root_directory
|
||||
report += ("=" * 80) + "\n"
|
||||
report += six.ensure_str(("=" * 80)) + "\n"
|
||||
|
||||
for input_path, output_path in files_to_process:
|
||||
output_directory = os.path.dirname(output_path)
|
||||
@ -1074,16 +1081,19 @@ class ASTCodeUpgrader(object):
|
||||
"""Process a directory of python files in place."""
|
||||
files_to_process = []
|
||||
for dir_name, _, file_list in os.walk(root_directory):
|
||||
py_files = [os.path.join(dir_name,
|
||||
f) for f in file_list if f.endswith(".py")]
|
||||
py_files = [
|
||||
os.path.join(dir_name, f)
|
||||
for f in file_list
|
||||
if six.ensure_str(f).endswith(".py")
|
||||
]
|
||||
files_to_process += py_files
|
||||
|
||||
file_count = 0
|
||||
tree_errors = {}
|
||||
report = ""
|
||||
report += ("=" * 80) + "\n"
|
||||
report += six.ensure_str(("=" * 80)) + "\n"
|
||||
report += "Input tree: %r\n" % root_directory
|
||||
report += ("=" * 80) + "\n"
|
||||
report += six.ensure_str(("=" * 80)) + "\n"
|
||||
|
||||
for path in files_to_process:
|
||||
if os.path.islink(path):
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -24,6 +25,7 @@ import json
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import six
|
||||
|
||||
CodeLine = collections.namedtuple("CodeLine", ["cell_number", "code"])
|
||||
|
||||
@ -31,7 +33,8 @@ def is_python(cell):
|
||||
"""Checks if the cell consists of Python code."""
|
||||
return (cell["cell_type"] == "code" # code cells only
|
||||
and cell["source"] # non-empty cells
|
||||
and not cell["source"][0].startswith("%%")) # multiline eg: %%bash
|
||||
and not six.ensure_str(cell["source"][0]).startswith("%%")
|
||||
) # multiline eg: %%bash
|
||||
|
||||
|
||||
def process_file(in_filename, out_filename, upgrader):
|
||||
@ -47,8 +50,9 @@ def process_file(in_filename, out_filename, upgrader):
|
||||
upgrader.update_string_pasta("\n".join(raw_lines), in_filename))
|
||||
|
||||
if temp_file and processed_file:
|
||||
new_notebook = _update_notebook(notebook, raw_code,
|
||||
new_file_content.split("\n"))
|
||||
new_notebook = _update_notebook(
|
||||
notebook, raw_code,
|
||||
six.ensure_str(new_file_content).split("\n"))
|
||||
json.dump(new_notebook, temp_file)
|
||||
else:
|
||||
raise SyntaxError(
|
||||
@ -78,7 +82,7 @@ def skip_magic(code_line, magic_list):
|
||||
"""
|
||||
|
||||
for magic in magic_list:
|
||||
if code_line.startswith(magic):
|
||||
if six.ensure_str(code_line).startswith(magic):
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -120,7 +124,7 @@ def _get_code(input_file):
|
||||
# Idea is to comment these lines, for upgrade time
|
||||
if skip_magic(code_line, ["%", "!", "?"]) or is_line_split:
|
||||
# Found a special character, need to "encode"
|
||||
code_line = "###!!!" + code_line
|
||||
code_line = "###!!!" + six.ensure_str(code_line)
|
||||
|
||||
# if this cell ends with `\` -> skip the next line
|
||||
is_line_split = check_line_split(code_line)
|
||||
@ -131,14 +135,16 @@ def _get_code(input_file):
|
||||
# Sometimes, people leave \n at the end of cell
|
||||
# in order to migrate only related things, and make the diff
|
||||
# the smallest -> here is another hack
|
||||
if (line_idx == len(cell_lines) - 1) and code_line.endswith("\n"):
|
||||
code_line = code_line.replace("\n", "###===")
|
||||
if (line_idx == len(cell_lines) -
|
||||
1) and six.ensure_str(code_line).endswith("\n"):
|
||||
code_line = six.ensure_str(code_line).replace("\n", "###===")
|
||||
|
||||
# sometimes a line would start with `\n` and content after
|
||||
# that's the hack for this
|
||||
raw_code.append(
|
||||
CodeLine(cell_index,
|
||||
code_line.rstrip().replace("\n", "###===")))
|
||||
six.ensure_str(code_line.rstrip()).replace("\n",
|
||||
"###===")))
|
||||
|
||||
cell_index += 1
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -20,6 +21,8 @@ from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.tools.compatibility import ast_edits
|
||||
|
||||
|
||||
@ -245,7 +248,7 @@ Simple usage:
|
||||
else:
|
||||
parser.print_help()
|
||||
if report_text:
|
||||
open(report_filename, "w").write(report_text)
|
||||
open(report_filename, "w").write(six.ensure_str(report_text))
|
||||
print("TensorFlow 1.0 Upgrade Script")
|
||||
print("-----------------------------")
|
||||
print("Converted %d files\n" % files_processed)
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -46,14 +47,16 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
def testParseError(self):
|
||||
_, report, unused_errors, unused_new_text = self._upgrade(
|
||||
"import tensorflow as tf\na + \n")
|
||||
self.assertTrue(report.find("Failed to parse") != -1)
|
||||
self.assertNotEqual(six.ensure_str(report).find("Failed to parse"), -1)
|
||||
|
||||
def testReport(self):
|
||||
text = "tf.mul(a, b)\n"
|
||||
_, report, unused_errors, unused_new_text = self._upgrade(text)
|
||||
# This is not a complete test, but it is a sanity test that a report
|
||||
# is generating information.
|
||||
self.assertTrue(report.find("Renamed function `tf.mul` to `tf.multiply`"))
|
||||
self.assertTrue(
|
||||
six.ensure_str(report).find(
|
||||
"Renamed function `tf.mul` to `tf.multiply`"))
|
||||
|
||||
def testRename(self):
|
||||
text = "tf.mul(a, tf.sub(b, c))\n"
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -24,6 +25,7 @@ import functools
|
||||
import sys
|
||||
|
||||
import pasta
|
||||
import six
|
||||
|
||||
from tensorflow.tools.compatibility import all_renames_v2
|
||||
from tensorflow.tools.compatibility import ast_edits
|
||||
@ -47,8 +49,9 @@ class VersionedTFImport(ast_edits.AnalysisResult):
|
||||
|
||||
def __init__(self, version):
|
||||
self.log_level = ast_edits.INFO
|
||||
self.log_message = ("Not upgrading symbols because `tensorflow." + version
|
||||
+ "` was directly imported as `tf`.")
|
||||
self.log_message = ("Not upgrading symbols because `tensorflow." +
|
||||
six.ensure_str(version) +
|
||||
"` was directly imported as `tf`.")
|
||||
|
||||
|
||||
class TFAPIImportAnalysisSpec(ast_edits.APIAnalysisSpec):
|
||||
@ -1687,7 +1690,7 @@ def _rename_if_arg_found_transformer(parent, node, full_name, name, logs,
|
||||
|
||||
# All conditions met, insert v1 and log what we did.
|
||||
# We must have a full name, so the func is an attribute.
|
||||
new_name = full_name.replace("tf.", "tf.compat.v1.", 1)
|
||||
new_name = six.ensure_str(full_name).replace("tf.", "tf.compat.v1.", 1)
|
||||
node.func = ast_edits.full_name_node(new_name)
|
||||
logs.append((
|
||||
ast_edits.INFO, node.lineno, node.col_offset,
|
||||
@ -1715,8 +1718,8 @@ def _iterator_transformer(parent, node, full_name, name, logs):
|
||||
# (tf.compat.v1.data), or something which is handled in the rename
|
||||
# (tf.data). This transformer only handles the method call to function call
|
||||
# conversion.
|
||||
if full_name and (full_name.startswith("tf.compat.v1.data") or
|
||||
full_name.startswith("tf.data")):
|
||||
if full_name and (six.ensure_str(full_name).startswith("tf.compat.v1.data") or
|
||||
six.ensure_str(full_name).startswith("tf.data")):
|
||||
return
|
||||
|
||||
# This should never happen, since we're only called for Attribute nodes.
|
||||
@ -2460,7 +2463,7 @@ def _name_scope_transformer(parent, node, full_name, name, logs):
|
||||
|
||||
|
||||
def _rename_to_compat_v1(node, full_name, logs, reason):
|
||||
new_name = full_name.replace("tf.", "tf.compat.v1.", 1)
|
||||
new_name = six.ensure_str(full_name).replace("tf.", "tf.compat.v1.", 1)
|
||||
return _rename_func(node, full_name, new_name, logs, reason)
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -20,10 +21,13 @@ from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.tools.compatibility import ast_edits
|
||||
from tensorflow.tools.compatibility import ipynb
|
||||
from tensorflow.tools.compatibility import tf_upgrade_v2
|
||||
from tensorflow.tools.compatibility import tf_upgrade_v2_safety
|
||||
|
||||
# Make straightforward changes to convert to 2.0. In harder cases,
|
||||
# use compat.v1.
|
||||
_DEFAULT_MODE = "DEFAULT"
|
||||
@ -35,10 +39,10 @@ _SAFETY_MODE = "SAFETY"
|
||||
def process_file(in_filename, out_filename, upgrader):
|
||||
"""Process a file of type `.py` or `.ipynb`."""
|
||||
|
||||
if in_filename.endswith(".py"):
|
||||
if six.ensure_str(in_filename).endswith(".py"):
|
||||
files_processed, report_text, errors = \
|
||||
upgrader.process_file(in_filename, out_filename)
|
||||
elif in_filename.endswith(".ipynb"):
|
||||
elif six.ensure_str(in_filename).endswith(".ipynb"):
|
||||
files_processed, report_text, errors = \
|
||||
ipynb.process_file(in_filename, out_filename, upgrader)
|
||||
else:
|
||||
@ -157,24 +161,24 @@ Simple usage:
|
||||
for f in errors:
|
||||
if errors[f]:
|
||||
num_errors += len(errors[f])
|
||||
report.append("-" * 80 + "\n")
|
||||
report.append(six.ensure_str("-" * 80) + "\n")
|
||||
report.append("File: %s\n" % f)
|
||||
report.append("-" * 80 + "\n")
|
||||
report.append(six.ensure_str("-" * 80) + "\n")
|
||||
report.append("\n".join(errors[f]) + "\n")
|
||||
|
||||
report = ("TensorFlow 2.0 Upgrade Script\n"
|
||||
"-----------------------------\n"
|
||||
"Converted %d files\n" % files_processed +
|
||||
"Detected %d issues that require attention" % num_errors + "\n" +
|
||||
"-" * 80 + "\n") + "".join(report)
|
||||
detailed_report_header = "=" * 80 + "\n"
|
||||
six.ensure_str("-" * 80) + "\n") + "".join(report)
|
||||
detailed_report_header = six.ensure_str("=" * 80) + "\n"
|
||||
detailed_report_header += "Detailed log follows:\n\n"
|
||||
detailed_report_header += "=" * 80 + "\n"
|
||||
detailed_report_header += six.ensure_str("=" * 80) + "\n"
|
||||
|
||||
with open(report_filename, "w") as report_file:
|
||||
report_file.write(report)
|
||||
report_file.write(detailed_report_header)
|
||||
report_file.write(report_text)
|
||||
report_file.write(six.ensure_str(report_text))
|
||||
|
||||
if args.print_all:
|
||||
print(report)
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -39,7 +40,7 @@ from tensorflow.tools.compatibility import tf_upgrade_v2
|
||||
|
||||
|
||||
def get_symbol_for_name(root, name):
|
||||
name_parts = name.split(".")
|
||||
name_parts = six.ensure_str(name).split(".")
|
||||
symbol = root
|
||||
# Iterate starting with second item since 1st item is "tf.".
|
||||
for part in name_parts[1:]:
|
||||
@ -66,12 +67,13 @@ def get_func_and_args_from_str(call_str):
|
||||
Returns:
|
||||
(function_name, list of arg names) tuple.
|
||||
"""
|
||||
open_paren_index = call_str.find("(")
|
||||
open_paren_index = six.ensure_str(call_str).find("(")
|
||||
close_paren_index = call_str.rfind(")")
|
||||
|
||||
function_name = call_str[:call_str.find("(")]
|
||||
args = call_str[open_paren_index+1:close_paren_index].split(",")
|
||||
args = [arg.split("=")[0].strip() for arg in args]
|
||||
function_name = call_str[:six.ensure_str(call_str).find("(")]
|
||||
args = six.ensure_str(call_str[open_paren_index +
|
||||
1:close_paren_index]).split(",")
|
||||
args = [six.ensure_str(arg).split("=")[0].strip() for arg in args]
|
||||
args = [arg for arg in args if arg] # filter out empty strings
|
||||
return function_name, args
|
||||
|
||||
@ -96,7 +98,7 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
_, attr = tf_decorator.unwrap(child[1])
|
||||
api_names_v2 = tf_export.get_v2_names(attr)
|
||||
for name in api_names_v2:
|
||||
cls.v2_symbols["tf." + name] = attr
|
||||
cls.v2_symbols["tf." + six.ensure_str(name)] = attr
|
||||
|
||||
visitor = public_api.PublicAPIVisitor(symbol_collector)
|
||||
visitor.private_map["tf.compat"] = ["v1"]
|
||||
@ -109,7 +111,7 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
_, attr = tf_decorator.unwrap(child[1])
|
||||
api_names_v1 = tf_export.get_v1_names(attr)
|
||||
for name in api_names_v1:
|
||||
cls.v1_symbols["tf." + name] = attr
|
||||
cls.v1_symbols["tf." + six.ensure_str(name)] = attr
|
||||
|
||||
visitor = public_api.PublicAPIVisitor(symbol_collector_v1)
|
||||
traverse.traverse(tf.compat.v1, visitor)
|
||||
@ -138,15 +140,16 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
def testParseError(self):
|
||||
_, report, unused_errors, unused_new_text = self._upgrade(
|
||||
"import tensorflow as tf\na + \n")
|
||||
self.assertTrue(report.find("Failed to parse") != -1)
|
||||
self.assertNotEqual(six.ensure_str(report).find("Failed to parse"), -1)
|
||||
|
||||
def testReport(self):
|
||||
text = "tf.angle(a)\n"
|
||||
_, report, unused_errors, unused_new_text = self._upgrade(text)
|
||||
# This is not a complete test, but it is a sanity test that a report
|
||||
# is generating information.
|
||||
self.assertTrue(report.find("Renamed function `tf.angle` to "
|
||||
"`tf.math.angle`"))
|
||||
self.assertTrue(
|
||||
six.ensure_str(report).find("Renamed function `tf.angle` to "
|
||||
"`tf.math.angle`"))
|
||||
|
||||
def testRename(self):
|
||||
text = "tf.conj(a)\n"
|
||||
@ -169,7 +172,7 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
_, attr = tf_decorator.unwrap(child[1])
|
||||
api_names = tf_export.get_v1_names(attr)
|
||||
for name in api_names:
|
||||
_, _, _, text = self._upgrade("tf." + name)
|
||||
_, _, _, text = self._upgrade("tf." + six.ensure_str(name))
|
||||
if (text and
|
||||
not text.startswith("tf.compat.v1") and
|
||||
not text.startswith("tf.compat.v2") and
|
||||
@ -198,9 +201,9 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
api_names = tf_export.get_v1_names(attr)
|
||||
for name in api_names:
|
||||
if collect:
|
||||
v1_symbols.add("tf." + name)
|
||||
v1_symbols.add("tf." + six.ensure_str(name))
|
||||
else:
|
||||
_, _, _, text = self._upgrade("tf." + name)
|
||||
_, _, _, text = self._upgrade("tf." + six.ensure_str(name))
|
||||
if (text and
|
||||
not text.startswith("tf.compat.v1") and
|
||||
not text.startswith("tf.compat.v2") and
|
||||
@ -337,16 +340,16 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def testPositionsMatchArgGiven(self):
|
||||
full_dict = tf_upgrade_v2.TFAPIChangeSpec().function_arg_warnings
|
||||
method_names = full_dict.keys()
|
||||
method_names = list(full_dict.keys())
|
||||
for method_name in method_names:
|
||||
args = full_dict[method_name].keys()
|
||||
args = list(full_dict[method_name].keys())
|
||||
if "contrib" in method_name:
|
||||
# Skip descending and fetching contrib methods during test. These are
|
||||
# not available in the repo anymore.
|
||||
continue
|
||||
elif method_name.startswith("*."):
|
||||
elif six.ensure_str(method_name).startswith("*."):
|
||||
# special case for optimizer methods
|
||||
method = method_name.replace("*", "tf.train.Optimizer")
|
||||
method = six.ensure_str(method_name).replace("*", "tf.train.Optimizer")
|
||||
else:
|
||||
method = method_name
|
||||
|
||||
@ -354,7 +357,7 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
arg_spec = tf_inspect.getfullargspec(method)
|
||||
for (arg, pos) in args:
|
||||
# to deal with the self argument on methods on objects
|
||||
if method_name.startswith("*."):
|
||||
if six.ensure_str(method_name).startswith("*."):
|
||||
pos += 1
|
||||
self.assertEqual(arg_spec[0][pos], arg)
|
||||
|
||||
|
@ -6,7 +6,7 @@ package(
|
||||
py_binary(
|
||||
name = "generate_v2_renames_map",
|
||||
srcs = ["generate_v2_renames_map.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
@ -15,13 +15,14 @@ py_binary(
|
||||
"//tensorflow/tools/common:public_api",
|
||||
"//tensorflow/tools/common:traverse",
|
||||
"//tensorflow/tools/compatibility:all_renames_v2",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "generate_v2_reorders_map",
|
||||
srcs = ["generate_v2_reorders_map.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -23,9 +24,9 @@ To update renames_v2.py, run:
|
||||
# pylint: enable=line-too-long
|
||||
import sys
|
||||
|
||||
import six
|
||||
import tensorflow as tf
|
||||
|
||||
# This import is needed so that TensorFlow python modules are in sys.modules.
|
||||
from tensorflow import python as tf_python # pylint: disable=unused-import
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.platform import app
|
||||
@ -35,6 +36,7 @@ from tensorflow.tools.common import public_api
|
||||
from tensorflow.tools.common import traverse
|
||||
from tensorflow.tools.compatibility import all_renames_v2
|
||||
|
||||
# This import is needed so that TensorFlow python modules are in sys.modules.
|
||||
|
||||
_OUTPUT_FILE_PATH = 'third_party/tensorflow/tools/compatibility/renames_v2.py'
|
||||
_FILE_HEADER = """# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
@ -178,7 +180,8 @@ def update_renames_v2(output_file_path):
|
||||
rename_lines = [
|
||||
get_rename_line(name, canonical_name)
|
||||
for name, canonical_name in all_renames
|
||||
if 'tf.' + name not in manual_renames]
|
||||
if 'tf.' + six.ensure_str(name) not in manual_renames
|
||||
]
|
||||
renames_file_text = '%srenames = {\n%s\n}\n' % (
|
||||
_FILE_HEADER, ',\n'.join(sorted(rename_lines)))
|
||||
file_io.write_string_to_file(output_file_path, renames_file_text)
|
||||
|
@ -31,6 +31,7 @@ py_library(
|
||||
"doc_generator_visitor.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["@six_archive//:six"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
@ -39,12 +40,13 @@ py_test(
|
||||
srcs = [
|
||||
"doc_generator_visitor_test.py",
|
||||
],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":doc_generator_visitor",
|
||||
":generate_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
@ -59,7 +61,7 @@ py_test(
|
||||
name = "doc_controls_test",
|
||||
size = "small",
|
||||
srcs = ["doc_controls_test.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":doc_controls",
|
||||
@ -77,6 +79,7 @@ py_library(
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
"@astor_archive//:astor",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
@ -84,11 +87,12 @@ py_test(
|
||||
name = "parser_test",
|
||||
size = "small",
|
||||
srcs = ["parser_test.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":parser",
|
||||
"//tensorflow/python:platform_test",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
@ -96,6 +100,7 @@ py_library(
|
||||
name = "pretty_docs",
|
||||
srcs = ["pretty_docs.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["@six_archive//:six"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -120,7 +125,7 @@ py_test(
|
||||
name = "generate_lib_test",
|
||||
size = "small",
|
||||
srcs = ["generate_lib_test.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":generate_lib",
|
||||
@ -132,7 +137,7 @@ py_test(
|
||||
py_binary(
|
||||
name = "generate",
|
||||
srcs = ["generate.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":generate_lib",
|
||||
@ -165,9 +170,12 @@ py_test(
|
||||
py_binary(
|
||||
name = "generate2",
|
||||
srcs = ["generate2.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":generate2_lib"],
|
||||
deps = [
|
||||
":generate2_lib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -184,16 +192,18 @@ py_library(
|
||||
name = "py_guide_parser",
|
||||
srcs = ["py_guide_parser.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["@six_archive//:six"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "py_guide_parser_test",
|
||||
size = "small",
|
||||
srcs = ["py_guide_parser_test.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":py_guide_parser",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -48,7 +49,7 @@ class DocGeneratorVisitor(object):
|
||||
def set_root_name(self, root_name):
|
||||
"""Sets the root name for subsequent __call__s."""
|
||||
self._root_name = root_name or ''
|
||||
self._prefix = (root_name + '.') if root_name else ''
|
||||
self._prefix = (six.ensure_str(root_name) + '.') if root_name else ''
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
@ -178,7 +179,7 @@ class DocGeneratorVisitor(object):
|
||||
A tuple of scores. When sorted the preferred name will have the lowest
|
||||
value.
|
||||
"""
|
||||
parts = name.split('.')
|
||||
parts = six.ensure_str(name).split('.')
|
||||
short_name = parts[-1]
|
||||
|
||||
container = self._index['.'.join(parts[:-1])]
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -20,6 +21,8 @@ from __future__ import print_function
|
||||
|
||||
import types
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.tools.docs import doc_generator_visitor
|
||||
from tensorflow.tools.docs import generate_lib
|
||||
@ -29,9 +32,9 @@ class NoDunderVisitor(doc_generator_visitor.DocGeneratorVisitor):
|
||||
|
||||
def __call__(self, parent_name, parent, children):
|
||||
"""Drop all the dunder methods to make testing easier."""
|
||||
children = [
|
||||
(name, obj) for (name, obj) in children if not name.startswith('_')
|
||||
]
|
||||
children = [(name, obj)
|
||||
for (name, obj) in children
|
||||
if not six.ensure_str(name).startswith('_')]
|
||||
super(NoDunderVisitor, self).__call__(parent_name, parent, children)
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -147,8 +148,10 @@ def write_docs(output_dir,
|
||||
duplicates = [item for item in duplicates if item != full_name]
|
||||
|
||||
for dup in duplicates:
|
||||
from_path = os.path.join(site_api_path, dup.replace('.', '/'))
|
||||
to_path = os.path.join(site_api_path, full_name.replace('.', '/'))
|
||||
from_path = os.path.join(site_api_path,
|
||||
six.ensure_str(dup).replace('.', '/'))
|
||||
to_path = os.path.join(site_api_path,
|
||||
six.ensure_str(full_name).replace('.', '/'))
|
||||
redirects.append((
|
||||
os.path.join('/', from_path),
|
||||
os.path.join('/', to_path)))
|
||||
@ -167,7 +170,7 @@ def write_docs(output_dir,
|
||||
# Generate table of contents
|
||||
|
||||
# Put modules in alphabetical order, case-insensitive
|
||||
modules = sorted(module_children.keys(), key=lambda a: a.upper())
|
||||
modules = sorted(list(module_children.keys()), key=lambda a: a.upper())
|
||||
|
||||
leftnav_path = os.path.join(output_dir, '_toc.yaml')
|
||||
with open(leftnav_path, 'w') as f:
|
||||
@ -183,16 +186,15 @@ def write_docs(output_dir,
|
||||
if indent_num > 1:
|
||||
# tf.contrib.baysflow.entropy will be under
|
||||
# tf.contrib->baysflow->entropy
|
||||
title = module.split('.')[-1]
|
||||
title = six.ensure_str(module).split('.')[-1]
|
||||
else:
|
||||
title = module
|
||||
|
||||
header = [
|
||||
'- title: ' + title,
|
||||
' section:',
|
||||
' - title: Overview',
|
||||
' path: ' + os.path.join('/', site_api_path,
|
||||
symbol_to_file[module])]
|
||||
'- title: ' + six.ensure_str(title), ' section:',
|
||||
' - title: Overview', ' path: ' +
|
||||
os.path.join('/', site_api_path, symbol_to_file[module])
|
||||
]
|
||||
header = ''.join([indent+line+'\n' for line in header])
|
||||
f.write(header)
|
||||
|
||||
@ -211,8 +213,9 @@ def write_docs(output_dir,
|
||||
# Write a global index containing all full names with links.
|
||||
with open(os.path.join(output_dir, 'index.md'), 'w') as f:
|
||||
f.write(
|
||||
parser.generate_global_index(root_title, parser_config.index,
|
||||
parser_config.reference_resolver))
|
||||
six.ensure_str(
|
||||
parser.generate_global_index(root_title, parser_config.index,
|
||||
parser_config.reference_resolver)))
|
||||
|
||||
|
||||
def add_dict_to_dict(add_from, add_to):
|
||||
@ -345,7 +348,7 @@ def build_doc_index(src_dir):
|
||||
for dirpath, _, filenames in os.walk(src_dir):
|
||||
suffix = os.path.relpath(path=dirpath, start=src_dir)
|
||||
for base_name in filenames:
|
||||
if not base_name.endswith('.md'):
|
||||
if not six.ensure_str(base_name).endswith('.md'):
|
||||
continue
|
||||
title_parser = _GetMarkdownTitle()
|
||||
title_parser.process(os.path.join(dirpath, base_name))
|
||||
@ -353,7 +356,8 @@ def build_doc_index(src_dir):
|
||||
msg = ('`{}` has no markdown title (# title)'.format(
|
||||
os.path.join(dirpath, base_name)))
|
||||
raise ValueError(msg)
|
||||
key_parts = os.path.join(suffix, base_name[:-3]).split('/')
|
||||
key_parts = six.ensure_str(os.path.join(suffix,
|
||||
base_name[:-3])).split('/')
|
||||
if key_parts[-1] == 'index':
|
||||
key_parts = key_parts[:-1]
|
||||
doc_info = _DocInfo(os.path.join(suffix, base_name), title_parser.title)
|
||||
@ -367,8 +371,8 @@ def build_doc_index(src_dir):
|
||||
class _GuideRef(object):
|
||||
|
||||
def __init__(self, base_name, title, section_title, section_tag):
|
||||
self.url = 'api_guides/python/' + (('%s#%s' % (base_name, section_tag))
|
||||
if section_tag else base_name)
|
||||
self.url = 'api_guides/python/' + six.ensure_str(
|
||||
(('%s#%s' % (base_name, section_tag)) if section_tag else base_name))
|
||||
self.link_text = (('%s > %s' % (title, section_title))
|
||||
if section_title else title)
|
||||
|
||||
@ -447,7 +451,7 @@ def update_id_tags_inplace(src_dir):
|
||||
# modified file contents
|
||||
content = tag_updater.process(full_path)
|
||||
with open(full_path, 'w') as f:
|
||||
f.write(content)
|
||||
f.write(six.ensure_str(content))
|
||||
|
||||
|
||||
EXCLUDED = set(['__init__.py', 'OWNERS', 'README.txt'])
|
||||
@ -512,7 +516,7 @@ def replace_refs(src_dir,
|
||||
content = reference_resolver.replace_references(content,
|
||||
relative_path_to_root)
|
||||
with open(full_out_path, 'wb') as f:
|
||||
f.write(content.encode('utf-8'))
|
||||
f.write(six.ensure_binary(content, 'utf-8'))
|
||||
|
||||
|
||||
class DocGenerator(object):
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -28,6 +29,7 @@ import re
|
||||
|
||||
import astor
|
||||
import six
|
||||
from six.moves import zip
|
||||
|
||||
from google.protobuf.message import Message as ProtoMessage
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
@ -50,7 +52,7 @@ def is_free_function(py_object, full_name, index):
|
||||
if not tf_inspect.isfunction(py_object):
|
||||
return False
|
||||
|
||||
parent_name = full_name.rsplit('.', 1)[0]
|
||||
parent_name = six.ensure_str(full_name).rsplit('.', 1)[0]
|
||||
if tf_inspect.isclass(index[parent_name]):
|
||||
return False
|
||||
|
||||
@ -112,14 +114,14 @@ def documentation_path(full_name, is_fragment=False):
|
||||
Returns:
|
||||
The file path to which to write the documentation for `full_name`.
|
||||
"""
|
||||
parts = full_name.split('.')
|
||||
parts = six.ensure_str(full_name).split('.')
|
||||
if is_fragment:
|
||||
parts, fragment = parts[:-1], parts[-1]
|
||||
|
||||
result = os.path.join(*parts) + '.md'
|
||||
result = six.ensure_str(os.path.join(*parts)) + '.md'
|
||||
|
||||
if is_fragment:
|
||||
result = result + '#' + fragment
|
||||
result = six.ensure_str(result) + '#' + six.ensure_str(fragment)
|
||||
|
||||
return result
|
||||
|
||||
@ -288,7 +290,7 @@ class ReferenceResolver(object):
|
||||
self.add_error(e.message)
|
||||
return 'BAD_LINK'
|
||||
|
||||
string = re.sub(SYMBOL_REFERENCE_RE, strict_one_ref, string)
|
||||
string = re.sub(SYMBOL_REFERENCE_RE, strict_one_ref, six.ensure_str(string))
|
||||
|
||||
def sloppy_one_ref(match):
|
||||
try:
|
||||
@ -333,7 +335,7 @@ class ReferenceResolver(object):
|
||||
@staticmethod
|
||||
def _link_text_to_html(link_text):
|
||||
code_re = '`(.*?)`'
|
||||
return re.sub(code_re, r'<code>\1</code>', link_text)
|
||||
return re.sub(code_re, r'<code>\1</code>', six.ensure_str(link_text))
|
||||
|
||||
def py_master_name(self, full_name):
|
||||
"""Return the master name for a Python symbol name."""
|
||||
@ -389,11 +391,11 @@ class ReferenceResolver(object):
|
||||
manual_link_text = False
|
||||
|
||||
# Handle different types of references.
|
||||
if string.startswith('$'): # Doc reference
|
||||
if six.ensure_str(string).startswith('$'): # Doc reference
|
||||
return self._doc_link(string, link_text, manual_link_text,
|
||||
relative_path_to_root)
|
||||
|
||||
elif string.startswith('tensorflow::'):
|
||||
elif six.ensure_str(string).startswith('tensorflow::'):
|
||||
# C++ symbol
|
||||
return self._cc_link(string, link_text, manual_link_text,
|
||||
relative_path_to_root)
|
||||
@ -401,7 +403,8 @@ class ReferenceResolver(object):
|
||||
else:
|
||||
is_python = False
|
||||
for py_module_name in self._py_module_names:
|
||||
if string == py_module_name or string.startswith(py_module_name + '.'):
|
||||
if string == py_module_name or string.startswith(
|
||||
six.ensure_str(py_module_name) + '.'):
|
||||
is_python = True
|
||||
break
|
||||
if is_python: # Python symbol
|
||||
@ -421,7 +424,7 @@ class ReferenceResolver(object):
|
||||
string = string[1:] # remove leading $
|
||||
|
||||
# If string has a #, split that part into `hash_tag`
|
||||
hash_pos = string.find('#')
|
||||
hash_pos = six.ensure_str(string).find('#')
|
||||
if hash_pos > -1:
|
||||
hash_tag = string[hash_pos:]
|
||||
string = string[:hash_pos]
|
||||
@ -520,10 +523,10 @@ class _FunctionDetail(
|
||||
|
||||
def __str__(self):
|
||||
"""Return the original string that represents the function detail."""
|
||||
parts = [self.keyword + ':\n']
|
||||
parts = [six.ensure_str(self.keyword) + ':\n']
|
||||
parts.append(self.header)
|
||||
for key, value in self.items:
|
||||
parts.append(' ' + key + ': ')
|
||||
parts.append(' ' + six.ensure_str(key) + ': ')
|
||||
parts.append(value)
|
||||
|
||||
return ''.join(parts)
|
||||
@ -587,7 +590,7 @@ def _parse_function_details(docstring):
|
||||
item_re = re.compile(r'^ ? ?(\*?\*?\w[\w.]*?\s*):\s', re.MULTILINE)
|
||||
|
||||
for keyword, content in pairs:
|
||||
content = item_re.split(content)
|
||||
content = item_re.split(six.ensure_str(content))
|
||||
header = content[0]
|
||||
items = list(_gen_pairs(content[1:]))
|
||||
|
||||
@ -634,7 +637,8 @@ def _parse_md_docstring(py_object, relative_path_to_root, reference_resolver):
|
||||
|
||||
atat_re = re.compile(r' *@@[a-zA-Z_.0-9]+ *$')
|
||||
raw_docstring = '\n'.join(
|
||||
line for line in raw_docstring.split('\n') if not atat_re.match(line))
|
||||
line for line in six.ensure_str(raw_docstring).split('\n')
|
||||
if not atat_re.match(six.ensure_str(line)))
|
||||
|
||||
docstring, compatibility = _handle_compatibility(raw_docstring)
|
||||
docstring, function_details = _parse_function_details(docstring)
|
||||
@ -698,8 +702,9 @@ def _get_arg_spec(func):
|
||||
|
||||
|
||||
def _remove_first_line_indent(string):
|
||||
indent = len(re.match(r'^\s*', string).group(0))
|
||||
return '\n'.join([line[indent:] for line in string.split('\n')])
|
||||
indent = len(re.match(r'^\s*', six.ensure_str(string)).group(0))
|
||||
return '\n'.join(
|
||||
[line[indent:] for line in six.ensure_str(string).split('\n')])
|
||||
|
||||
|
||||
PAREN_NUMBER_RE = re.compile(r'^\(([0-9.e-]+)\)')
|
||||
@ -761,9 +766,9 @@ def _generate_signature(func, reverse_index):
|
||||
default_text = reverse_index[id(default)]
|
||||
elif ast_default is not None:
|
||||
default_text = (
|
||||
astor.to_source(ast_default).rstrip('\n').replace('\t', '\\t')
|
||||
.replace('\n', '\\n').replace('"""', "'"))
|
||||
default_text = PAREN_NUMBER_RE.sub('\\1', default_text)
|
||||
six.ensure_str(astor.to_source(ast_default)).rstrip('\n').replace(
|
||||
'\t', '\\t').replace('\n', '\\n').replace('"""', "'"))
|
||||
default_text = PAREN_NUMBER_RE.sub('\\1', six.ensure_str(default_text))
|
||||
|
||||
if default_text != repr(default):
|
||||
# This may be an internal name. If so, handle the ones we know about.
|
||||
@ -797,9 +802,9 @@ def _generate_signature(func, reverse_index):
|
||||
|
||||
# Add *args and *kwargs.
|
||||
if argspec.varargs:
|
||||
args_list.append('*' + argspec.varargs)
|
||||
args_list.append('*' + six.ensure_str(argspec.varargs))
|
||||
if argspec.varkw:
|
||||
args_list.append('**' + argspec.varkw)
|
||||
args_list.append('**' + six.ensure_str(argspec.varkw))
|
||||
|
||||
return args_list
|
||||
|
||||
@ -879,7 +884,7 @@ class _FunctionPageInfo(object):
|
||||
|
||||
@property
|
||||
def short_name(self):
|
||||
return self._full_name.split('.')[-1]
|
||||
return six.ensure_str(self._full_name).split('.')[-1]
|
||||
|
||||
@property
|
||||
def defined_in(self):
|
||||
@ -998,7 +1003,7 @@ class _ClassPageInfo(object):
|
||||
@property
|
||||
def short_name(self):
|
||||
"""Returns the documented object's short name."""
|
||||
return self._full_name.split('.')[-1]
|
||||
return six.ensure_str(self._full_name).split('.')[-1]
|
||||
|
||||
@property
|
||||
def defined_in(self):
|
||||
@ -1091,9 +1096,12 @@ class _ClassPageInfo(object):
|
||||
base_url = parser_config.reference_resolver.reference_to_url(
|
||||
base_full_name, relative_path)
|
||||
|
||||
link_info = _LinkInfo(short_name=base_full_name.split('.')[-1],
|
||||
full_name=base_full_name, obj=base,
|
||||
doc=base_doc, url=base_url)
|
||||
link_info = _LinkInfo(
|
||||
short_name=six.ensure_str(base_full_name).split('.')[-1],
|
||||
full_name=base_full_name,
|
||||
obj=base,
|
||||
doc=base_doc,
|
||||
url=base_url)
|
||||
bases.append(link_info)
|
||||
|
||||
self._bases = bases
|
||||
@ -1121,7 +1129,7 @@ class _ClassPageInfo(object):
|
||||
doc: The property's parsed docstring, a `_DocstringInfo`.
|
||||
"""
|
||||
# Hide useless namedtuple docs-trings
|
||||
if re.match('Alias for field number [0-9]+', doc.docstring):
|
||||
if re.match('Alias for field number [0-9]+', six.ensure_str(doc.docstring)):
|
||||
doc = doc._replace(docstring='', brief='')
|
||||
property_info = _PropertyInfo(short_name, full_name, obj, doc)
|
||||
self._properties.append(property_info)
|
||||
@ -1255,8 +1263,8 @@ class _ClassPageInfo(object):
|
||||
|
||||
# Omit methods defined by namedtuple.
|
||||
original_method = defining_class.__dict__[short_name]
|
||||
if (hasattr(original_method, '__module__') and
|
||||
(original_method.__module__ or '').startswith('namedtuple')):
|
||||
if (hasattr(original_method, '__module__') and six.ensure_str(
|
||||
(original_method.__module__ or '')).startswith('namedtuple')):
|
||||
continue
|
||||
|
||||
# Some methods are often overridden without documentation. Because it's
|
||||
@ -1294,7 +1302,7 @@ class _ClassPageInfo(object):
|
||||
else:
|
||||
# Exclude members defined by protobuf that are useless
|
||||
if issubclass(py_class, ProtoMessage):
|
||||
if (short_name.endswith('_FIELD_NUMBER') or
|
||||
if (six.ensure_str(short_name).endswith('_FIELD_NUMBER') or
|
||||
short_name in ['__slots__', 'DESCRIPTOR']):
|
||||
continue
|
||||
|
||||
@ -1332,7 +1340,7 @@ class _ModulePageInfo(object):
|
||||
|
||||
@property
|
||||
def short_name(self):
|
||||
return self._full_name.split('.')[-1]
|
||||
return six.ensure_str(self._full_name).split('.')[-1]
|
||||
|
||||
@property
|
||||
def defined_in(self):
|
||||
@ -1425,7 +1433,8 @@ class _ModulePageInfo(object):
|
||||
'__cached__', '__loader__', '__spec__']:
|
||||
continue
|
||||
|
||||
member_full_name = self.full_name + '.' + name if self.full_name else name
|
||||
member_full_name = six.ensure_str(self.full_name) + '.' + six.ensure_str(
|
||||
name) if self.full_name else name
|
||||
member = parser_config.py_name_to_object(member_full_name)
|
||||
|
||||
member_doc = _parse_md_docstring(member, relative_path,
|
||||
@ -1680,20 +1689,21 @@ def _get_defined_in(py_object, parser_config):
|
||||
# TODO(wicke): And make their source file predictable from the file name.
|
||||
|
||||
# In case this is compiled, point to the original
|
||||
if path.endswith('.pyc'):
|
||||
if six.ensure_str(path).endswith('.pyc'):
|
||||
path = path[:-1]
|
||||
|
||||
# Never include links outside this code base.
|
||||
if path.startswith('..') or re.search(r'\b_api\b', path):
|
||||
if six.ensure_str(path).startswith('..') or re.search(r'\b_api\b',
|
||||
six.ensure_str(path)):
|
||||
return None
|
||||
|
||||
if re.match(r'.*/gen_[^/]*\.py$', path):
|
||||
if re.match(r'.*/gen_[^/]*\.py$', six.ensure_str(path)):
|
||||
return _GeneratedFile(path, parser_config)
|
||||
if 'genfiles' in path or 'tools/api/generator' in path:
|
||||
return _GeneratedFile(path, parser_config)
|
||||
elif re.match(r'.*_pb2\.py$', path):
|
||||
elif re.match(r'.*_pb2\.py$', six.ensure_str(path)):
|
||||
# The _pb2.py files all appear right next to their defining .proto file.
|
||||
return _ProtoFile(path[:-7] + '.proto', parser_config)
|
||||
return _ProtoFile(six.ensure_str(path[:-7]) + '.proto', parser_config)
|
||||
else:
|
||||
return _PythonFile(path, parser_config)
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -23,6 +24,8 @@ import functools
|
||||
import os
|
||||
import sys
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.tools.docs import doc_controls
|
||||
@ -180,7 +183,8 @@ class ParserTest(googletest.TestCase):
|
||||
|
||||
# Make sure the brief docstring is present
|
||||
self.assertEqual(
|
||||
tf_inspect.getdoc(TestClass).split('\n')[0], page_info.doc.brief)
|
||||
six.ensure_str(tf_inspect.getdoc(TestClass)).split('\n')[0],
|
||||
page_info.doc.brief)
|
||||
|
||||
# Make sure the method is present
|
||||
self.assertEqual(TestClass.a_method, page_info.methods[0].obj)
|
||||
@ -236,7 +240,7 @@ class ParserTest(googletest.TestCase):
|
||||
# 'Alias for field number ##'. These props are returned sorted.
|
||||
|
||||
def sort_key(prop_info):
|
||||
return int(prop_info.obj.__doc__.split(' ')[-1])
|
||||
return int(six.ensure_str(prop_info.obj.__doc__).split(' ')[-1])
|
||||
|
||||
self.assertSequenceEqual(page_info.properties,
|
||||
sorted(page_info.properties, key=sort_key))
|
||||
@ -378,7 +382,8 @@ class ParserTest(googletest.TestCase):
|
||||
|
||||
# Make sure the brief docstring is present
|
||||
self.assertEqual(
|
||||
tf_inspect.getdoc(test_module).split('\n')[0], page_info.doc.brief)
|
||||
six.ensure_str(tf_inspect.getdoc(test_module)).split('\n')[0],
|
||||
page_info.doc.brief)
|
||||
|
||||
# Make sure that the members are there
|
||||
funcs = {f_info.obj for f_info in page_info.functions}
|
||||
@ -422,7 +427,8 @@ class ParserTest(googletest.TestCase):
|
||||
|
||||
# Make sure the brief docstring is present
|
||||
self.assertEqual(
|
||||
tf_inspect.getdoc(test_function).split('\n')[0], page_info.doc.brief)
|
||||
six.ensure_str(tf_inspect.getdoc(test_function)).split('\n')[0],
|
||||
page_info.doc.brief)
|
||||
|
||||
# Make sure the extracted signature is good.
|
||||
self.assertEqual(['unused_arg', "unused_kwarg='default'"],
|
||||
@ -461,7 +467,8 @@ class ParserTest(googletest.TestCase):
|
||||
|
||||
# Make sure the brief docstring is present
|
||||
self.assertEqual(
|
||||
tf_inspect.getdoc(test_function_with_args_kwargs).split('\n')[0],
|
||||
six.ensure_str(
|
||||
tf_inspect.getdoc(test_function_with_args_kwargs)).split('\n')[0],
|
||||
page_info.doc.brief)
|
||||
|
||||
# Make sure the extracted signature is good.
|
||||
@ -751,7 +758,8 @@ class TestParseFunctionDetails(googletest.TestCase):
|
||||
|
||||
self.assertEqual(
|
||||
RELU_DOC,
|
||||
docstring + ''.join(str(detail) for detail in function_details))
|
||||
six.ensure_str(docstring) +
|
||||
''.join(str(detail) for detail in function_details))
|
||||
|
||||
|
||||
class TestGenerateSignature(googletest.TestCase):
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -28,6 +29,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import textwrap
|
||||
import six
|
||||
|
||||
|
||||
def build_md_page(page_info):
|
||||
@ -83,7 +85,8 @@ def _build_class_page(page_info):
|
||||
"""Given a ClassPageInfo object Return the page as an md string."""
|
||||
parts = ['# {page_info.full_name}\n\n'.format(page_info=page_info)]
|
||||
|
||||
parts.append('## Class `%s`\n\n' % page_info.full_name.split('.')[-1])
|
||||
parts.append('## Class `%s`\n\n' %
|
||||
six.ensure_str(page_info.full_name).split('.')[-1])
|
||||
if page_info.bases:
|
||||
parts.append('Inherits From: ')
|
||||
|
||||
@ -222,7 +225,7 @@ def _build_module_page(page_info):
|
||||
parts.append(template.format(**item._asdict()))
|
||||
|
||||
if item.doc.brief:
|
||||
parts.append(': ' + item.doc.brief)
|
||||
parts.append(': ' + six.ensure_str(item.doc.brief))
|
||||
|
||||
parts.append('\n\n')
|
||||
|
||||
@ -234,7 +237,7 @@ def _build_module_page(page_info):
|
||||
parts.append(template.format(**item._asdict()))
|
||||
|
||||
if item.doc.brief:
|
||||
parts.append(': ' + item.doc.brief)
|
||||
parts.append(': ' + six.ensure_str(item.doc.brief))
|
||||
|
||||
parts.append('\n\n')
|
||||
|
||||
@ -246,7 +249,7 @@ def _build_module_page(page_info):
|
||||
parts.append(template.format(**item._asdict()))
|
||||
|
||||
if item.doc.brief:
|
||||
parts.append(': ' + item.doc.brief)
|
||||
parts.append(': ' + six.ensure_str(item.doc.brief))
|
||||
|
||||
parts.append('\n\n')
|
||||
|
||||
@ -273,7 +276,7 @@ def _build_signature(obj_info, use_full_name=True):
|
||||
'```\n\n')
|
||||
|
||||
parts = ['``` python']
|
||||
parts.extend(['@' + dec for dec in obj_info.decorators])
|
||||
parts.extend(['@' + six.ensure_str(dec) for dec in obj_info.decorators])
|
||||
signature_template = '{name}({sig})'
|
||||
|
||||
if not obj_info.signature:
|
||||
@ -313,7 +316,7 @@ def _build_function_details(function_details):
|
||||
parts = []
|
||||
for detail in function_details:
|
||||
sub = []
|
||||
sub.append('#### ' + detail.keyword + ':\n\n')
|
||||
sub.append('#### ' + six.ensure_str(detail.keyword) + ':\n\n')
|
||||
sub.append(textwrap.dedent(detail.header))
|
||||
for key, value in detail.items:
|
||||
sub.append('* <b>`%s`</b>: %s' % (key, value))
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -21,14 +22,16 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
import six
|
||||
|
||||
|
||||
def md_files_in_dir(py_guide_src_dir):
|
||||
"""Returns a list of filename (full_path, base) pairs for guide files."""
|
||||
all_in_dir = [(os.path.join(py_guide_src_dir, f), f)
|
||||
for f in os.listdir(py_guide_src_dir)]
|
||||
return [(full, f) for full, f in all_in_dir
|
||||
if os.path.isfile(full) and f.endswith('.md')]
|
||||
return [(full, f)
|
||||
for full, f in all_in_dir
|
||||
if os.path.isfile(full) and six.ensure_str(f).endswith('.md')]
|
||||
|
||||
|
||||
class PyGuideParser(object):
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -20,6 +21,8 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.tools.docs import py_guide_parser
|
||||
|
||||
@ -38,7 +41,7 @@ class TestPyGuideParser(py_guide_parser.PyGuideParser):
|
||||
|
||||
def process_in_blockquote(self, line_number, line):
|
||||
self.calls.append((line_number, 'b', line))
|
||||
self.replace_line(line_number, line + ' BQ')
|
||||
self.replace_line(line_number, six.ensure_str(line) + ' BQ')
|
||||
|
||||
def process_line(self, line_number, line):
|
||||
self.calls.append((line_number, 'l', line))
|
||||
|
@ -11,6 +11,7 @@ package(
|
||||
py_binary(
|
||||
name = "gen_git_source",
|
||||
srcs = ["gen_git_source.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["@six_archive//:six"],
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -26,13 +27,16 @@ NOTE: this script is only used in opensource.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from builtins import bytes # pylint: disable=redefined-builtin
|
||||
|
||||
import argparse
|
||||
from builtins import bytes # pylint: disable=redefined-builtin
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
import six
|
||||
|
||||
|
||||
def parse_branch_ref(filename):
|
||||
"""Given a filename of a .git/HEAD file return ref path.
|
||||
@ -161,10 +165,13 @@ def get_git_version(git_base_path, git_tag_override):
|
||||
unknown_label = b"unknown"
|
||||
try:
|
||||
# Force to bytes so this works on python 2 and python 3
|
||||
val = bytes(subprocess.check_output([
|
||||
"git", str("--git-dir=%s/.git" % git_base_path),
|
||||
str("--work-tree=" + git_base_path), "describe", "--long", "--tags"
|
||||
]).strip())
|
||||
val = bytes(
|
||||
subprocess.check_output([
|
||||
"git",
|
||||
str("--git-dir=%s/.git" % git_base_path),
|
||||
str("--work-tree=" + six.ensure_str(git_base_path)), "describe",
|
||||
"--long", "--tags"
|
||||
]).strip())
|
||||
version_separator = b"-"
|
||||
if git_tag_override and val:
|
||||
split_val = val.split(version_separator)
|
||||
|
@ -3,7 +3,6 @@
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_not_v2",
|
||||
"if_not_windows",
|
||||
"tf_cc_binary",
|
||||
"tf_cc_test",
|
||||
|
@ -163,6 +163,7 @@ genrule(
|
||||
"@com_google_protobuf//:LICENSE",
|
||||
"@snappy//:COPYING",
|
||||
"@zlib_archive//:zlib.h",
|
||||
"@six_archive//:LICENSE",
|
||||
] + select({
|
||||
"//tensorflow:android": [],
|
||||
"//tensorflow:ios": [],
|
||||
@ -235,6 +236,7 @@ genrule(
|
||||
"@zlib_archive//:zlib.h",
|
||||
"@grpc//:LICENSE",
|
||||
"@grpc//third_party/address_sorting:LICENSE",
|
||||
"@six_archive//:LICENSE",
|
||||
] + select({
|
||||
"//tensorflow:android": [],
|
||||
"//tensorflow:ios": [],
|
||||
|
@ -13,9 +13,12 @@ package(
|
||||
],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
py_library(
|
||||
name = "compat_checker",
|
||||
srcs = ["compat_checker.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
@ -29,6 +32,7 @@ py_test(
|
||||
data = [
|
||||
"//tensorflow/tools/tensorflow_builder/compat_checker:test_config",
|
||||
],
|
||||
python_version = "PY3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":compat_checker",
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -23,16 +24,11 @@ import re
|
||||
import sys
|
||||
|
||||
import six
|
||||
from six.moves import range
|
||||
import six.moves.configparser
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if six.PY2:
|
||||
import ConfigParser
|
||||
else:
|
||||
import configparser as ConfigParser
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
PATH_TO_DIR = "tensorflow/tools/tensorflow_builder/compat_checker"
|
||||
|
||||
|
||||
@ -56,8 +52,8 @@ def _compare_versions(v1, v2):
|
||||
raise RuntimeError("Cannot compare `inf` to `inf`.")
|
||||
|
||||
rtn_dict = {"smaller": None, "larger": None}
|
||||
v1_list = v1.split(".")
|
||||
v2_list = v2.split(".")
|
||||
v1_list = six.ensure_str(v1).split(".")
|
||||
v2_list = six.ensure_str(v2).split(".")
|
||||
# Take care of cases with infinity (arg=`inf`).
|
||||
if v1_list[0] == "inf":
|
||||
v1_list[0] = str(int(v2_list[0]) + 1)
|
||||
@ -380,7 +376,7 @@ class ConfigCompatChecker(object):
|
||||
curr_status = True
|
||||
|
||||
# Initialize config parser for parsing version requirements file.
|
||||
parser = ConfigParser.ConfigParser()
|
||||
parser = six.moves.configparser.ConfigParser()
|
||||
parser.read(self.req_file)
|
||||
|
||||
if not parser.sections():
|
||||
@ -643,7 +639,7 @@ class ConfigCompatChecker(object):
|
||||
if filtered[-1] == "]":
|
||||
filtered = filtered[:-1]
|
||||
elif "]" in filtered[-1]:
|
||||
filtered[-1] = filtered[-1].replace("]", "")
|
||||
filtered[-1] = six.ensure_str(filtered[-1]).replace("]", "")
|
||||
# If `]` is missing, then it could be a formatting issue with
|
||||
# config file (.ini.). Add to warning.
|
||||
else:
|
||||
@ -792,7 +788,7 @@ class ConfigCompatChecker(object):
|
||||
Boolean that is a status of the compatibility check result.
|
||||
"""
|
||||
# Check if all `Required` configs are found in user configs.
|
||||
usr_keys = self.usr_config.keys()
|
||||
usr_keys = list(self.usr_config.keys())
|
||||
|
||||
for k in six.iterkeys(self.usr_config):
|
||||
if k not in usr_keys:
|
||||
@ -809,10 +805,10 @@ class ConfigCompatChecker(object):
|
||||
for config_name, spec in six.iteritems(self.usr_config):
|
||||
temp_status = True
|
||||
# Check under which section the user config is defined.
|
||||
in_required = config_name in self.required.keys()
|
||||
in_optional = config_name in self.optional.keys()
|
||||
in_unsupported = config_name in self.unsupported.keys()
|
||||
in_dependency = config_name in self.dependency.keys()
|
||||
in_required = config_name in list(self.required.keys())
|
||||
in_optional = config_name in list(self.optional.keys())
|
||||
in_unsupported = config_name in list(self.unsupported.keys())
|
||||
in_dependency = config_name in list(self.dependency.keys())
|
||||
|
||||
# Add to warning if user config is not specified in the config file.
|
||||
if not (in_required or in_optional or in_unsupported or in_dependency):
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -88,7 +89,7 @@ class CompatCheckerTest(unittest.TestCase):
|
||||
# Make sure no warning or error messages are recorded.
|
||||
self.assertFalse(len(self.compat_checker.error_msg))
|
||||
# Make sure total # of successes match total # of configs.
|
||||
cnt = len(USER_CONFIG_IN_RANGE.keys())
|
||||
cnt = len(list(USER_CONFIG_IN_RANGE.keys()))
|
||||
self.assertEqual(len(self.compat_checker.successes), cnt)
|
||||
|
||||
def testWithUserConfigNotInRange(self):
|
||||
@ -106,7 +107,7 @@ class CompatCheckerTest(unittest.TestCase):
|
||||
err_msg_list = self.compat_checker.failures
|
||||
self.assertTrue(len(err_msg_list))
|
||||
# Make sure total # of failures match total # of configs.
|
||||
cnt = len(USER_CONFIG_NOT_IN_RANGE.keys())
|
||||
cnt = len(list(USER_CONFIG_NOT_IN_RANGE.keys()))
|
||||
self.assertEqual(len(err_msg_list), cnt)
|
||||
|
||||
def testWithUserConfigMissing(self):
|
||||
|
@ -14,22 +14,24 @@ py_binary(
|
||||
data = [
|
||||
"//tensorflow/tools/tensorflow_builder/config_detector/data/golden:cuda_cc_golden",
|
||||
],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cuda_compute_capability",
|
||||
"@absl_py//absl:app",
|
||||
"@absl_py//absl/flags",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "cuda_compute_capability",
|
||||
srcs = ["data/cuda_compute_capability.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"@absl_py//absl:app",
|
||||
"@absl_py//absl/flags",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -70,6 +71,7 @@ import subprocess
|
||||
import sys
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import six
|
||||
|
||||
from tensorflow.tools.tensorflow_builder.config_detector.data import cuda_compute_capability
|
||||
|
||||
@ -182,7 +184,7 @@ def get_cpu_type():
|
||||
"""
|
||||
key = "cpu_type"
|
||||
out, err = run_shell_cmd(cmds_all[PLATFORM][key])
|
||||
cpu_detected = out.split(":")[1].strip()
|
||||
cpu_detected = out.split(b":")[1].strip()
|
||||
if err and FLAGS.debug:
|
||||
print("Error in detecting CPU type:\n %s" % str(err))
|
||||
|
||||
@ -201,7 +203,7 @@ def get_cpu_arch():
|
||||
if err and FLAGS.debug:
|
||||
print("Error in detecting CPU arch:\n %s" % str(err))
|
||||
|
||||
return out.strip("\n")
|
||||
return out.strip(b"\n")
|
||||
|
||||
|
||||
def get_distrib():
|
||||
@ -216,7 +218,7 @@ def get_distrib():
|
||||
if err and FLAGS.debug:
|
||||
print("Error in detecting distribution:\n %s" % str(err))
|
||||
|
||||
return out.strip("\n")
|
||||
return out.strip(b"\n")
|
||||
|
||||
|
||||
def get_distrib_version():
|
||||
@ -233,7 +235,7 @@ def get_distrib_version():
|
||||
"Error in detecting distribution version:\n %s" % str(err)
|
||||
)
|
||||
|
||||
return out.strip("\n")
|
||||
return out.strip(b"\n")
|
||||
|
||||
|
||||
def get_gpu_type():
|
||||
@ -251,7 +253,7 @@ def get_gpu_type():
|
||||
key = "gpu_type_no_sudo"
|
||||
gpu_dict = cuda_compute_capability.retrieve_from_golden()
|
||||
out, err = run_shell_cmd(cmds_all[PLATFORM][key])
|
||||
ret_val = out.split(" ")
|
||||
ret_val = out.split(b" ")
|
||||
gpu_id = ret_val[0]
|
||||
if err and FLAGS.debug:
|
||||
print("Error in detecting GPU type:\n %s" % str(err))
|
||||
@ -261,10 +263,10 @@ def get_gpu_type():
|
||||
return gpu_id, GPU_TYPE
|
||||
else:
|
||||
if "[" or "]" in ret_val[1]:
|
||||
gpu_release = ret_val[1].replace("[", "") + " "
|
||||
gpu_release += ret_val[2].replace("]", "").strip("\n")
|
||||
gpu_release = ret_val[1].replace(b"[", b"") + b" "
|
||||
gpu_release += ret_val[2].replace(b"]", b"").strip(b"\n")
|
||||
else:
|
||||
gpu_release = ret_val[1].replace("\n", " ")
|
||||
gpu_release = six.ensure_str(ret_val[1]).replace("\n", " ")
|
||||
|
||||
if gpu_release not in gpu_dict:
|
||||
GPU_TYPE = "unknown"
|
||||
@ -285,7 +287,7 @@ def get_gpu_count():
|
||||
if err and FLAGS.debug:
|
||||
print("Error in detecting GPU count:\n %s" % str(err))
|
||||
|
||||
return out.strip("\n")
|
||||
return out.strip(b"\n")
|
||||
|
||||
|
||||
def get_cuda_version_all():
|
||||
@ -303,7 +305,7 @@ def get_cuda_version_all():
|
||||
"""
|
||||
key = "cuda_ver_all"
|
||||
out, err = run_shell_cmd(cmds_all[PLATFORM.lower()][key])
|
||||
ret_val = out.split("\n")
|
||||
ret_val = out.split(b"\n")
|
||||
filtered = []
|
||||
for item in ret_val:
|
||||
if item not in ["\n", ""]:
|
||||
@ -311,9 +313,9 @@ def get_cuda_version_all():
|
||||
|
||||
all_vers = []
|
||||
for item in filtered:
|
||||
ver_re = re.search(r".*/cuda(\-[\d]+\.[\d]+)?", item)
|
||||
ver_re = re.search(r".*/cuda(\-[\d]+\.[\d]+)?", item.decode("utf-8"))
|
||||
if ver_re.group(1):
|
||||
all_vers.append(ver_re.group(1).strip("-"))
|
||||
all_vers.append(six.ensure_str(ver_re.group(1)).strip("-"))
|
||||
|
||||
if err and FLAGS.debug:
|
||||
print("Error in detecting CUDA version:\n %s" % str(err))
|
||||
@ -409,13 +411,13 @@ def get_cudnn_version():
|
||||
if err and FLAGS.debug:
|
||||
print("Error in finding `cudnn.h`:\n %s" % str(err))
|
||||
|
||||
if len(out.split(" ")) > 1:
|
||||
if len(out.split(b" ")) > 1:
|
||||
cmd = cmds[0] + " | " + cmds[1]
|
||||
out_re, err_re = run_shell_cmd(cmd)
|
||||
if err_re and FLAGS.debug:
|
||||
print("Error in detecting cuDNN version:\n %s" % str(err_re))
|
||||
|
||||
return out_re.strip("\n")
|
||||
return out_re.strip(b"\n")
|
||||
else:
|
||||
return
|
||||
|
||||
@ -432,7 +434,7 @@ def get_gcc_version():
|
||||
if err and FLAGS.debug:
|
||||
print("Error in detecting GCC version:\n %s" % str(err))
|
||||
|
||||
return out.strip("\n")
|
||||
return out.strip(b"\n")
|
||||
|
||||
|
||||
def get_glibc_version():
|
||||
@ -447,7 +449,7 @@ def get_glibc_version():
|
||||
if err and FLAGS.debug:
|
||||
print("Error in detecting GCC version:\n %s" % str(err))
|
||||
|
||||
return out.strip("\n")
|
||||
return out.strip(b"\n")
|
||||
|
||||
|
||||
def get_libstdcpp_version():
|
||||
@ -462,7 +464,7 @@ def get_libstdcpp_version():
|
||||
if err and FLAGS.debug:
|
||||
print("Error in detecting libstdc++ version:\n %s" % str(err))
|
||||
|
||||
ver = out.split("_")[-1].replace("\n", "")
|
||||
ver = out.split(b"_")[-1].replace(b"\n", b"")
|
||||
return ver
|
||||
|
||||
|
||||
@ -485,7 +487,7 @@ def get_cpu_isa_version():
|
||||
found = []
|
||||
missing = []
|
||||
for isa in required_isa:
|
||||
for sys_isa in ret_val.split(" "):
|
||||
for sys_isa in ret_val.split(b" "):
|
||||
if isa == sys_isa:
|
||||
if isa not in found:
|
||||
found.append(isa)
|
||||
@ -539,7 +541,7 @@ def get_all_configs():
|
||||
json_data = {}
|
||||
missing = []
|
||||
warning = []
|
||||
for config, call_func in all_functions.iteritems():
|
||||
for config, call_func in six.iteritems(all_functions):
|
||||
ret_val = call_func
|
||||
if not ret_val:
|
||||
configs_found.append([config, "\033[91m\033[1mMissing\033[0m"])
|
||||
@ -557,10 +559,10 @@ def get_all_configs():
|
||||
configs_found.append([config, ret_val[0]])
|
||||
json_data[config] = ret_val[0]
|
||||
else:
|
||||
configs_found.append(
|
||||
[config,
|
||||
"\033[91m\033[1mMissing " + str(ret_val[1])[1:-1] + "\033[0m"]
|
||||
)
|
||||
configs_found.append([
|
||||
config, "\033[91m\033[1mMissing " +
|
||||
six.ensure_str(str(ret_val[1])[1:-1]) + "\033[0m"
|
||||
])
|
||||
missing.append(
|
||||
[config,
|
||||
"\n\t=> Found %s but missing %s"
|
||||
@ -587,7 +589,7 @@ def print_all_configs(configs, missing, warning):
|
||||
llen = 65 # line length
|
||||
for i, row in enumerate(configs):
|
||||
if i != 0:
|
||||
print_text += "-"*llen + "\n"
|
||||
print_text += six.ensure_str("-" * llen) + "\n"
|
||||
|
||||
if isinstance(row[1], list):
|
||||
val = ", ".join(row[1])
|
||||
@ -629,7 +631,7 @@ def save_to_file(json_data, filename):
|
||||
print("filename: %s" % filename)
|
||||
filename += ".json"
|
||||
|
||||
with open(PATH_TO_DIR + "/" + filename, "w") as f:
|
||||
with open(PATH_TO_DIR + "/" + six.ensure_str(filename), "w") as f:
|
||||
json.dump(json_data, f, sort_keys=True, indent=4)
|
||||
|
||||
print(" Successfully wrote configs to file `%s`.\n" % (filename))
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -42,6 +43,7 @@ import re
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import six
|
||||
import six.moves.urllib.request as urllib
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
@ -61,21 +63,18 @@ def retrieve_from_web(generate_csv=False):
|
||||
NVIDIA page. Order goes from top to bottom of the webpage content (.html).
|
||||
"""
|
||||
url = "https://developer.nvidia.com/cuda-gpus"
|
||||
source = urllib.urlopen(url)
|
||||
source = urllib.request.urlopen(url)
|
||||
matches = []
|
||||
while True:
|
||||
line = source.readline()
|
||||
if "</html>" in line:
|
||||
break
|
||||
else:
|
||||
gpu = re.search(
|
||||
r"<a href=.*>([\w\S\s\d\[\]\,]+[^*])</a>(<a href=.*)?.*",
|
||||
line
|
||||
)
|
||||
gpu = re.search(r"<a href=.*>([\w\S\s\d\[\]\,]+[^*])</a>(<a href=.*)?.*",
|
||||
six.ensure_str(line))
|
||||
capability = re.search(
|
||||
r"([\d]+).([\d]+)(/)?([\d]+)?(.)?([\d]+)?.*</td>.*",
|
||||
line
|
||||
)
|
||||
six.ensure_str(line))
|
||||
if gpu:
|
||||
matches.append(gpu.group(1))
|
||||
elif capability:
|
||||
@ -155,15 +154,15 @@ def create_gpu_capa_map(match_list,
|
||||
|
||||
gpu = ""
|
||||
cnt += 1
|
||||
if len(gpu_capa.keys()) < cnt:
|
||||
if len(list(gpu_capa.keys())) < cnt:
|
||||
mismatch_cnt += 1
|
||||
cnt = len(gpu_capa.keys())
|
||||
cnt = len(list(gpu_capa.keys()))
|
||||
|
||||
else:
|
||||
gpu = match
|
||||
|
||||
if generate_csv:
|
||||
f_name = filename + ".csv"
|
||||
f_name = six.ensure_str(filename) + ".csv"
|
||||
write_csv_from_dict(f_name, gpu_capa)
|
||||
|
||||
return gpu_capa
|
||||
@ -179,8 +178,8 @@ def write_csv_from_dict(filename, input_dict):
|
||||
filename: String that is the output file name.
|
||||
input_dict: Dictionary that is to be written out to a `.csv` file.
|
||||
"""
|
||||
f = open(PATH_TO_DIR + "/data/" + filename, "w")
|
||||
for k, v in input_dict.iteritems():
|
||||
f = open(PATH_TO_DIR + "/data/" + six.ensure_str(filename), "w")
|
||||
for k, v in six.iteritems(input_dict):
|
||||
line = k
|
||||
for item in v:
|
||||
line += "," + item
|
||||
@ -203,7 +202,7 @@ def check_with_golden(filename):
|
||||
Args:
|
||||
filename: String that is the name of the newly created file.
|
||||
"""
|
||||
path_to_file = PATH_TO_DIR + "/data/" + filename
|
||||
path_to_file = PATH_TO_DIR + "/data/" + six.ensure_str(filename)
|
||||
if os.path.isfile(path_to_file) and os.path.isfile(CUDA_CC_GOLDEN_DIR):
|
||||
with open(path_to_file, "r") as f_new:
|
||||
with open(CUDA_CC_GOLDEN_DIR, "r") as f_golden:
|
||||
|
@ -22,17 +22,19 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:platform",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "system_info",
|
||||
srcs = ["system_info.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":system_info_lib",
|
||||
@ -50,16 +52,20 @@ py_library(
|
||||
":system_info_lib",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:platform",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "run_and_gather_logs",
|
||||
srcs = ["run_and_gather_logs.py"],
|
||||
python_version = "PY2",
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":run_and_gather_logs_main_lib"],
|
||||
deps = [
|
||||
":run_and_gather_logs_main_lib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -72,6 +78,7 @@ py_library(
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:platform",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -21,6 +22,9 @@ from __future__ import print_function
|
||||
import ctypes as ct
|
||||
import platform
|
||||
|
||||
import six
|
||||
from six.moves import range
|
||||
|
||||
from tensorflow.core.util import test_log_pb2
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import gfile
|
||||
@ -30,10 +34,11 @@ def _gather_gpu_devices_proc():
|
||||
"""Try to gather NVidia GPU device information via /proc/driver."""
|
||||
dev_info = []
|
||||
for f in gfile.Glob("/proc/driver/nvidia/gpus/*/information"):
|
||||
bus_id = f.split("/")[5]
|
||||
key_values = dict(line.rstrip().replace("\t", "").split(":", 1)
|
||||
for line in gfile.GFile(f, "r"))
|
||||
key_values = dict((k.lower(), v.strip(" ").rstrip(" "))
|
||||
bus_id = six.ensure_str(f).split("/")[5]
|
||||
key_values = dict(
|
||||
six.ensure_str(line.rstrip()).replace("\t", "").split(":", 1)
|
||||
for line in gfile.GFile(f, "r"))
|
||||
key_values = dict((k.lower(), six.ensure_str(v).strip(" ").rstrip(" "))
|
||||
for (k, v) in key_values.items())
|
||||
info = test_log_pb2.GPUInfo()
|
||||
info.model = key_values.get("model", "Unknown")
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -25,9 +26,10 @@ from string import maketrans
|
||||
import sys
|
||||
import time
|
||||
|
||||
import six
|
||||
|
||||
from google.protobuf import json_format
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.core.util import test_log_pb2
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import gfile
|
||||
@ -83,8 +85,9 @@ def main(unused_args):
|
||||
if FLAGS.test_log_output_filename:
|
||||
file_name = FLAGS.test_log_output_filename
|
||||
else:
|
||||
file_name = (name.strip("/").translate(maketrans("/:", "__")) +
|
||||
time.strftime("%Y%m%d%H%M%S", time.gmtime()))
|
||||
file_name = (
|
||||
six.ensure_str(name).strip("/").translate(maketrans("/:", "__")) +
|
||||
time.strftime("%Y%m%d%H%M%S", time.gmtime()))
|
||||
if FLAGS.test_log_output_use_tmpdir:
|
||||
tmpdir = test.get_temp_dir()
|
||||
output_path = os.path.join(tmpdir, FLAGS.test_log_output_dir, file_name)
|
||||
@ -92,7 +95,8 @@ def main(unused_args):
|
||||
output_path = os.path.join(
|
||||
os.path.abspath(FLAGS.test_log_output_dir), file_name)
|
||||
json_test_results = json_format.MessageToJson(test_results)
|
||||
gfile.GFile(output_path + ".json", "w").write(json_test_results)
|
||||
gfile.GFile(six.ensure_str(output_path) + ".json",
|
||||
"w").write(json_test_results)
|
||||
tf_logging.info("Test results written to: %s" % output_path)
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -25,6 +26,8 @@ import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.core.util import test_log_pb2
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.tools.test import gpu_info_lib
|
||||
@ -118,12 +121,15 @@ def run_and_gather_logs(name, test_name, test_args,
|
||||
IOError: If there are problems gathering test log output from the test.
|
||||
MissingLogsError: If we couldn't find benchmark logs.
|
||||
"""
|
||||
if not (test_name and test_name.startswith("//") and ".." not in test_name and
|
||||
not test_name.endswith(":") and not test_name.endswith(":all") and
|
||||
not test_name.endswith("...") and len(test_name.split(":")) == 2):
|
||||
if not (test_name and six.ensure_str(test_name).startswith("//") and
|
||||
".." not in test_name and not six.ensure_str(test_name).endswith(":")
|
||||
and not six.ensure_str(test_name).endswith(":all") and
|
||||
not six.ensure_str(test_name).endswith("...") and
|
||||
len(six.ensure_str(test_name).split(":")) == 2):
|
||||
raise ValueError("Expected test_name parameter with a unique test, e.g.: "
|
||||
"--test_name=//path/to:test")
|
||||
test_executable = test_name.rstrip().strip("/").replace(":", "/")
|
||||
test_executable = six.ensure_str(test_name.rstrip()).strip("/").replace(
|
||||
":", "/")
|
||||
|
||||
if gfile.Exists(os.path.join("bazel-bin", test_executable)):
|
||||
# Running in standalone mode from core of the repository
|
||||
@ -136,14 +142,17 @@ def run_and_gather_logs(name, test_name, test_args,
|
||||
gpu_config = gpu_info_lib.gather_gpu_devices()
|
||||
if gpu_config:
|
||||
gpu_name = gpu_config[0].model
|
||||
gpu_short_name_match = re.search(r"Tesla (K40|K80|P100|V100)", gpu_name)
|
||||
gpu_short_name_match = re.search(r"Tesla (K40|K80|P100|V100)",
|
||||
six.ensure_str(gpu_name))
|
||||
if gpu_short_name_match:
|
||||
gpu_short_name = gpu_short_name_match.group(0)
|
||||
test_adjusted_name = name + "|" + gpu_short_name.replace(" ", "_")
|
||||
test_adjusted_name = six.ensure_str(name) + "|" + gpu_short_name.replace(
|
||||
" ", "_")
|
||||
|
||||
temp_directory = tempfile.mkdtemp(prefix="run_and_gather_logs")
|
||||
mangled_test_name = (test_adjusted_name.strip("/")
|
||||
.replace("|", "_").replace("/", "_").replace(":", "_"))
|
||||
mangled_test_name = (
|
||||
six.ensure_str(test_adjusted_name).strip("/").replace("|", "_").replace(
|
||||
"/", "_").replace(":", "_"))
|
||||
test_file_prefix = os.path.join(temp_directory, mangled_test_name)
|
||||
test_file_prefix = "%s." % test_file_prefix
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Lint as: python2, python3
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -29,6 +30,8 @@ import socket
|
||||
# OSS tree. They are installable via pip.
|
||||
import cpuinfo
|
||||
import psutil
|
||||
|
||||
import six
|
||||
# pylint: enable=g-bad-import-order
|
||||
|
||||
from tensorflow.core.util import test_log_pb2
|
||||
@ -81,7 +84,8 @@ def gather_cpu_info():
|
||||
# Gather num_cores_allowed
|
||||
try:
|
||||
with gfile.GFile('/proc/self/status', 'rb') as fh:
|
||||
nc = re.search(r'(?m)^Cpus_allowed:\s*(.*)$', fh.read().decode('utf-8'))
|
||||
nc = re.search(r'(?m)^Cpus_allowed:\s*(.*)$',
|
||||
six.ensure_text(fh.read(), 'utf-8'))
|
||||
if nc: # e.g. 'ff' => 8, 'fff' => 12
|
||||
cpu_info.num_cores_allowed = (
|
||||
bin(int(nc.group(1).replace(',', ''), 16)).count('1'))
|
||||
|
Loading…
Reference in New Issue
Block a user