diff --git a/tensorflow/tools/api/lib/BUILD b/tensorflow/tools/api/lib/BUILD index 75cb9338d4a..ab48e6970b2 100644 --- a/tensorflow/tools/api/lib/BUILD +++ b/tensorflow/tools/api/lib/BUILD @@ -23,5 +23,6 @@ py_library( ":api_objects_proto_py", "//tensorflow/python:platform", "//tensorflow/python:util", + "@six_archive//:six", ], ) diff --git a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py index 549a11588bb..283f53882c3 100644 --- a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py +++ b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,8 +20,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys import enum +import sys + +import six + from google.protobuf import message from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation @@ -191,7 +195,8 @@ class PythonObjectToProtoVisitor(object): if (_SkipMember(parent, member_name) or isinstance(member_obj, deprecation.HiddenTfApiAttribute)): return - if member_name == '__init__' or not member_name.startswith('_'): + if member_name == '__init__' or not six.ensure_str( + member_name).startswith('_'): if tf_inspect.isroutine(member_obj): new_method = proto.member_method.add() new_method.name = member_name diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD index dc21877f060..5cc9fa67e6d 100644 --- a/tensorflow/tools/common/BUILD +++ b/tensorflow/tools/common/BUILD @@ -14,13 +14,16 @@ py_library( name = "public_api", srcs = ["public_api.py"], srcs_version = "PY2AND3", - deps = ["//tensorflow/python:util"], + deps = [ + "//tensorflow/python:util", + "@six_archive//:six", + ], ) py_test( name = "public_api_test", srcs = ["public_api_test.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", deps = [ ":public_api", @@ -32,13 +35,16 @@ py_library( name = "traverse", srcs = ["traverse.py"], srcs_version = "PY2AND3", - deps = ["//tensorflow/python:util"], + deps = [ + "//tensorflow/python:util", + "@six_archive//:six", + ], ) py_test( name = "traverse_test", srcs = ["traverse_test.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", deps = [ ":test_module1", diff --git a/tensorflow/tools/common/public_api.py b/tensorflow/tools/common/public_api.py index f62f9bb54f1..0f788290b26 100644 --- a/tensorflow/tools/common/public_api.py +++ b/tensorflow/tools/common/public_api.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +21,8 @@ from __future__ import print_function import re +import six + from tensorflow.python.util import tf_inspect @@ -108,9 +111,9 @@ class PublicAPIVisitor(object): """Return whether a name is private.""" # TODO(wicke): Find out what names to exclude. del obj # Unused. - return ((path in self._private_map and - name in self._private_map[path]) or - (name.startswith('_') and not re.match('__.*__$', name) or + return ((path in self._private_map and name in self._private_map[path]) or + (six.ensure_str(name).startswith('_') and + not re.match('__.*__$', six.ensure_str(name)) or name in ['__base__', '__class__'])) def _do_not_descend(self, path, name): @@ -122,7 +125,8 @@ class PublicAPIVisitor(object): """Visitor interface, see `traverse` for details.""" # Avoid long waits in cases of pretty unambiguous failure. - if tf_inspect.ismodule(parent) and len(path.split('.')) > 10: + if tf_inspect.ismodule(parent) and len( + six.ensure_str(path).split('.')) > 10: raise RuntimeError('Modules nested too deep:\n%s.%s\n\nThis is likely a ' 'problem with an accidental public import.' % (self._root_name, path)) diff --git a/tensorflow/tools/common/traverse.py b/tensorflow/tools/common/traverse.py index b121a87062f..5efce450dcb 100644 --- a/tensorflow/tools/common/traverse.py +++ b/tensorflow/tools/common/traverse.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,6 +22,8 @@ from __future__ import print_function import enum import sys +import six + from tensorflow.python.util import tf_inspect __all__ = ['traverse'] @@ -59,7 +62,8 @@ def _traverse_internal(root, visit, stack, path): if any(child is item for item in new_stack): # `in`, but using `is` continue - child_path = path + '.' + name if path else name + child_path = six.ensure_str(path) + '.' + six.ensure_str( + name) if path else name _traverse_internal(child, visit, new_stack, child_path) diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD index 5a50d77b010..d6d882b5749 100644 --- a/tensorflow/tools/compatibility/BUILD +++ b/tensorflow/tools/compatibility/BUILD @@ -14,6 +14,7 @@ py_library( name = "ipynb", srcs = ["ipynb.py"], srcs_version = "PY2AND3", + deps = ["@six_archive//:six"], ) py_library( @@ -29,7 +30,7 @@ py_library( py_test( name = "ast_edits_test", srcs = ["ast_edits_test.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", deps = [ ":ast_edits", @@ -42,22 +43,28 @@ py_test( py_binary( name = "tf_upgrade", srcs = ["tf_upgrade.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", - deps = [":tf_upgrade_lib"], + deps = [ + ":tf_upgrade_lib", + "@six_archive//:six", + ], ) py_library( name = "tf_upgrade_lib", srcs = ["tf_upgrade.py"], srcs_version = "PY2AND3", - deps = [":ast_edits"], + deps = [ + ":ast_edits", + "@six_archive//:six", + ], ) py_test( name = "tf_upgrade_test", srcs = ["tf_upgrade_test.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", tags = [ "no_pip", @@ -96,6 +103,7 @@ py_library( py_test( name = "all_renames_v2_test", srcs = ["all_renames_v2_test.py"], + python_version = "PY3", srcs_version = "PY2AND3", deps = [ ":all_renames_v2", @@ -108,6 +116,7 @@ py_test( py_library( name = "module_deprecations_v2", srcs = ["module_deprecations_v2.py"], + srcs_version = "PY2AND3", deps = [":ast_edits"], ) @@ -145,13 +154,14 @@ py_binary( ":ipynb", ":tf_upgrade_v2_lib", ":tf_upgrade_v2_safety_lib", + "@six_archive//:six", ], ) py_test( name = "tf_upgrade_v2_test", srcs = ["tf_upgrade_v2_test.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", tags = ["v1only"], deps = [ @@ -169,6 +179,7 @@ py_test( py_test( name = "tf_upgrade_v2_safety_test", srcs = ["tf_upgrade_v2_safety_test.py"], + python_version = "PY3", srcs_version = "PY2AND3", deps = [ ":tf_upgrade_v2_safety_lib", @@ -208,7 +219,7 @@ py_test( name = "test_file_v1_0", size = "small", srcs = ["test_file_v1_0.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", @@ -235,7 +246,7 @@ py_test( name = "test_file_v1_12", size = "small", srcs = ["testdata/test_file_v1_12.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", tags = ["v1only"], deps = [ @@ -247,7 +258,7 @@ py_test( name = "test_file_v2_0", size = "small", srcs = ["test_file_v2_0.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", diff --git a/tensorflow/tools/compatibility/ast_edits.py b/tensorflow/tools/compatibility/ast_edits.py index d6129da71dd..71fb2aee770 100644 --- a/tensorflow/tools/compatibility/ast_edits.py +++ b/tensorflow/tools/compatibility/ast_edits.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -29,6 +30,7 @@ import traceback import pasta import six +from six.moves import range # Some regular expressions we will need for parsing FIND_OPEN = re.compile(r"^\s*(\[).*$") @@ -56,7 +58,7 @@ def full_name_node(name, ctx=ast.Load()): Returns: A Name or Attribute node. """ - names = name.split(".") + names = six.ensure_str(name).split(".") names.reverse() node = ast.Name(id=names.pop(), ctx=ast.Load()) while names: @@ -301,7 +303,7 @@ class _PastaEditVisitor(ast.NodeVisitor): function_transformers = getattr(self._api_change_spec, transformer_field, {}) - glob_name = "*." + name if name else None + glob_name = "*." + six.ensure_str(name) if name else None transformers = [] if full_name in function_transformers: transformers.append(function_transformers[full_name]) @@ -318,7 +320,7 @@ class _PastaEditVisitor(ast.NodeVisitor): function_transformers = getattr(self._api_change_spec, transformer_field, {}) - glob_name = "*." + name if name else None + glob_name = "*." + six.ensure_str(name) if name else None transformers = function_transformers.get("*", {}).copy() transformers.update(function_transformers.get(glob_name, {})) transformers.update(function_transformers.get(full_name, {})) @@ -351,7 +353,7 @@ class _PastaEditVisitor(ast.NodeVisitor): function_warnings = self._api_change_spec.function_warnings if full_name in function_warnings: level, message = function_warnings[full_name] - message = message.replace("", full_name) + message = six.ensure_str(message).replace("", full_name) self.add_log(level, node.lineno, node.col_offset, "%s requires manual check. %s" % (full_name, message)) return True @@ -363,7 +365,8 @@ class _PastaEditVisitor(ast.NodeVisitor): warnings = self._api_change_spec.module_deprecations if full_name in warnings: level, message = warnings[full_name] - message = message.replace("", whole_name) + message = six.ensure_str(message).replace("", + six.ensure_str(whole_name)) self.add_log(level, node.lineno, node.col_offset, "Using member %s in deprecated module %s. %s" % (whole_name, full_name, @@ -394,7 +397,7 @@ class _PastaEditVisitor(ast.NodeVisitor): # an attribute. warned = False if isinstance(node.func, ast.Attribute): - warned = self._maybe_add_warning(node, "*." + name) + warned = self._maybe_add_warning(node, "*." + six.ensure_str(name)) # All arg warnings are handled here, since only we have the args arg_warnings = self._get_applicable_dict("function_arg_warnings", @@ -406,7 +409,8 @@ class _PastaEditVisitor(ast.NodeVisitor): present, _ = get_arg_value(node, kwarg, arg) or variadic_args if present: warned = True - warning_message = warning.replace("", full_name or name) + warning_message = six.ensure_str(warning).replace( + "", six.ensure_str(full_name or name)) template = "%s called with %s argument, requires manual check: %s" if variadic_args: template = ("%s called with *args or **kwargs that may include %s, " @@ -625,13 +629,13 @@ class _PastaEditVisitor(ast.NodeVisitor): # This loop processes imports in the format # import foo as f, bar as b for import_alias in node.names: - all_import_components = import_alias.name.split(".") + all_import_components = six.ensure_str(import_alias.name).split(".") # Look for rename, starting with longest import levels. found_update = False - for i in reversed(range(1, max_submodule_depth + 1)): + for i in reversed(list(range(1, max_submodule_depth + 1))): import_component = all_import_components[0] for j in range(1, min(i, len(all_import_components))): - import_component += "." + all_import_components[j] + import_component += "." + six.ensure_str(all_import_components[j]) import_rename_spec = import_renames.get(import_component, None) if not import_rename_spec or excluded_from_module_rename( @@ -674,7 +678,8 @@ class _PastaEditVisitor(ast.NodeVisitor): if old_suffix is None: old_suffix = os.linesep if os.linesep not in old_suffix: - pasta.base.formatting.set(node, "suffix", old_suffix + os.linesep) + pasta.base.formatting.set(node, "suffix", + six.ensure_str(old_suffix) + os.linesep) # Apply indentation to new node. pasta.base.formatting.set(new_line_node, "prefix", @@ -720,7 +725,7 @@ class _PastaEditVisitor(ast.NodeVisitor): # Look for rename based on first component of from-import. # i.e. based on foo in foo.bar. - from_import_first_component = from_import.split(".")[0] + from_import_first_component = six.ensure_str(from_import).split(".")[0] import_renames = getattr(self._api_change_spec, "import_renames", {}) import_rename_spec = import_renames.get(from_import_first_component, None) if not import_rename_spec: @@ -918,7 +923,7 @@ class ASTCodeUpgrader(object): def format_log(self, log, in_filename): log_string = "%d:%d: %s: %s" % (log[1], log[2], log[0], log[3]) if in_filename: - return in_filename + ":" + log_string + return six.ensure_str(in_filename) + ":" + log_string else: return log_string @@ -945,12 +950,12 @@ class ASTCodeUpgrader(object): return 1, pasta.dump(t), logs, errors def _format_log(self, log, in_filename, out_filename): - text = "-" * 80 + "\n" + text = six.ensure_str("-" * 80) + "\n" text += "Processing file %r\n outputting to %r\n" % (in_filename, out_filename) - text += "-" * 80 + "\n\n" + text += six.ensure_str("-" * 80) + "\n\n" text += "\n".join(log) + "\n" - text += "-" * 80 + "\n\n" + text += six.ensure_str("-" * 80) + "\n\n" return text def process_opened_file(self, in_filename, in_file, out_filename, out_file): @@ -1017,8 +1022,10 @@ class ASTCodeUpgrader(object): files_to_process = [] files_to_copy = [] for dir_name, _, file_list in os.walk(root_directory): - py_files = [f for f in file_list if f.endswith(".py")] - copy_files = [f for f in file_list if not f.endswith(".py")] + py_files = [f for f in file_list if six.ensure_str(f).endswith(".py")] + copy_files = [ + f for f in file_list if not six.ensure_str(f).endswith(".py") + ] for filename in py_files: fullpath = os.path.join(dir_name, filename) fullpath_output = os.path.join(output_root_directory, @@ -1036,9 +1043,9 @@ class ASTCodeUpgrader(object): file_count = 0 tree_errors = {} report = "" - report += ("=" * 80) + "\n" + report += six.ensure_str(("=" * 80)) + "\n" report += "Input tree: %r\n" % root_directory - report += ("=" * 80) + "\n" + report += six.ensure_str(("=" * 80)) + "\n" for input_path, output_path in files_to_process: output_directory = os.path.dirname(output_path) @@ -1074,16 +1081,19 @@ class ASTCodeUpgrader(object): """Process a directory of python files in place.""" files_to_process = [] for dir_name, _, file_list in os.walk(root_directory): - py_files = [os.path.join(dir_name, - f) for f in file_list if f.endswith(".py")] + py_files = [ + os.path.join(dir_name, f) + for f in file_list + if six.ensure_str(f).endswith(".py") + ] files_to_process += py_files file_count = 0 tree_errors = {} report = "" - report += ("=" * 80) + "\n" + report += six.ensure_str(("=" * 80)) + "\n" report += "Input tree: %r\n" % root_directory - report += ("=" * 80) + "\n" + report += six.ensure_str(("=" * 80)) + "\n" for path in files_to_process: if os.path.islink(path): diff --git a/tensorflow/tools/compatibility/ipynb.py b/tensorflow/tools/compatibility/ipynb.py index fed5f0f2bfc..77c8fedf709 100644 --- a/tensorflow/tools/compatibility/ipynb.py +++ b/tensorflow/tools/compatibility/ipynb.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,6 +25,7 @@ import json import re import shutil import tempfile +import six CodeLine = collections.namedtuple("CodeLine", ["cell_number", "code"]) @@ -31,7 +33,8 @@ def is_python(cell): """Checks if the cell consists of Python code.""" return (cell["cell_type"] == "code" # code cells only and cell["source"] # non-empty cells - and not cell["source"][0].startswith("%%")) # multiline eg: %%bash + and not six.ensure_str(cell["source"][0]).startswith("%%") + ) # multiline eg: %%bash def process_file(in_filename, out_filename, upgrader): @@ -47,8 +50,9 @@ def process_file(in_filename, out_filename, upgrader): upgrader.update_string_pasta("\n".join(raw_lines), in_filename)) if temp_file and processed_file: - new_notebook = _update_notebook(notebook, raw_code, - new_file_content.split("\n")) + new_notebook = _update_notebook( + notebook, raw_code, + six.ensure_str(new_file_content).split("\n")) json.dump(new_notebook, temp_file) else: raise SyntaxError( @@ -78,7 +82,7 @@ def skip_magic(code_line, magic_list): """ for magic in magic_list: - if code_line.startswith(magic): + if six.ensure_str(code_line).startswith(magic): return True return False @@ -120,7 +124,7 @@ def _get_code(input_file): # Idea is to comment these lines, for upgrade time if skip_magic(code_line, ["%", "!", "?"]) or is_line_split: # Found a special character, need to "encode" - code_line = "###!!!" + code_line + code_line = "###!!!" + six.ensure_str(code_line) # if this cell ends with `\` -> skip the next line is_line_split = check_line_split(code_line) @@ -131,14 +135,16 @@ def _get_code(input_file): # Sometimes, people leave \n at the end of cell # in order to migrate only related things, and make the diff # the smallest -> here is another hack - if (line_idx == len(cell_lines) - 1) and code_line.endswith("\n"): - code_line = code_line.replace("\n", "###===") + if (line_idx == len(cell_lines) - + 1) and six.ensure_str(code_line).endswith("\n"): + code_line = six.ensure_str(code_line).replace("\n", "###===") # sometimes a line would start with `\n` and content after # that's the hack for this raw_code.append( CodeLine(cell_index, - code_line.rstrip().replace("\n", "###==="))) + six.ensure_str(code_line.rstrip()).replace("\n", + "###==="))) cell_index += 1 diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py index 5dd548c8214..988d11b1016 100644 --- a/tensorflow/tools/compatibility/tf_upgrade.py +++ b/tensorflow/tools/compatibility/tf_upgrade.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +21,8 @@ from __future__ import print_function import argparse +import six + from tensorflow.tools.compatibility import ast_edits @@ -245,7 +248,7 @@ Simple usage: else: parser.print_help() if report_text: - open(report_filename, "w").write(report_text) + open(report_filename, "w").write(six.ensure_str(report_text)) print("TensorFlow 1.0 Upgrade Script") print("-----------------------------") print("Converted %d files\n" % files_processed) diff --git a/tensorflow/tools/compatibility/tf_upgrade_test.py b/tensorflow/tools/compatibility/tf_upgrade_test.py index cf05575a9dd..ca0e80564ff 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_test.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -46,14 +47,16 @@ class TestUpgrade(test_util.TensorFlowTestCase): def testParseError(self): _, report, unused_errors, unused_new_text = self._upgrade( "import tensorflow as tf\na + \n") - self.assertTrue(report.find("Failed to parse") != -1) + self.assertNotEqual(six.ensure_str(report).find("Failed to parse"), -1) def testReport(self): text = "tf.mul(a, b)\n" _, report, unused_errors, unused_new_text = self._upgrade(text) # This is not a complete test, but it is a sanity test that a report # is generating information. - self.assertTrue(report.find("Renamed function `tf.mul` to `tf.multiply`")) + self.assertTrue( + six.ensure_str(report).find( + "Renamed function `tf.mul` to `tf.multiply`")) def testRename(self): text = "tf.mul(a, tf.sub(b, c))\n" diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index 221353d87cd..2bd5bb984c8 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,6 +25,7 @@ import functools import sys import pasta +import six from tensorflow.tools.compatibility import all_renames_v2 from tensorflow.tools.compatibility import ast_edits @@ -47,8 +49,9 @@ class VersionedTFImport(ast_edits.AnalysisResult): def __init__(self, version): self.log_level = ast_edits.INFO - self.log_message = ("Not upgrading symbols because `tensorflow." + version - + "` was directly imported as `tf`.") + self.log_message = ("Not upgrading symbols because `tensorflow." + + six.ensure_str(version) + + "` was directly imported as `tf`.") class TFAPIImportAnalysisSpec(ast_edits.APIAnalysisSpec): @@ -1687,7 +1690,7 @@ def _rename_if_arg_found_transformer(parent, node, full_name, name, logs, # All conditions met, insert v1 and log what we did. # We must have a full name, so the func is an attribute. - new_name = full_name.replace("tf.", "tf.compat.v1.", 1) + new_name = six.ensure_str(full_name).replace("tf.", "tf.compat.v1.", 1) node.func = ast_edits.full_name_node(new_name) logs.append(( ast_edits.INFO, node.lineno, node.col_offset, @@ -1715,8 +1718,8 @@ def _iterator_transformer(parent, node, full_name, name, logs): # (tf.compat.v1.data), or something which is handled in the rename # (tf.data). This transformer only handles the method call to function call # conversion. - if full_name and (full_name.startswith("tf.compat.v1.data") or - full_name.startswith("tf.data")): + if full_name and (six.ensure_str(full_name).startswith("tf.compat.v1.data") or + six.ensure_str(full_name).startswith("tf.data")): return # This should never happen, since we're only called for Attribute nodes. @@ -2460,7 +2463,7 @@ def _name_scope_transformer(parent, node, full_name, name, logs): def _rename_to_compat_v1(node, full_name, logs, reason): - new_name = full_name.replace("tf.", "tf.compat.v1.", 1) + new_name = six.ensure_str(full_name).replace("tf.", "tf.compat.v1.", 1) return _rename_func(node, full_name, new_name, logs, reason) diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_main.py b/tensorflow/tools/compatibility/tf_upgrade_v2_main.py index 7288e171d72..e41577926a5 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_main.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_main.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,10 +21,13 @@ from __future__ import print_function import argparse +import six + from tensorflow.tools.compatibility import ast_edits from tensorflow.tools.compatibility import ipynb from tensorflow.tools.compatibility import tf_upgrade_v2 from tensorflow.tools.compatibility import tf_upgrade_v2_safety + # Make straightforward changes to convert to 2.0. In harder cases, # use compat.v1. _DEFAULT_MODE = "DEFAULT" @@ -35,10 +39,10 @@ _SAFETY_MODE = "SAFETY" def process_file(in_filename, out_filename, upgrader): """Process a file of type `.py` or `.ipynb`.""" - if in_filename.endswith(".py"): + if six.ensure_str(in_filename).endswith(".py"): files_processed, report_text, errors = \ upgrader.process_file(in_filename, out_filename) - elif in_filename.endswith(".ipynb"): + elif six.ensure_str(in_filename).endswith(".ipynb"): files_processed, report_text, errors = \ ipynb.process_file(in_filename, out_filename, upgrader) else: @@ -157,24 +161,24 @@ Simple usage: for f in errors: if errors[f]: num_errors += len(errors[f]) - report.append("-" * 80 + "\n") + report.append(six.ensure_str("-" * 80) + "\n") report.append("File: %s\n" % f) - report.append("-" * 80 + "\n") + report.append(six.ensure_str("-" * 80) + "\n") report.append("\n".join(errors[f]) + "\n") report = ("TensorFlow 2.0 Upgrade Script\n" "-----------------------------\n" "Converted %d files\n" % files_processed + "Detected %d issues that require attention" % num_errors + "\n" + - "-" * 80 + "\n") + "".join(report) - detailed_report_header = "=" * 80 + "\n" + six.ensure_str("-" * 80) + "\n") + "".join(report) + detailed_report_header = six.ensure_str("=" * 80) + "\n" detailed_report_header += "Detailed log follows:\n\n" - detailed_report_header += "=" * 80 + "\n" + detailed_report_header += six.ensure_str("=" * 80) + "\n" with open(report_filename, "w") as report_file: report_file.write(report) report_file.write(detailed_report_header) - report_file.write(report_text) + report_file.write(six.ensure_str(report_text)) if args.print_all: print(report) diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py index 1249ce7cc2f..ca915109659 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -39,7 +40,7 @@ from tensorflow.tools.compatibility import tf_upgrade_v2 def get_symbol_for_name(root, name): - name_parts = name.split(".") + name_parts = six.ensure_str(name).split(".") symbol = root # Iterate starting with second item since 1st item is "tf.". for part in name_parts[1:]: @@ -66,12 +67,13 @@ def get_func_and_args_from_str(call_str): Returns: (function_name, list of arg names) tuple. """ - open_paren_index = call_str.find("(") + open_paren_index = six.ensure_str(call_str).find("(") close_paren_index = call_str.rfind(")") - function_name = call_str[:call_str.find("(")] - args = call_str[open_paren_index+1:close_paren_index].split(",") - args = [arg.split("=")[0].strip() for arg in args] + function_name = call_str[:six.ensure_str(call_str).find("(")] + args = six.ensure_str(call_str[open_paren_index + + 1:close_paren_index]).split(",") + args = [six.ensure_str(arg).split("=")[0].strip() for arg in args] args = [arg for arg in args if arg] # filter out empty strings return function_name, args @@ -96,7 +98,7 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase): _, attr = tf_decorator.unwrap(child[1]) api_names_v2 = tf_export.get_v2_names(attr) for name in api_names_v2: - cls.v2_symbols["tf." + name] = attr + cls.v2_symbols["tf." + six.ensure_str(name)] = attr visitor = public_api.PublicAPIVisitor(symbol_collector) visitor.private_map["tf.compat"] = ["v1"] @@ -109,7 +111,7 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase): _, attr = tf_decorator.unwrap(child[1]) api_names_v1 = tf_export.get_v1_names(attr) for name in api_names_v1: - cls.v1_symbols["tf." + name] = attr + cls.v1_symbols["tf." + six.ensure_str(name)] = attr visitor = public_api.PublicAPIVisitor(symbol_collector_v1) traverse.traverse(tf.compat.v1, visitor) @@ -138,15 +140,16 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase): def testParseError(self): _, report, unused_errors, unused_new_text = self._upgrade( "import tensorflow as tf\na + \n") - self.assertTrue(report.find("Failed to parse") != -1) + self.assertNotEqual(six.ensure_str(report).find("Failed to parse"), -1) def testReport(self): text = "tf.angle(a)\n" _, report, unused_errors, unused_new_text = self._upgrade(text) # This is not a complete test, but it is a sanity test that a report # is generating information. - self.assertTrue(report.find("Renamed function `tf.angle` to " - "`tf.math.angle`")) + self.assertTrue( + six.ensure_str(report).find("Renamed function `tf.angle` to " + "`tf.math.angle`")) def testRename(self): text = "tf.conj(a)\n" @@ -169,7 +172,7 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase): _, attr = tf_decorator.unwrap(child[1]) api_names = tf_export.get_v1_names(attr) for name in api_names: - _, _, _, text = self._upgrade("tf." + name) + _, _, _, text = self._upgrade("tf." + six.ensure_str(name)) if (text and not text.startswith("tf.compat.v1") and not text.startswith("tf.compat.v2") and @@ -198,9 +201,9 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase): api_names = tf_export.get_v1_names(attr) for name in api_names: if collect: - v1_symbols.add("tf." + name) + v1_symbols.add("tf." + six.ensure_str(name)) else: - _, _, _, text = self._upgrade("tf." + name) + _, _, _, text = self._upgrade("tf." + six.ensure_str(name)) if (text and not text.startswith("tf.compat.v1") and not text.startswith("tf.compat.v2") and @@ -337,16 +340,16 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase): def testPositionsMatchArgGiven(self): full_dict = tf_upgrade_v2.TFAPIChangeSpec().function_arg_warnings - method_names = full_dict.keys() + method_names = list(full_dict.keys()) for method_name in method_names: - args = full_dict[method_name].keys() + args = list(full_dict[method_name].keys()) if "contrib" in method_name: # Skip descending and fetching contrib methods during test. These are # not available in the repo anymore. continue - elif method_name.startswith("*."): + elif six.ensure_str(method_name).startswith("*."): # special case for optimizer methods - method = method_name.replace("*", "tf.train.Optimizer") + method = six.ensure_str(method_name).replace("*", "tf.train.Optimizer") else: method = method_name @@ -354,7 +357,7 @@ class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase): arg_spec = tf_inspect.getfullargspec(method) for (arg, pos) in args: # to deal with the self argument on methods on objects - if method_name.startswith("*."): + if six.ensure_str(method_name).startswith("*."): pos += 1 self.assertEqual(arg_spec[0][pos], arg) diff --git a/tensorflow/tools/compatibility/update/BUILD b/tensorflow/tools/compatibility/update/BUILD index 5a74271e882..5f40406c689 100644 --- a/tensorflow/tools/compatibility/update/BUILD +++ b/tensorflow/tools/compatibility/update/BUILD @@ -6,7 +6,7 @@ package( py_binary( name = "generate_v2_renames_map", srcs = ["generate_v2_renames_map.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", @@ -15,13 +15,14 @@ py_binary( "//tensorflow/tools/common:public_api", "//tensorflow/tools/common:traverse", "//tensorflow/tools/compatibility:all_renames_v2", + "@six_archive//:six", ], ) py_binary( name = "generate_v2_reorders_map", srcs = ["generate_v2_reorders_map.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", diff --git a/tensorflow/tools/compatibility/update/generate_v2_renames_map.py b/tensorflow/tools/compatibility/update/generate_v2_renames_map.py index 6761fa6ae3d..6cdc4972bd5 100644 --- a/tensorflow/tools/compatibility/update/generate_v2_renames_map.py +++ b/tensorflow/tools/compatibility/update/generate_v2_renames_map.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,9 +24,9 @@ To update renames_v2.py, run: # pylint: enable=line-too-long import sys +import six import tensorflow as tf -# This import is needed so that TensorFlow python modules are in sys.modules. from tensorflow import python as tf_python # pylint: disable=unused-import from tensorflow.python.lib.io import file_io from tensorflow.python.platform import app @@ -35,6 +36,7 @@ from tensorflow.tools.common import public_api from tensorflow.tools.common import traverse from tensorflow.tools.compatibility import all_renames_v2 +# This import is needed so that TensorFlow python modules are in sys.modules. _OUTPUT_FILE_PATH = 'third_party/tensorflow/tools/compatibility/renames_v2.py' _FILE_HEADER = """# Copyright 2018 The TensorFlow Authors. All Rights Reserved. @@ -178,7 +180,8 @@ def update_renames_v2(output_file_path): rename_lines = [ get_rename_line(name, canonical_name) for name, canonical_name in all_renames - if 'tf.' + name not in manual_renames] + if 'tf.' + six.ensure_str(name) not in manual_renames + ] renames_file_text = '%srenames = {\n%s\n}\n' % ( _FILE_HEADER, ',\n'.join(sorted(rename_lines))) file_io.write_string_to_file(output_file_path, renames_file_text) diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index e4806027a91..68f04f20dc3 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -31,6 +31,7 @@ py_library( "doc_generator_visitor.py", ], srcs_version = "PY2AND3", + deps = ["@six_archive//:six"], ) py_test( @@ -39,12 +40,13 @@ py_test( srcs = [ "doc_generator_visitor_test.py", ], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", deps = [ ":doc_generator_visitor", ":generate_lib", "//tensorflow/python:platform_test", + "@six_archive//:six", ], ) @@ -59,7 +61,7 @@ py_test( name = "doc_controls_test", size = "small", srcs = ["doc_controls_test.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", deps = [ ":doc_controls", @@ -77,6 +79,7 @@ py_library( "//tensorflow/python:platform", "//tensorflow/python:util", "@astor_archive//:astor", + "@six_archive//:six", ], ) @@ -84,11 +87,12 @@ py_test( name = "parser_test", size = "small", srcs = ["parser_test.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", deps = [ ":parser", "//tensorflow/python:platform_test", + "@six_archive//:six", ], ) @@ -96,6 +100,7 @@ py_library( name = "pretty_docs", srcs = ["pretty_docs.py"], srcs_version = "PY2AND3", + deps = ["@six_archive//:six"], ) py_library( @@ -120,7 +125,7 @@ py_test( name = "generate_lib_test", size = "small", srcs = ["generate_lib_test.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", deps = [ ":generate_lib", @@ -132,7 +137,7 @@ py_test( py_binary( name = "generate", srcs = ["generate.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", deps = [ ":generate_lib", @@ -165,9 +170,12 @@ py_test( py_binary( name = "generate2", srcs = ["generate2.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", - deps = [":generate2_lib"], + deps = [ + ":generate2_lib", + "@six_archive//:six", + ], ) py_library( @@ -184,16 +192,18 @@ py_library( name = "py_guide_parser", srcs = ["py_guide_parser.py"], srcs_version = "PY2AND3", + deps = ["@six_archive//:six"], ) py_test( name = "py_guide_parser_test", size = "small", srcs = ["py_guide_parser_test.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", deps = [ ":py_guide_parser", "//tensorflow/python:client_testlib", + "@six_archive//:six", ], ) diff --git a/tensorflow/tools/docs/doc_generator_visitor.py b/tensorflow/tools/docs/doc_generator_visitor.py index 6157eb1b7fc..ec2102a5935 100644 --- a/tensorflow/tools/docs/doc_generator_visitor.py +++ b/tensorflow/tools/docs/doc_generator_visitor.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -48,7 +49,7 @@ class DocGeneratorVisitor(object): def set_root_name(self, root_name): """Sets the root name for subsequent __call__s.""" self._root_name = root_name or '' - self._prefix = (root_name + '.') if root_name else '' + self._prefix = (six.ensure_str(root_name) + '.') if root_name else '' @property def index(self): @@ -178,7 +179,7 @@ class DocGeneratorVisitor(object): A tuple of scores. When sorted the preferred name will have the lowest value. """ - parts = name.split('.') + parts = six.ensure_str(name).split('.') short_name = parts[-1] container = self._index['.'.join(parts[:-1])] diff --git a/tensorflow/tools/docs/doc_generator_visitor_test.py b/tensorflow/tools/docs/doc_generator_visitor_test.py index 1c2635d4a8c..29ec1f8437d 100644 --- a/tensorflow/tools/docs/doc_generator_visitor_test.py +++ b/tensorflow/tools/docs/doc_generator_visitor_test.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +21,8 @@ from __future__ import print_function import types +import six + from tensorflow.python.platform import googletest from tensorflow.tools.docs import doc_generator_visitor from tensorflow.tools.docs import generate_lib @@ -29,9 +32,9 @@ class NoDunderVisitor(doc_generator_visitor.DocGeneratorVisitor): def __call__(self, parent_name, parent, children): """Drop all the dunder methods to make testing easier.""" - children = [ - (name, obj) for (name, obj) in children if not name.startswith('_') - ] + children = [(name, obj) + for (name, obj) in children + if not six.ensure_str(name).startswith('_')] super(NoDunderVisitor, self).__call__(parent_name, parent, children) diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py index 18d3a8349e8..77a685062ae 100644 --- a/tensorflow/tools/docs/generate_lib.py +++ b/tensorflow/tools/docs/generate_lib.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -147,8 +148,10 @@ def write_docs(output_dir, duplicates = [item for item in duplicates if item != full_name] for dup in duplicates: - from_path = os.path.join(site_api_path, dup.replace('.', '/')) - to_path = os.path.join(site_api_path, full_name.replace('.', '/')) + from_path = os.path.join(site_api_path, + six.ensure_str(dup).replace('.', '/')) + to_path = os.path.join(site_api_path, + six.ensure_str(full_name).replace('.', '/')) redirects.append(( os.path.join('/', from_path), os.path.join('/', to_path))) @@ -167,7 +170,7 @@ def write_docs(output_dir, # Generate table of contents # Put modules in alphabetical order, case-insensitive - modules = sorted(module_children.keys(), key=lambda a: a.upper()) + modules = sorted(list(module_children.keys()), key=lambda a: a.upper()) leftnav_path = os.path.join(output_dir, '_toc.yaml') with open(leftnav_path, 'w') as f: @@ -183,16 +186,15 @@ def write_docs(output_dir, if indent_num > 1: # tf.contrib.baysflow.entropy will be under # tf.contrib->baysflow->entropy - title = module.split('.')[-1] + title = six.ensure_str(module).split('.')[-1] else: title = module header = [ - '- title: ' + title, - ' section:', - ' - title: Overview', - ' path: ' + os.path.join('/', site_api_path, - symbol_to_file[module])] + '- title: ' + six.ensure_str(title), ' section:', + ' - title: Overview', ' path: ' + + os.path.join('/', site_api_path, symbol_to_file[module]) + ] header = ''.join([indent+line+'\n' for line in header]) f.write(header) @@ -211,8 +213,9 @@ def write_docs(output_dir, # Write a global index containing all full names with links. with open(os.path.join(output_dir, 'index.md'), 'w') as f: f.write( - parser.generate_global_index(root_title, parser_config.index, - parser_config.reference_resolver)) + six.ensure_str( + parser.generate_global_index(root_title, parser_config.index, + parser_config.reference_resolver))) def add_dict_to_dict(add_from, add_to): @@ -345,7 +348,7 @@ def build_doc_index(src_dir): for dirpath, _, filenames in os.walk(src_dir): suffix = os.path.relpath(path=dirpath, start=src_dir) for base_name in filenames: - if not base_name.endswith('.md'): + if not six.ensure_str(base_name).endswith('.md'): continue title_parser = _GetMarkdownTitle() title_parser.process(os.path.join(dirpath, base_name)) @@ -353,7 +356,8 @@ def build_doc_index(src_dir): msg = ('`{}` has no markdown title (# title)'.format( os.path.join(dirpath, base_name))) raise ValueError(msg) - key_parts = os.path.join(suffix, base_name[:-3]).split('/') + key_parts = six.ensure_str(os.path.join(suffix, + base_name[:-3])).split('/') if key_parts[-1] == 'index': key_parts = key_parts[:-1] doc_info = _DocInfo(os.path.join(suffix, base_name), title_parser.title) @@ -367,8 +371,8 @@ def build_doc_index(src_dir): class _GuideRef(object): def __init__(self, base_name, title, section_title, section_tag): - self.url = 'api_guides/python/' + (('%s#%s' % (base_name, section_tag)) - if section_tag else base_name) + self.url = 'api_guides/python/' + six.ensure_str( + (('%s#%s' % (base_name, section_tag)) if section_tag else base_name)) self.link_text = (('%s > %s' % (title, section_title)) if section_title else title) @@ -447,7 +451,7 @@ def update_id_tags_inplace(src_dir): # modified file contents content = tag_updater.process(full_path) with open(full_path, 'w') as f: - f.write(content) + f.write(six.ensure_str(content)) EXCLUDED = set(['__init__.py', 'OWNERS', 'README.txt']) @@ -512,7 +516,7 @@ def replace_refs(src_dir, content = reference_resolver.replace_references(content, relative_path_to_root) with open(full_out_path, 'wb') as f: - f.write(content.encode('utf-8')) + f.write(six.ensure_binary(content, 'utf-8')) class DocGenerator(object): diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py index d87f9585f20..d6426cfd6de 100644 --- a/tensorflow/tools/docs/parser.py +++ b/tensorflow/tools/docs/parser.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +29,7 @@ import re import astor import six +from six.moves import zip from google.protobuf.message import Message as ProtoMessage from tensorflow.python.platform import tf_logging as logging @@ -50,7 +52,7 @@ def is_free_function(py_object, full_name, index): if not tf_inspect.isfunction(py_object): return False - parent_name = full_name.rsplit('.', 1)[0] + parent_name = six.ensure_str(full_name).rsplit('.', 1)[0] if tf_inspect.isclass(index[parent_name]): return False @@ -112,14 +114,14 @@ def documentation_path(full_name, is_fragment=False): Returns: The file path to which to write the documentation for `full_name`. """ - parts = full_name.split('.') + parts = six.ensure_str(full_name).split('.') if is_fragment: parts, fragment = parts[:-1], parts[-1] - result = os.path.join(*parts) + '.md' + result = six.ensure_str(os.path.join(*parts)) + '.md' if is_fragment: - result = result + '#' + fragment + result = six.ensure_str(result) + '#' + six.ensure_str(fragment) return result @@ -288,7 +290,7 @@ class ReferenceResolver(object): self.add_error(e.message) return 'BAD_LINK' - string = re.sub(SYMBOL_REFERENCE_RE, strict_one_ref, string) + string = re.sub(SYMBOL_REFERENCE_RE, strict_one_ref, six.ensure_str(string)) def sloppy_one_ref(match): try: @@ -333,7 +335,7 @@ class ReferenceResolver(object): @staticmethod def _link_text_to_html(link_text): code_re = '`(.*?)`' - return re.sub(code_re, r'\1', link_text) + return re.sub(code_re, r'\1', six.ensure_str(link_text)) def py_master_name(self, full_name): """Return the master name for a Python symbol name.""" @@ -389,11 +391,11 @@ class ReferenceResolver(object): manual_link_text = False # Handle different types of references. - if string.startswith('$'): # Doc reference + if six.ensure_str(string).startswith('$'): # Doc reference return self._doc_link(string, link_text, manual_link_text, relative_path_to_root) - elif string.startswith('tensorflow::'): + elif six.ensure_str(string).startswith('tensorflow::'): # C++ symbol return self._cc_link(string, link_text, manual_link_text, relative_path_to_root) @@ -401,7 +403,8 @@ class ReferenceResolver(object): else: is_python = False for py_module_name in self._py_module_names: - if string == py_module_name or string.startswith(py_module_name + '.'): + if string == py_module_name or string.startswith( + six.ensure_str(py_module_name) + '.'): is_python = True break if is_python: # Python symbol @@ -421,7 +424,7 @@ class ReferenceResolver(object): string = string[1:] # remove leading $ # If string has a #, split that part into `hash_tag` - hash_pos = string.find('#') + hash_pos = six.ensure_str(string).find('#') if hash_pos > -1: hash_tag = string[hash_pos:] string = string[:hash_pos] @@ -520,10 +523,10 @@ class _FunctionDetail( def __str__(self): """Return the original string that represents the function detail.""" - parts = [self.keyword + ':\n'] + parts = [six.ensure_str(self.keyword) + ':\n'] parts.append(self.header) for key, value in self.items: - parts.append(' ' + key + ': ') + parts.append(' ' + six.ensure_str(key) + ': ') parts.append(value) return ''.join(parts) @@ -587,7 +590,7 @@ def _parse_function_details(docstring): item_re = re.compile(r'^ ? ?(\*?\*?\w[\w.]*?\s*):\s', re.MULTILINE) for keyword, content in pairs: - content = item_re.split(content) + content = item_re.split(six.ensure_str(content)) header = content[0] items = list(_gen_pairs(content[1:])) @@ -634,7 +637,8 @@ def _parse_md_docstring(py_object, relative_path_to_root, reference_resolver): atat_re = re.compile(r' *@@[a-zA-Z_.0-9]+ *$') raw_docstring = '\n'.join( - line for line in raw_docstring.split('\n') if not atat_re.match(line)) + line for line in six.ensure_str(raw_docstring).split('\n') + if not atat_re.match(six.ensure_str(line))) docstring, compatibility = _handle_compatibility(raw_docstring) docstring, function_details = _parse_function_details(docstring) @@ -698,8 +702,9 @@ def _get_arg_spec(func): def _remove_first_line_indent(string): - indent = len(re.match(r'^\s*', string).group(0)) - return '\n'.join([line[indent:] for line in string.split('\n')]) + indent = len(re.match(r'^\s*', six.ensure_str(string)).group(0)) + return '\n'.join( + [line[indent:] for line in six.ensure_str(string).split('\n')]) PAREN_NUMBER_RE = re.compile(r'^\(([0-9.e-]+)\)') @@ -761,9 +766,9 @@ def _generate_signature(func, reverse_index): default_text = reverse_index[id(default)] elif ast_default is not None: default_text = ( - astor.to_source(ast_default).rstrip('\n').replace('\t', '\\t') - .replace('\n', '\\n').replace('"""', "'")) - default_text = PAREN_NUMBER_RE.sub('\\1', default_text) + six.ensure_str(astor.to_source(ast_default)).rstrip('\n').replace( + '\t', '\\t').replace('\n', '\\n').replace('"""', "'")) + default_text = PAREN_NUMBER_RE.sub('\\1', six.ensure_str(default_text)) if default_text != repr(default): # This may be an internal name. If so, handle the ones we know about. @@ -797,9 +802,9 @@ def _generate_signature(func, reverse_index): # Add *args and *kwargs. if argspec.varargs: - args_list.append('*' + argspec.varargs) + args_list.append('*' + six.ensure_str(argspec.varargs)) if argspec.varkw: - args_list.append('**' + argspec.varkw) + args_list.append('**' + six.ensure_str(argspec.varkw)) return args_list @@ -879,7 +884,7 @@ class _FunctionPageInfo(object): @property def short_name(self): - return self._full_name.split('.')[-1] + return six.ensure_str(self._full_name).split('.')[-1] @property def defined_in(self): @@ -998,7 +1003,7 @@ class _ClassPageInfo(object): @property def short_name(self): """Returns the documented object's short name.""" - return self._full_name.split('.')[-1] + return six.ensure_str(self._full_name).split('.')[-1] @property def defined_in(self): @@ -1091,9 +1096,12 @@ class _ClassPageInfo(object): base_url = parser_config.reference_resolver.reference_to_url( base_full_name, relative_path) - link_info = _LinkInfo(short_name=base_full_name.split('.')[-1], - full_name=base_full_name, obj=base, - doc=base_doc, url=base_url) + link_info = _LinkInfo( + short_name=six.ensure_str(base_full_name).split('.')[-1], + full_name=base_full_name, + obj=base, + doc=base_doc, + url=base_url) bases.append(link_info) self._bases = bases @@ -1121,7 +1129,7 @@ class _ClassPageInfo(object): doc: The property's parsed docstring, a `_DocstringInfo`. """ # Hide useless namedtuple docs-trings - if re.match('Alias for field number [0-9]+', doc.docstring): + if re.match('Alias for field number [0-9]+', six.ensure_str(doc.docstring)): doc = doc._replace(docstring='', brief='') property_info = _PropertyInfo(short_name, full_name, obj, doc) self._properties.append(property_info) @@ -1255,8 +1263,8 @@ class _ClassPageInfo(object): # Omit methods defined by namedtuple. original_method = defining_class.__dict__[short_name] - if (hasattr(original_method, '__module__') and - (original_method.__module__ or '').startswith('namedtuple')): + if (hasattr(original_method, '__module__') and six.ensure_str( + (original_method.__module__ or '')).startswith('namedtuple')): continue # Some methods are often overridden without documentation. Because it's @@ -1294,7 +1302,7 @@ class _ClassPageInfo(object): else: # Exclude members defined by protobuf that are useless if issubclass(py_class, ProtoMessage): - if (short_name.endswith('_FIELD_NUMBER') or + if (six.ensure_str(short_name).endswith('_FIELD_NUMBER') or short_name in ['__slots__', 'DESCRIPTOR']): continue @@ -1332,7 +1340,7 @@ class _ModulePageInfo(object): @property def short_name(self): - return self._full_name.split('.')[-1] + return six.ensure_str(self._full_name).split('.')[-1] @property def defined_in(self): @@ -1425,7 +1433,8 @@ class _ModulePageInfo(object): '__cached__', '__loader__', '__spec__']: continue - member_full_name = self.full_name + '.' + name if self.full_name else name + member_full_name = six.ensure_str(self.full_name) + '.' + six.ensure_str( + name) if self.full_name else name member = parser_config.py_name_to_object(member_full_name) member_doc = _parse_md_docstring(member, relative_path, @@ -1680,20 +1689,21 @@ def _get_defined_in(py_object, parser_config): # TODO(wicke): And make their source file predictable from the file name. # In case this is compiled, point to the original - if path.endswith('.pyc'): + if six.ensure_str(path).endswith('.pyc'): path = path[:-1] # Never include links outside this code base. - if path.startswith('..') or re.search(r'\b_api\b', path): + if six.ensure_str(path).startswith('..') or re.search(r'\b_api\b', + six.ensure_str(path)): return None - if re.match(r'.*/gen_[^/]*\.py$', path): + if re.match(r'.*/gen_[^/]*\.py$', six.ensure_str(path)): return _GeneratedFile(path, parser_config) if 'genfiles' in path or 'tools/api/generator' in path: return _GeneratedFile(path, parser_config) - elif re.match(r'.*_pb2\.py$', path): + elif re.match(r'.*_pb2\.py$', six.ensure_str(path)): # The _pb2.py files all appear right next to their defining .proto file. - return _ProtoFile(path[:-7] + '.proto', parser_config) + return _ProtoFile(six.ensure_str(path[:-7]) + '.proto', parser_config) else: return _PythonFile(path, parser_config) diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py index 6ecdf521cdc..15d4cad89cc 100644 --- a/tensorflow/tools/docs/parser_test.py +++ b/tensorflow/tools/docs/parser_test.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,6 +24,8 @@ import functools import os import sys +import six + from tensorflow.python.platform import googletest from tensorflow.python.util import tf_inspect from tensorflow.tools.docs import doc_controls @@ -180,7 +183,8 @@ class ParserTest(googletest.TestCase): # Make sure the brief docstring is present self.assertEqual( - tf_inspect.getdoc(TestClass).split('\n')[0], page_info.doc.brief) + six.ensure_str(tf_inspect.getdoc(TestClass)).split('\n')[0], + page_info.doc.brief) # Make sure the method is present self.assertEqual(TestClass.a_method, page_info.methods[0].obj) @@ -236,7 +240,7 @@ class ParserTest(googletest.TestCase): # 'Alias for field number ##'. These props are returned sorted. def sort_key(prop_info): - return int(prop_info.obj.__doc__.split(' ')[-1]) + return int(six.ensure_str(prop_info.obj.__doc__).split(' ')[-1]) self.assertSequenceEqual(page_info.properties, sorted(page_info.properties, key=sort_key)) @@ -378,7 +382,8 @@ class ParserTest(googletest.TestCase): # Make sure the brief docstring is present self.assertEqual( - tf_inspect.getdoc(test_module).split('\n')[0], page_info.doc.brief) + six.ensure_str(tf_inspect.getdoc(test_module)).split('\n')[0], + page_info.doc.brief) # Make sure that the members are there funcs = {f_info.obj for f_info in page_info.functions} @@ -422,7 +427,8 @@ class ParserTest(googletest.TestCase): # Make sure the brief docstring is present self.assertEqual( - tf_inspect.getdoc(test_function).split('\n')[0], page_info.doc.brief) + six.ensure_str(tf_inspect.getdoc(test_function)).split('\n')[0], + page_info.doc.brief) # Make sure the extracted signature is good. self.assertEqual(['unused_arg', "unused_kwarg='default'"], @@ -461,7 +467,8 @@ class ParserTest(googletest.TestCase): # Make sure the brief docstring is present self.assertEqual( - tf_inspect.getdoc(test_function_with_args_kwargs).split('\n')[0], + six.ensure_str( + tf_inspect.getdoc(test_function_with_args_kwargs)).split('\n')[0], page_info.doc.brief) # Make sure the extracted signature is good. @@ -751,7 +758,8 @@ class TestParseFunctionDetails(googletest.TestCase): self.assertEqual( RELU_DOC, - docstring + ''.join(str(detail) for detail in function_details)) + six.ensure_str(docstring) + + ''.join(str(detail) for detail in function_details)) class TestGenerateSignature(googletest.TestCase): diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py index 1a3e79621f8..d7237b1a39a 100644 --- a/tensorflow/tools/docs/pretty_docs.py +++ b/tensorflow/tools/docs/pretty_docs.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +29,7 @@ from __future__ import division from __future__ import print_function import textwrap +import six def build_md_page(page_info): @@ -83,7 +85,8 @@ def _build_class_page(page_info): """Given a ClassPageInfo object Return the page as an md string.""" parts = ['# {page_info.full_name}\n\n'.format(page_info=page_info)] - parts.append('## Class `%s`\n\n' % page_info.full_name.split('.')[-1]) + parts.append('## Class `%s`\n\n' % + six.ensure_str(page_info.full_name).split('.')[-1]) if page_info.bases: parts.append('Inherits From: ') @@ -222,7 +225,7 @@ def _build_module_page(page_info): parts.append(template.format(**item._asdict())) if item.doc.brief: - parts.append(': ' + item.doc.brief) + parts.append(': ' + six.ensure_str(item.doc.brief)) parts.append('\n\n') @@ -234,7 +237,7 @@ def _build_module_page(page_info): parts.append(template.format(**item._asdict())) if item.doc.brief: - parts.append(': ' + item.doc.brief) + parts.append(': ' + six.ensure_str(item.doc.brief)) parts.append('\n\n') @@ -246,7 +249,7 @@ def _build_module_page(page_info): parts.append(template.format(**item._asdict())) if item.doc.brief: - parts.append(': ' + item.doc.brief) + parts.append(': ' + six.ensure_str(item.doc.brief)) parts.append('\n\n') @@ -273,7 +276,7 @@ def _build_signature(obj_info, use_full_name=True): '```\n\n') parts = ['``` python'] - parts.extend(['@' + dec for dec in obj_info.decorators]) + parts.extend(['@' + six.ensure_str(dec) for dec in obj_info.decorators]) signature_template = '{name}({sig})' if not obj_info.signature: @@ -313,7 +316,7 @@ def _build_function_details(function_details): parts = [] for detail in function_details: sub = [] - sub.append('#### ' + detail.keyword + ':\n\n') + sub.append('#### ' + six.ensure_str(detail.keyword) + ':\n\n') sub.append(textwrap.dedent(detail.header)) for key, value in detail.items: sub.append('* `%s`: %s' % (key, value)) diff --git a/tensorflow/tools/docs/py_guide_parser.py b/tensorflow/tools/docs/py_guide_parser.py index b00694dc403..70149c4dd9e 100644 --- a/tensorflow/tools/docs/py_guide_parser.py +++ b/tensorflow/tools/docs/py_guide_parser.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,14 +22,16 @@ from __future__ import print_function import os import re +import six def md_files_in_dir(py_guide_src_dir): """Returns a list of filename (full_path, base) pairs for guide files.""" all_in_dir = [(os.path.join(py_guide_src_dir, f), f) for f in os.listdir(py_guide_src_dir)] - return [(full, f) for full, f in all_in_dir - if os.path.isfile(full) and f.endswith('.md')] + return [(full, f) + for full, f in all_in_dir + if os.path.isfile(full) and six.ensure_str(f).endswith('.md')] class PyGuideParser(object): diff --git a/tensorflow/tools/docs/py_guide_parser_test.py b/tensorflow/tools/docs/py_guide_parser_test.py index 168b0535a94..2975a1a6575 100644 --- a/tensorflow/tools/docs/py_guide_parser_test.py +++ b/tensorflow/tools/docs/py_guide_parser_test.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +21,8 @@ from __future__ import print_function import os +import six + from tensorflow.python.platform import test from tensorflow.tools.docs import py_guide_parser @@ -38,7 +41,7 @@ class TestPyGuideParser(py_guide_parser.PyGuideParser): def process_in_blockquote(self, line_number, line): self.calls.append((line_number, 'b', line)) - self.replace_line(line_number, line + ' BQ') + self.replace_line(line_number, six.ensure_str(line) + ' BQ') def process_line(self, line_number, line): self.calls.append((line_number, 'l', line)) diff --git a/tensorflow/tools/git/BUILD b/tensorflow/tools/git/BUILD index fb6a07133e7..8a47f4c4c2d 100644 --- a/tensorflow/tools/git/BUILD +++ b/tensorflow/tools/git/BUILD @@ -11,6 +11,7 @@ package( py_binary( name = "gen_git_source", srcs = ["gen_git_source.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", + deps = ["@six_archive//:six"], ) diff --git a/tensorflow/tools/git/gen_git_source.py b/tensorflow/tools/git/gen_git_source.py index d29b535ae30..011406e2288 100755 --- a/tensorflow/tools/git/gen_git_source.py +++ b/tensorflow/tools/git/gen_git_source.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,13 +27,16 @@ NOTE: this script is only used in opensource. from __future__ import absolute_import from __future__ import division from __future__ import print_function -from builtins import bytes # pylint: disable=redefined-builtin + import argparse +from builtins import bytes # pylint: disable=redefined-builtin import json import os import shutil import subprocess +import six + def parse_branch_ref(filename): """Given a filename of a .git/HEAD file return ref path. @@ -161,10 +165,13 @@ def get_git_version(git_base_path, git_tag_override): unknown_label = b"unknown" try: # Force to bytes so this works on python 2 and python 3 - val = bytes(subprocess.check_output([ - "git", str("--git-dir=%s/.git" % git_base_path), - str("--work-tree=" + git_base_path), "describe", "--long", "--tags" - ]).strip()) + val = bytes( + subprocess.check_output([ + "git", + str("--git-dir=%s/.git" % git_base_path), + str("--work-tree=" + six.ensure_str(git_base_path)), "describe", + "--long", "--tags" + ]).strip()) version_separator = b"-" if git_tag_override and val: split_val = val.split(version_separator) diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index adafe2aca12..0f5c298b48b 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -3,7 +3,6 @@ load( "//tensorflow:tensorflow.bzl", - "if_not_v2", "if_not_windows", "tf_cc_binary", "tf_cc_test", diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 174f21610dc..1162ad5de47 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -163,6 +163,7 @@ genrule( "@com_google_protobuf//:LICENSE", "@snappy//:COPYING", "@zlib_archive//:zlib.h", + "@six_archive//:LICENSE", ] + select({ "//tensorflow:android": [], "//tensorflow:ios": [], @@ -235,6 +236,7 @@ genrule( "@zlib_archive//:zlib.h", "@grpc//:LICENSE", "@grpc//third_party/address_sorting:LICENSE", + "@six_archive//:LICENSE", ] + select({ "//tensorflow:android": [], "//tensorflow:ios": [], diff --git a/tensorflow/tools/tensorflow_builder/compat_checker/BUILD b/tensorflow/tools/tensorflow_builder/compat_checker/BUILD index b60a7df1b76..d2119dd1e63 100644 --- a/tensorflow/tools/tensorflow_builder/compat_checker/BUILD +++ b/tensorflow/tools/tensorflow_builder/compat_checker/BUILD @@ -13,9 +13,12 @@ package( ], ) +licenses(["notice"]) # Apache 2.0 + py_library( name = "compat_checker", srcs = ["compat_checker.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow/python:platform", "//tensorflow/python:util", @@ -29,6 +32,7 @@ py_test( data = [ "//tensorflow/tools/tensorflow_builder/compat_checker:test_config", ], + python_version = "PY3", tags = ["no_pip"], deps = [ ":compat_checker", diff --git a/tensorflow/tools/tensorflow_builder/compat_checker/compat_checker.py b/tensorflow/tools/tensorflow_builder/compat_checker/compat_checker.py index d65d7727ffa..ec8a0ba6f96 100644 --- a/tensorflow/tools/tensorflow_builder/compat_checker/compat_checker.py +++ b/tensorflow/tools/tensorflow_builder/compat_checker/compat_checker.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,16 +24,11 @@ import re import sys import six +from six.moves import range +import six.moves.configparser from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_inspect -# pylint: disable=g-import-not-at-top -if six.PY2: - import ConfigParser -else: - import configparser as ConfigParser -# pylint: enable=g-import-not-at-top - PATH_TO_DIR = "tensorflow/tools/tensorflow_builder/compat_checker" @@ -56,8 +52,8 @@ def _compare_versions(v1, v2): raise RuntimeError("Cannot compare `inf` to `inf`.") rtn_dict = {"smaller": None, "larger": None} - v1_list = v1.split(".") - v2_list = v2.split(".") + v1_list = six.ensure_str(v1).split(".") + v2_list = six.ensure_str(v2).split(".") # Take care of cases with infinity (arg=`inf`). if v1_list[0] == "inf": v1_list[0] = str(int(v2_list[0]) + 1) @@ -380,7 +376,7 @@ class ConfigCompatChecker(object): curr_status = True # Initialize config parser for parsing version requirements file. - parser = ConfigParser.ConfigParser() + parser = six.moves.configparser.ConfigParser() parser.read(self.req_file) if not parser.sections(): @@ -643,7 +639,7 @@ class ConfigCompatChecker(object): if filtered[-1] == "]": filtered = filtered[:-1] elif "]" in filtered[-1]: - filtered[-1] = filtered[-1].replace("]", "") + filtered[-1] = six.ensure_str(filtered[-1]).replace("]", "") # If `]` is missing, then it could be a formatting issue with # config file (.ini.). Add to warning. else: @@ -792,7 +788,7 @@ class ConfigCompatChecker(object): Boolean that is a status of the compatibility check result. """ # Check if all `Required` configs are found in user configs. - usr_keys = self.usr_config.keys() + usr_keys = list(self.usr_config.keys()) for k in six.iterkeys(self.usr_config): if k not in usr_keys: @@ -809,10 +805,10 @@ class ConfigCompatChecker(object): for config_name, spec in six.iteritems(self.usr_config): temp_status = True # Check under which section the user config is defined. - in_required = config_name in self.required.keys() - in_optional = config_name in self.optional.keys() - in_unsupported = config_name in self.unsupported.keys() - in_dependency = config_name in self.dependency.keys() + in_required = config_name in list(self.required.keys()) + in_optional = config_name in list(self.optional.keys()) + in_unsupported = config_name in list(self.unsupported.keys()) + in_dependency = config_name in list(self.dependency.keys()) # Add to warning if user config is not specified in the config file. if not (in_required or in_optional or in_unsupported or in_dependency): diff --git a/tensorflow/tools/tensorflow_builder/compat_checker/compat_checker_test.py b/tensorflow/tools/tensorflow_builder/compat_checker/compat_checker_test.py index bd0d50a9f99..b12815b555d 100644 --- a/tensorflow/tools/tensorflow_builder/compat_checker/compat_checker_test.py +++ b/tensorflow/tools/tensorflow_builder/compat_checker/compat_checker_test.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -88,7 +89,7 @@ class CompatCheckerTest(unittest.TestCase): # Make sure no warning or error messages are recorded. self.assertFalse(len(self.compat_checker.error_msg)) # Make sure total # of successes match total # of configs. - cnt = len(USER_CONFIG_IN_RANGE.keys()) + cnt = len(list(USER_CONFIG_IN_RANGE.keys())) self.assertEqual(len(self.compat_checker.successes), cnt) def testWithUserConfigNotInRange(self): @@ -106,7 +107,7 @@ class CompatCheckerTest(unittest.TestCase): err_msg_list = self.compat_checker.failures self.assertTrue(len(err_msg_list)) # Make sure total # of failures match total # of configs. - cnt = len(USER_CONFIG_NOT_IN_RANGE.keys()) + cnt = len(list(USER_CONFIG_NOT_IN_RANGE.keys())) self.assertEqual(len(err_msg_list), cnt) def testWithUserConfigMissing(self): diff --git a/tensorflow/tools/tensorflow_builder/config_detector/BUILD b/tensorflow/tools/tensorflow_builder/config_detector/BUILD index 6227366fec1..ab52eb33fdc 100644 --- a/tensorflow/tools/tensorflow_builder/config_detector/BUILD +++ b/tensorflow/tools/tensorflow_builder/config_detector/BUILD @@ -14,22 +14,24 @@ py_binary( data = [ "//tensorflow/tools/tensorflow_builder/config_detector/data/golden:cuda_cc_golden", ], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", deps = [ ":cuda_compute_capability", "@absl_py//absl:app", "@absl_py//absl/flags", + "@six_archive//:six", ], ) py_binary( name = "cuda_compute_capability", srcs = ["data/cuda_compute_capability.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", deps = [ "@absl_py//absl:app", "@absl_py//absl/flags", + "@six_archive//:six", ], ) diff --git a/tensorflow/tools/tensorflow_builder/config_detector/config_detector.py b/tensorflow/tools/tensorflow_builder/config_detector/config_detector.py index 680247d19b8..2c24780bcfd 100755 --- a/tensorflow/tools/tensorflow_builder/config_detector/config_detector.py +++ b/tensorflow/tools/tensorflow_builder/config_detector/config_detector.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -70,6 +71,7 @@ import subprocess import sys from absl import app from absl import flags +import six from tensorflow.tools.tensorflow_builder.config_detector.data import cuda_compute_capability @@ -182,7 +184,7 @@ def get_cpu_type(): """ key = "cpu_type" out, err = run_shell_cmd(cmds_all[PLATFORM][key]) - cpu_detected = out.split(":")[1].strip() + cpu_detected = out.split(b":")[1].strip() if err and FLAGS.debug: print("Error in detecting CPU type:\n %s" % str(err)) @@ -201,7 +203,7 @@ def get_cpu_arch(): if err and FLAGS.debug: print("Error in detecting CPU arch:\n %s" % str(err)) - return out.strip("\n") + return out.strip(b"\n") def get_distrib(): @@ -216,7 +218,7 @@ def get_distrib(): if err and FLAGS.debug: print("Error in detecting distribution:\n %s" % str(err)) - return out.strip("\n") + return out.strip(b"\n") def get_distrib_version(): @@ -233,7 +235,7 @@ def get_distrib_version(): "Error in detecting distribution version:\n %s" % str(err) ) - return out.strip("\n") + return out.strip(b"\n") def get_gpu_type(): @@ -251,7 +253,7 @@ def get_gpu_type(): key = "gpu_type_no_sudo" gpu_dict = cuda_compute_capability.retrieve_from_golden() out, err = run_shell_cmd(cmds_all[PLATFORM][key]) - ret_val = out.split(" ") + ret_val = out.split(b" ") gpu_id = ret_val[0] if err and FLAGS.debug: print("Error in detecting GPU type:\n %s" % str(err)) @@ -261,10 +263,10 @@ def get_gpu_type(): return gpu_id, GPU_TYPE else: if "[" or "]" in ret_val[1]: - gpu_release = ret_val[1].replace("[", "") + " " - gpu_release += ret_val[2].replace("]", "").strip("\n") + gpu_release = ret_val[1].replace(b"[", b"") + b" " + gpu_release += ret_val[2].replace(b"]", b"").strip(b"\n") else: - gpu_release = ret_val[1].replace("\n", " ") + gpu_release = six.ensure_str(ret_val[1]).replace("\n", " ") if gpu_release not in gpu_dict: GPU_TYPE = "unknown" @@ -285,7 +287,7 @@ def get_gpu_count(): if err and FLAGS.debug: print("Error in detecting GPU count:\n %s" % str(err)) - return out.strip("\n") + return out.strip(b"\n") def get_cuda_version_all(): @@ -303,7 +305,7 @@ def get_cuda_version_all(): """ key = "cuda_ver_all" out, err = run_shell_cmd(cmds_all[PLATFORM.lower()][key]) - ret_val = out.split("\n") + ret_val = out.split(b"\n") filtered = [] for item in ret_val: if item not in ["\n", ""]: @@ -311,9 +313,9 @@ def get_cuda_version_all(): all_vers = [] for item in filtered: - ver_re = re.search(r".*/cuda(\-[\d]+\.[\d]+)?", item) + ver_re = re.search(r".*/cuda(\-[\d]+\.[\d]+)?", item.decode("utf-8")) if ver_re.group(1): - all_vers.append(ver_re.group(1).strip("-")) + all_vers.append(six.ensure_str(ver_re.group(1)).strip("-")) if err and FLAGS.debug: print("Error in detecting CUDA version:\n %s" % str(err)) @@ -409,13 +411,13 @@ def get_cudnn_version(): if err and FLAGS.debug: print("Error in finding `cudnn.h`:\n %s" % str(err)) - if len(out.split(" ")) > 1: + if len(out.split(b" ")) > 1: cmd = cmds[0] + " | " + cmds[1] out_re, err_re = run_shell_cmd(cmd) if err_re and FLAGS.debug: print("Error in detecting cuDNN version:\n %s" % str(err_re)) - return out_re.strip("\n") + return out_re.strip(b"\n") else: return @@ -432,7 +434,7 @@ def get_gcc_version(): if err and FLAGS.debug: print("Error in detecting GCC version:\n %s" % str(err)) - return out.strip("\n") + return out.strip(b"\n") def get_glibc_version(): @@ -447,7 +449,7 @@ def get_glibc_version(): if err and FLAGS.debug: print("Error in detecting GCC version:\n %s" % str(err)) - return out.strip("\n") + return out.strip(b"\n") def get_libstdcpp_version(): @@ -462,7 +464,7 @@ def get_libstdcpp_version(): if err and FLAGS.debug: print("Error in detecting libstdc++ version:\n %s" % str(err)) - ver = out.split("_")[-1].replace("\n", "") + ver = out.split(b"_")[-1].replace(b"\n", b"") return ver @@ -485,7 +487,7 @@ def get_cpu_isa_version(): found = [] missing = [] for isa in required_isa: - for sys_isa in ret_val.split(" "): + for sys_isa in ret_val.split(b" "): if isa == sys_isa: if isa not in found: found.append(isa) @@ -539,7 +541,7 @@ def get_all_configs(): json_data = {} missing = [] warning = [] - for config, call_func in all_functions.iteritems(): + for config, call_func in six.iteritems(all_functions): ret_val = call_func if not ret_val: configs_found.append([config, "\033[91m\033[1mMissing\033[0m"]) @@ -557,10 +559,10 @@ def get_all_configs(): configs_found.append([config, ret_val[0]]) json_data[config] = ret_val[0] else: - configs_found.append( - [config, - "\033[91m\033[1mMissing " + str(ret_val[1])[1:-1] + "\033[0m"] - ) + configs_found.append([ + config, "\033[91m\033[1mMissing " + + six.ensure_str(str(ret_val[1])[1:-1]) + "\033[0m" + ]) missing.append( [config, "\n\t=> Found %s but missing %s" @@ -587,7 +589,7 @@ def print_all_configs(configs, missing, warning): llen = 65 # line length for i, row in enumerate(configs): if i != 0: - print_text += "-"*llen + "\n" + print_text += six.ensure_str("-" * llen) + "\n" if isinstance(row[1], list): val = ", ".join(row[1]) @@ -629,7 +631,7 @@ def save_to_file(json_data, filename): print("filename: %s" % filename) filename += ".json" - with open(PATH_TO_DIR + "/" + filename, "w") as f: + with open(PATH_TO_DIR + "/" + six.ensure_str(filename), "w") as f: json.dump(json_data, f, sort_keys=True, indent=4) print(" Successfully wrote configs to file `%s`.\n" % (filename)) diff --git a/tensorflow/tools/tensorflow_builder/config_detector/data/cuda_compute_capability.py b/tensorflow/tools/tensorflow_builder/config_detector/data/cuda_compute_capability.py index 4dba961b037..9d17dbc9178 100644 --- a/tensorflow/tools/tensorflow_builder/config_detector/data/cuda_compute_capability.py +++ b/tensorflow/tools/tensorflow_builder/config_detector/data/cuda_compute_capability.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -42,6 +43,7 @@ import re from absl import app from absl import flags +import six import six.moves.urllib.request as urllib FLAGS = flags.FLAGS @@ -61,21 +63,18 @@ def retrieve_from_web(generate_csv=False): NVIDIA page. Order goes from top to bottom of the webpage content (.html). """ url = "https://developer.nvidia.com/cuda-gpus" - source = urllib.urlopen(url) + source = urllib.request.urlopen(url) matches = [] while True: line = source.readline() if "" in line: break else: - gpu = re.search( - r"([\w\S\s\d\[\]\,]+[^*])(([\w\S\s\d\[\]\,]+[^*])(.*", - line - ) + six.ensure_str(line)) if gpu: matches.append(gpu.group(1)) elif capability: @@ -155,15 +154,15 @@ def create_gpu_capa_map(match_list, gpu = "" cnt += 1 - if len(gpu_capa.keys()) < cnt: + if len(list(gpu_capa.keys())) < cnt: mismatch_cnt += 1 - cnt = len(gpu_capa.keys()) + cnt = len(list(gpu_capa.keys())) else: gpu = match if generate_csv: - f_name = filename + ".csv" + f_name = six.ensure_str(filename) + ".csv" write_csv_from_dict(f_name, gpu_capa) return gpu_capa @@ -179,8 +178,8 @@ def write_csv_from_dict(filename, input_dict): filename: String that is the output file name. input_dict: Dictionary that is to be written out to a `.csv` file. """ - f = open(PATH_TO_DIR + "/data/" + filename, "w") - for k, v in input_dict.iteritems(): + f = open(PATH_TO_DIR + "/data/" + six.ensure_str(filename), "w") + for k, v in six.iteritems(input_dict): line = k for item in v: line += "," + item @@ -203,7 +202,7 @@ def check_with_golden(filename): Args: filename: String that is the name of the newly created file. """ - path_to_file = PATH_TO_DIR + "/data/" + filename + path_to_file = PATH_TO_DIR + "/data/" + six.ensure_str(filename) if os.path.isfile(path_to_file) and os.path.isfile(CUDA_CC_GOLDEN_DIR): with open(path_to_file, "r") as f_new: with open(CUDA_CC_GOLDEN_DIR, "r") as f_golden: diff --git a/tensorflow/tools/test/BUILD b/tensorflow/tools/test/BUILD index 5607fe73361..b09acad5036 100644 --- a/tensorflow/tools/test/BUILD +++ b/tensorflow/tools/test/BUILD @@ -22,17 +22,19 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow:tensorflow_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:client", "//tensorflow/python:errors", "//tensorflow/python:platform", + "@six_archive//:six", ], ) py_binary( name = "system_info", srcs = ["system_info.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", deps = [ ":system_info_lib", @@ -50,16 +52,20 @@ py_library( ":system_info_lib", "//tensorflow/core:protos_all_py", "//tensorflow/python:platform", + "@six_archive//:six", ], ) py_binary( name = "run_and_gather_logs", srcs = ["run_and_gather_logs.py"], - python_version = "PY2", + python_version = "PY3", srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [":run_and_gather_logs_main_lib"], + deps = [ + ":run_and_gather_logs_main_lib", + "@six_archive//:six", + ], ) py_library( @@ -72,6 +78,7 @@ py_library( "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:platform", + "@six_archive//:six", ], ) diff --git a/tensorflow/tools/test/gpu_info_lib.py b/tensorflow/tools/test/gpu_info_lib.py index 3a4ff4fdff4..e9fb227baf1 100644 --- a/tensorflow/tools/test/gpu_info_lib.py +++ b/tensorflow/tools/test/gpu_info_lib.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,6 +22,9 @@ from __future__ import print_function import ctypes as ct import platform +import six +from six.moves import range + from tensorflow.core.util import test_log_pb2 from tensorflow.python.framework import errors from tensorflow.python.platform import gfile @@ -30,10 +34,11 @@ def _gather_gpu_devices_proc(): """Try to gather NVidia GPU device information via /proc/driver.""" dev_info = [] for f in gfile.Glob("/proc/driver/nvidia/gpus/*/information"): - bus_id = f.split("/")[5] - key_values = dict(line.rstrip().replace("\t", "").split(":", 1) - for line in gfile.GFile(f, "r")) - key_values = dict((k.lower(), v.strip(" ").rstrip(" ")) + bus_id = six.ensure_str(f).split("/")[5] + key_values = dict( + six.ensure_str(line.rstrip()).replace("\t", "").split(":", 1) + for line in gfile.GFile(f, "r")) + key_values = dict((k.lower(), six.ensure_str(v).strip(" ").rstrip(" ")) for (k, v) in key_values.items()) info = test_log_pb2.GPUInfo() info.model = key_values.get("model", "Unknown") diff --git a/tensorflow/tools/test/run_and_gather_logs.py b/tensorflow/tools/test/run_and_gather_logs.py index f6b25bbfccb..a1486826615 100644 --- a/tensorflow/tools/test/run_and_gather_logs.py +++ b/tensorflow/tools/test/run_and_gather_logs.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,9 +26,10 @@ from string import maketrans import sys import time +import six + from google.protobuf import json_format from google.protobuf import text_format - from tensorflow.core.util import test_log_pb2 from tensorflow.python.platform import app from tensorflow.python.platform import gfile @@ -83,8 +85,9 @@ def main(unused_args): if FLAGS.test_log_output_filename: file_name = FLAGS.test_log_output_filename else: - file_name = (name.strip("/").translate(maketrans("/:", "__")) + - time.strftime("%Y%m%d%H%M%S", time.gmtime())) + file_name = ( + six.ensure_str(name).strip("/").translate(maketrans("/:", "__")) + + time.strftime("%Y%m%d%H%M%S", time.gmtime())) if FLAGS.test_log_output_use_tmpdir: tmpdir = test.get_temp_dir() output_path = os.path.join(tmpdir, FLAGS.test_log_output_dir, file_name) @@ -92,7 +95,8 @@ def main(unused_args): output_path = os.path.join( os.path.abspath(FLAGS.test_log_output_dir), file_name) json_test_results = json_format.MessageToJson(test_results) - gfile.GFile(output_path + ".json", "w").write(json_test_results) + gfile.GFile(six.ensure_str(output_path) + ".json", + "w").write(json_test_results) tf_logging.info("Test results written to: %s" % output_path) diff --git a/tensorflow/tools/test/run_and_gather_logs_lib.py b/tensorflow/tools/test/run_and_gather_logs_lib.py index 3b4921bb983..f629e3a10b6 100644 --- a/tensorflow/tools/test/run_and_gather_logs_lib.py +++ b/tensorflow/tools/test/run_and_gather_logs_lib.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,6 +26,8 @@ import subprocess import tempfile import time +import six + from tensorflow.core.util import test_log_pb2 from tensorflow.python.platform import gfile from tensorflow.tools.test import gpu_info_lib @@ -118,12 +121,15 @@ def run_and_gather_logs(name, test_name, test_args, IOError: If there are problems gathering test log output from the test. MissingLogsError: If we couldn't find benchmark logs. """ - if not (test_name and test_name.startswith("//") and ".." not in test_name and - not test_name.endswith(":") and not test_name.endswith(":all") and - not test_name.endswith("...") and len(test_name.split(":")) == 2): + if not (test_name and six.ensure_str(test_name).startswith("//") and + ".." not in test_name and not six.ensure_str(test_name).endswith(":") + and not six.ensure_str(test_name).endswith(":all") and + not six.ensure_str(test_name).endswith("...") and + len(six.ensure_str(test_name).split(":")) == 2): raise ValueError("Expected test_name parameter with a unique test, e.g.: " "--test_name=//path/to:test") - test_executable = test_name.rstrip().strip("/").replace(":", "/") + test_executable = six.ensure_str(test_name.rstrip()).strip("/").replace( + ":", "/") if gfile.Exists(os.path.join("bazel-bin", test_executable)): # Running in standalone mode from core of the repository @@ -136,14 +142,17 @@ def run_and_gather_logs(name, test_name, test_args, gpu_config = gpu_info_lib.gather_gpu_devices() if gpu_config: gpu_name = gpu_config[0].model - gpu_short_name_match = re.search(r"Tesla (K40|K80|P100|V100)", gpu_name) + gpu_short_name_match = re.search(r"Tesla (K40|K80|P100|V100)", + six.ensure_str(gpu_name)) if gpu_short_name_match: gpu_short_name = gpu_short_name_match.group(0) - test_adjusted_name = name + "|" + gpu_short_name.replace(" ", "_") + test_adjusted_name = six.ensure_str(name) + "|" + gpu_short_name.replace( + " ", "_") temp_directory = tempfile.mkdtemp(prefix="run_and_gather_logs") - mangled_test_name = (test_adjusted_name.strip("/") - .replace("|", "_").replace("/", "_").replace(":", "_")) + mangled_test_name = ( + six.ensure_str(test_adjusted_name).strip("/").replace("|", "_").replace( + "/", "_").replace(":", "_")) test_file_prefix = os.path.join(temp_directory, mangled_test_name) test_file_prefix = "%s." % test_file_prefix diff --git a/tensorflow/tools/test/system_info_lib.py b/tensorflow/tools/test/system_info_lib.py index 59a30f9a368..ccdd8f3bf2a 100644 --- a/tensorflow/tools/test/system_info_lib.py +++ b/tensorflow/tools/test/system_info_lib.py @@ -1,3 +1,4 @@ +# Lint as: python2, python3 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -29,6 +30,8 @@ import socket # OSS tree. They are installable via pip. import cpuinfo import psutil + +import six # pylint: enable=g-bad-import-order from tensorflow.core.util import test_log_pb2 @@ -81,7 +84,8 @@ def gather_cpu_info(): # Gather num_cores_allowed try: with gfile.GFile('/proc/self/status', 'rb') as fh: - nc = re.search(r'(?m)^Cpus_allowed:\s*(.*)$', fh.read().decode('utf-8')) + nc = re.search(r'(?m)^Cpus_allowed:\s*(.*)$', + six.ensure_text(fh.read(), 'utf-8')) if nc: # e.g. 'ff' => 8, 'fff' => 12 cpu_info.num_cores_allowed = ( bin(int(nc.group(1).replace(',', ''), 16)).count('1'))