From 76e8f7b7fdf89b131e0406022129d5dde6b89e40 Mon Sep 17 00:00:00 2001 From: Anna R <annarev@google.com> Date: Tue, 24 Jul 2018 14:02:04 -0700 Subject: [PATCH] Initial API compatibility script for TF2.0. I am pretty much reusing 1.0 conversion script but passing V2 data. Also, remove code from tf_update.py which is also in ast_edits.py. PiperOrigin-RevId: 205887317 --- tensorflow/tools/compatibility/BUILD | 65 ++- tensorflow/tools/compatibility/renames_v2.py | 134 +++++ .../compatibility/testdata/test_file_v1_10.py | 34 ++ tensorflow/tools/compatibility/tf_upgrade.py | 486 +----------------- .../tools/compatibility/tf_upgrade_test.py | 5 +- .../tools/compatibility/tf_upgrade_v2.py | 115 +++++ .../tools/compatibility/tf_upgrade_v2_test.py | 83 +++ tensorflow/tools/compatibility/update/BUILD | 15 + .../update/generate_v2_renames_map.py | 103 ++++ 9 files changed, 551 insertions(+), 489 deletions(-) create mode 100644 tensorflow/tools/compatibility/renames_v2.py create mode 100644 tensorflow/tools/compatibility/testdata/test_file_v1_10.py create mode 100644 tensorflow/tools/compatibility/tf_upgrade_v2.py create mode 100644 tensorflow/tools/compatibility/tf_upgrade_v2_test.py create mode 100644 tensorflow/tools/compatibility/update/BUILD create mode 100644 tensorflow/tools/compatibility/update/generate_v2_renames_map.py diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD index b7bfb29aae4..55792c51fe8 100644 --- a/tensorflow/tools/compatibility/BUILD +++ b/tensorflow/tools/compatibility/BUILD @@ -8,10 +8,17 @@ load( "tf_cc_test", # @unused ) +py_library( + name = "ast_edits", + srcs = ["ast_edits.py"], + srcs_version = "PY2AND3", +) + py_binary( name = "tf_upgrade", srcs = ["tf_upgrade.py"], srcs_version = "PY2AND3", + deps = [":ast_edits"], ) py_test( @@ -26,6 +33,28 @@ py_test( ], ) +py_binary( + name = "tf_upgrade_v2", + srcs = [ + "renames_v2.py", + "tf_upgrade_v2.py", + ], + srcs_version = "PY2AND3", + deps = [":ast_edits"], +) + +py_test( + name = "tf_upgrade_v2_test", + srcs = ["tf_upgrade_v2_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":tf_upgrade_v2", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "@six_archive//:six", + ], +) + # Keep for reference, this test will succeed in 0.11 but fail in 1.0 # py_test( # name = "test_file_v0_11", @@ -62,9 +91,37 @@ py_test( ], ) -exports_files( - [ - "tf_upgrade.py", - "testdata/test_file_v0_11.py", +genrule( + name = "generate_upgraded_file_v2", + testonly = 1, + srcs = ["testdata/test_file_v1_10.py"], + outs = [ + "test_file_v2_0.py", + "report_v2.txt", + ], + cmd = ("$(location :tf_upgrade_v2)" + + " --infile $(location testdata/test_file_v1_10.py)" + + " --outfile $(location test_file_v2_0.py)" + + " --reportfile $(location report_v2.txt)"), + tools = [":tf_upgrade_v2"], +) + +py_test( + name = "test_file_v2_0", + size = "small", + srcs = ["test_file_v2_0.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +exports_files( + [ + "ast_edits.py", + "tf_upgrade.py", + "renames_v2.py", + "testdata/test_file_v0_11.py", + "testdata/test_file_v1_10.py", ], ) diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py new file mode 100644 index 00000000000..216aa41b60e --- /dev/null +++ b/tensorflow/tools/compatibility/renames_v2.py @@ -0,0 +1,134 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=line-too-long +"""List of renames to apply when converting from TF 1.0 to TF 2.0. + +THIS FILE IS AUTOGENERATED: To update, please run: + bazel build tensorflow/tools/compatibility/update:generate_v2_renames_map + bazel-bin/tensorflow/tools/compatibility/update/generate_v2_renames_map +This file should be updated whenever endpoints are deprecated. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +renames = { + 'tf.acos': 'tf.math.acos', + 'tf.acosh': 'tf.math.acosh', + 'tf.add': 'tf.math.add', + 'tf.as_string': 'tf.dtypes.as_string', + 'tf.asin': 'tf.math.asin', + 'tf.asinh': 'tf.math.asinh', + 'tf.atan': 'tf.math.atan', + 'tf.atan2': 'tf.math.atan2', + 'tf.atanh': 'tf.math.atanh', + 'tf.batch_to_space_nd': 'tf.manip.batch_to_space_nd', + 'tf.betainc': 'tf.math.betainc', + 'tf.ceil': 'tf.math.ceil', + 'tf.check_numerics': 'tf.debugging.check_numerics', + 'tf.cholesky': 'tf.linalg.cholesky', + 'tf.cos': 'tf.math.cos', + 'tf.cosh': 'tf.math.cosh', + 'tf.cross': 'tf.linalg.cross', + 'tf.decode_base64': 'tf.io.decode_base64', + 'tf.decode_compressed': 'tf.io.decode_compressed', + 'tf.decode_json_example': 'tf.io.decode_json_example', + 'tf.decode_raw': 'tf.io.decode_raw', + 'tf.dequantize': 'tf.quantization.dequantize', + 'tf.diag': 'tf.linalg.tensor_diag', + 'tf.diag_part': 'tf.linalg.tensor_diag_part', + 'tf.digamma': 'tf.math.digamma', + 'tf.encode_base64': 'tf.io.encode_base64', + 'tf.equal': 'tf.math.equal', + 'tf.erfc': 'tf.math.erfc', + 'tf.exp': 'tf.math.exp', + 'tf.expm1': 'tf.math.expm1', + 'tf.extract_image_patches': 'tf.image.extract_image_patches', + 'tf.fake_quant_with_min_max_args': 'tf.quantization.fake_quant_with_min_max_args', + 'tf.fake_quant_with_min_max_args_gradient': 'tf.quantization.fake_quant_with_min_max_args_gradient', + 'tf.fake_quant_with_min_max_vars': 'tf.quantization.fake_quant_with_min_max_vars', + 'tf.fake_quant_with_min_max_vars_gradient': 'tf.quantization.fake_quant_with_min_max_vars_gradient', + 'tf.fake_quant_with_min_max_vars_per_channel': 'tf.quantization.fake_quant_with_min_max_vars_per_channel', + 'tf.fake_quant_with_min_max_vars_per_channel_gradient': 'tf.quantization.fake_quant_with_min_max_vars_per_channel_gradient', + 'tf.fft': 'tf.spectral.fft', + 'tf.floor': 'tf.math.floor', + 'tf.gather_nd': 'tf.manip.gather_nd', + 'tf.greater': 'tf.math.greater', + 'tf.greater_equal': 'tf.math.greater_equal', + 'tf.ifft': 'tf.spectral.ifft', + 'tf.igamma': 'tf.math.igamma', + 'tf.igammac': 'tf.math.igammac', + 'tf.invert_permutation': 'tf.math.invert_permutation', + 'tf.is_finite': 'tf.debugging.is_finite', + 'tf.is_inf': 'tf.debugging.is_inf', + 'tf.is_nan': 'tf.debugging.is_nan', + 'tf.less': 'tf.math.less', + 'tf.less_equal': 'tf.math.less_equal', + 'tf.lgamma': 'tf.math.lgamma', + 'tf.log': 'tf.math.log', + 'tf.log1p': 'tf.math.log1p', + 'tf.logical_and': 'tf.math.logical_and', + 'tf.logical_not': 'tf.math.logical_not', + 'tf.logical_or': 'tf.math.logical_or', + 'tf.matching_files': 'tf.io.matching_files', + 'tf.matrix_band_part': 'tf.linalg.band_part', + 'tf.matrix_determinant': 'tf.linalg.det', + 'tf.matrix_diag': 'tf.linalg.diag', + 'tf.matrix_diag_part': 'tf.linalg.diag_part', + 'tf.matrix_inverse': 'tf.linalg.inv', + 'tf.matrix_set_diag': 'tf.linalg.set_diag', + 'tf.matrix_solve': 'tf.linalg.solve', + 'tf.matrix_triangular_solve': 'tf.linalg.triangular_solve', + 'tf.maximum': 'tf.math.maximum', + 'tf.minimum': 'tf.math.minimum', + 'tf.not_equal': 'tf.math.not_equal', + 'tf.parse_tensor': 'tf.io.parse_tensor', + 'tf.polygamma': 'tf.math.polygamma', + 'tf.qr': 'tf.linalg.qr', + 'tf.quantized_concat': 'tf.quantization.quantized_concat', + 'tf.read_file': 'tf.io.read_file', + 'tf.reciprocal': 'tf.math.reciprocal', + 'tf.regex_replace': 'tf.strings.regex_replace', + 'tf.reshape': 'tf.manip.reshape', + 'tf.reverse': 'tf.manip.reverse', + 'tf.reverse_v2': 'tf.manip.reverse', + 'tf.rint': 'tf.math.rint', + 'tf.rsqrt': 'tf.math.rsqrt', + 'tf.scatter_nd': 'tf.manip.scatter_nd', + 'tf.segment_max': 'tf.math.segment_max', + 'tf.segment_mean': 'tf.math.segment_mean', + 'tf.segment_min': 'tf.math.segment_min', + 'tf.segment_prod': 'tf.math.segment_prod', + 'tf.segment_sum': 'tf.math.segment_sum', + 'tf.sin': 'tf.math.sin', + 'tf.sinh': 'tf.math.sinh', + 'tf.space_to_batch_nd': 'tf.manip.space_to_batch_nd', + 'tf.squared_difference': 'tf.math.squared_difference', + 'tf.string_join': 'tf.strings.join', + 'tf.string_strip': 'tf.strings.strip', + 'tf.string_to_hash_bucket': 'tf.strings.to_hash_bucket', + 'tf.string_to_hash_bucket_fast': 'tf.strings.to_hash_bucket_fast', + 'tf.string_to_hash_bucket_strong': 'tf.strings.to_hash_bucket_strong', + 'tf.string_to_number': 'tf.strings.to_number', + 'tf.substr': 'tf.strings.substr', + 'tf.tan': 'tf.math.tan', + 'tf.tile': 'tf.manip.tile', + 'tf.unsorted_segment_max': 'tf.math.unsorted_segment_max', + 'tf.unsorted_segment_min': 'tf.math.unsorted_segment_min', + 'tf.unsorted_segment_prod': 'tf.math.unsorted_segment_prod', + 'tf.unsorted_segment_sum': 'tf.math.unsorted_segment_sum', + 'tf.write_file': 'tf.io.write_file', + 'tf.zeta': 'tf.math.zeta' +} diff --git a/tensorflow/tools/compatibility/testdata/test_file_v1_10.py b/tensorflow/tools/compatibility/testdata/test_file_v1_10.py new file mode 100644 index 00000000000..a49035a1a09 --- /dev/null +++ b/tensorflow/tools/compatibility/testdata/test_file_v1_10.py @@ -0,0 +1,34 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf upgrader.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tensorflow as tf +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test as test_lib + + +class TestUpgrade(test_util.TensorFlowTestCase): + """Test various APIs that have been changed in 2.0.""" + + def testRenames(self): + with self.test_session(): + self.assertAllClose(1.04719755, tf.acos(0.5).eval()) + self.assertAllClose(0.5, tf.rsqrt(4.0).eval()) + +if __name__ == "__main__": + test_lib.main() diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py index 1f8833582af..96705b1a4c2 100644 --- a/tensorflow/tools/compatibility/tf_upgrade.py +++ b/tensorflow/tools/compatibility/tf_upgrade.py @@ -19,491 +19,11 @@ from __future__ import division from __future__ import print_function import argparse -import ast -import collections -import os -import shutil -import sys -import tempfile -import traceback +from tensorflow.tools.compatibility import ast_edits -class APIChangeSpec(object): - """This class defines the transformations that need to happen. - This class must provide the following fields: - - * `function_keyword_renames`: maps function names to a map of old -> new - argument names - * `function_renames`: maps function names to new function names - * `change_to_function`: a set of function names that have changed (for - notifications) - * `function_reorders`: maps functions whose argument order has changed to the - list of arguments in the new order - * `function_handle`: maps function names to custom handlers for the function - - For an example, see `TFAPIChangeSpec`. - """ - - -class _FileEditTuple( - collections.namedtuple("_FileEditTuple", - ["comment", "line", "start", "old", "new"])): - """Each edit that is recorded by a _FileEditRecorder. - - Fields: - comment: A description of the edit and why it was made. - line: The line number in the file where the edit occurs (1-indexed). - start: The line number in the file where the edit occurs (0-indexed). - old: text string to remove (this must match what was in file). - new: text string to add in place of `old`. - """ - - __slots__ = () - - -class _FileEditRecorder(object): - """Record changes that need to be done to the file.""" - - def __init__(self, filename): - # all edits are lists of chars - self._filename = filename - - self._line_to_edit = collections.defaultdict(list) - self._errors = [] - - def process(self, text): - """Process a list of strings, each corresponding to the recorded changes. - - Args: - text: A list of lines of text (assumed to contain newlines) - Returns: - A tuple of the modified text and a textual description of what is done. - Raises: - ValueError: if substitution source location does not have expected text. - """ - - change_report = "" - - # Iterate of each line - for line, edits in self._line_to_edit.items(): - offset = 0 - # sort by column so that edits are processed in order in order to make - # indexing adjustments cumulative for changes that change the string - # length - edits.sort(key=lambda x: x.start) - - # Extract each line to a list of characters, because mutable lists - # are editable, unlike immutable strings. - char_array = list(text[line - 1]) - - # Record a description of the change - change_report += "%r Line %d\n" % (self._filename, line) - change_report += "-" * 80 + "\n\n" - for e in edits: - change_report += "%s\n" % e.comment - change_report += "\n Old: %s" % (text[line - 1]) - - # Make underscore buffers for underlining where in the line the edit was - change_list = [" "] * len(text[line - 1]) - change_list_new = [" "] * len(text[line - 1]) - - # Iterate for each edit - for e in edits: - # Create effective start, end by accounting for change in length due - # to previous edits - start_eff = e.start + offset - end_eff = start_eff + len(e.old) - - # Make sure the edit is changing what it should be changing - old_actual = "".join(char_array[start_eff:end_eff]) - if old_actual != e.old: - raise ValueError("Expected text %r but got %r" % - ("".join(e.old), "".join(old_actual))) - # Make the edit - char_array[start_eff:end_eff] = list(e.new) - - # Create the underline highlighting of the before and after - change_list[e.start:e.start + len(e.old)] = "~" * len(e.old) - change_list_new[start_eff:end_eff] = "~" * len(e.new) - - # Keep track of how to generate effective ranges - offset += len(e.new) - len(e.old) - - # Finish the report comment - change_report += " %s\n" % "".join(change_list) - text[line - 1] = "".join(char_array) - change_report += " New: %s" % (text[line - 1]) - change_report += " %s\n\n" % "".join(change_list_new) - return "".join(text), change_report, self._errors - - def add(self, comment, line, start, old, new, error=None): - """Add a new change that is needed. - - Args: - comment: A description of what was changed - line: Line number (1 indexed) - start: Column offset (0 indexed) - old: old text - new: new text - error: this "edit" is something that cannot be fixed automatically - Returns: - None - """ - - self._line_to_edit[line].append( - _FileEditTuple(comment, line, start, old, new)) - if error: - self._errors.append("%s:%d: %s" % (self._filename, line, error)) - - -class _ASTCallVisitor(ast.NodeVisitor): - """AST Visitor that processes function calls. - - Updates function calls from old API version to new API version using a given - change spec. - """ - - def __init__(self, filename, lines, api_change_spec): - self._filename = filename - self._file_edit = _FileEditRecorder(filename) - self._lines = lines - self._api_change_spec = api_change_spec - - def process(self, lines): - return self._file_edit.process(lines) - - def generic_visit(self, node): - ast.NodeVisitor.generic_visit(self, node) - - def _rename_functions(self, node, full_name): - function_renames = self._api_change_spec.function_renames - try: - new_name = function_renames[full_name] - self._file_edit.add("Renamed function %r to %r" % (full_name, new_name), - node.lineno, node.col_offset, full_name, new_name) - except KeyError: - pass - - def _get_attribute_full_path(self, node): - """Traverse an attribute to generate a full name e.g. tf.foo.bar. - - Args: - node: A Node of type Attribute. - - Returns: - a '.'-delimited full-name or None if the tree was not a simple form. - i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c". - """ - curr = node - items = [] - while not isinstance(curr, ast.Name): - if not isinstance(curr, ast.Attribute): - return None - items.append(curr.attr) - curr = curr.value - items.append(curr.id) - return ".".join(reversed(items)) - - def _find_true_position(self, node): - """Return correct line number and column offset for a given node. - - This is necessary mainly because ListComp's location reporting reports - the next token after the list comprehension list opening. - - Args: - node: Node for which we wish to know the lineno and col_offset - """ - import re - find_open = re.compile("^\s*(\\[).*$") - find_string_chars = re.compile("['\"]") - - if isinstance(node, ast.ListComp): - # Strangely, ast.ListComp returns the col_offset of the first token - # after the '[' token which appears to be a bug. Workaround by - # explicitly finding the real start of the list comprehension. - line = node.lineno - col = node.col_offset - # loop over lines - while 1: - # Reverse the text to and regular expression search for whitespace - text = self._lines[line - 1] - reversed_preceding_text = text[:col][::-1] - # First find if a [ can be found with only whitespace between it and - # col. - m = find_open.match(reversed_preceding_text) - if m: - new_col_offset = col - m.start(1) - 1 - return line, new_col_offset - else: - if (reversed_preceding_text == "" or - reversed_preceding_text.isspace()): - line = line - 1 - prev_line = self._lines[line - 1] - # TODO(aselle): - # this is poor comment detection, but it is good enough for - # cases where the comment does not contain string literal starting/ - # ending characters. If ast gave us start and end locations of the - # ast nodes rather than just start, we could use string literal - # node ranges to filter out spurious #'s that appear in string - # literals. - comment_start = prev_line.find("#") - if comment_start == -1: - col = len(prev_line) - 1 - elif find_string_chars.search(prev_line[comment_start:]) is None: - col = comment_start - else: - return None, None - else: - return None, None - # Most other nodes return proper locations (with notably does not), but - # it is not possible to use that in an argument. - return node.lineno, node.col_offset - - def visit_Call(self, node): # pylint: disable=invalid-name - """Handle visiting a call node in the AST. - - Args: - node: Current Node - """ - - # Find a simple attribute name path e.g. "tf.foo.bar" - full_name = self._get_attribute_full_path(node.func) - - # Make sure the func is marked as being part of a call - node.func.is_function_for_call = True - - if full_name: - # Call special handlers - function_handles = self._api_change_spec.function_handle - if full_name in function_handles: - function_handles[full_name](self._file_edit, node) - - # Examine any non-keyword argument and make it into a keyword argument - # if reordering required. - function_reorders = self._api_change_spec.function_reorders - function_keyword_renames = ( - self._api_change_spec.function_keyword_renames) - - if full_name in function_reorders: - reordered = function_reorders[full_name] - for idx, arg in enumerate(node.args): - lineno, col_offset = self._find_true_position(arg) - if lineno is None or col_offset is None: - self._file_edit.add( - "Failed to add keyword %r to reordered function %r" % - (reordered[idx], full_name), - arg.lineno, - arg.col_offset, - "", - "", - error="A necessary keyword argument failed to be inserted.") - else: - keyword_arg = reordered[idx] - if (full_name in function_keyword_renames and - keyword_arg in function_keyword_renames[full_name]): - keyword_arg = function_keyword_renames[full_name][keyword_arg] - self._file_edit.add("Added keyword %r to reordered function %r" % - (reordered[idx], full_name), lineno, col_offset, - "", keyword_arg + "=") - - # Examine each keyword argument and convert it to the final renamed form - renamed_keywords = ({} if full_name not in function_keyword_renames else - function_keyword_renames[full_name]) - for keyword in node.keywords: - argkey = keyword.arg - argval = keyword.value - - if argkey in renamed_keywords: - argval_lineno, argval_col_offset = self._find_true_position(argval) - if argval_lineno is not None and argval_col_offset is not None: - # TODO(aselle): We should scan backward to find the start of the - # keyword key. Unfortunately ast does not give you the location of - # keyword keys, so we are forced to infer it from the keyword arg - # value. - key_start = argval_col_offset - len(argkey) - 1 - key_end = key_start + len(argkey) + 1 - if (self._lines[argval_lineno - 1][key_start:key_end] == argkey + - "="): - self._file_edit.add("Renamed keyword argument from %r to %r" % - (argkey, - renamed_keywords[argkey]), argval_lineno, - argval_col_offset - len(argkey) - 1, - argkey + "=", renamed_keywords[argkey] + "=") - continue - self._file_edit.add( - "Failed to rename keyword argument from %r to %r" % - (argkey, renamed_keywords[argkey]), - argval.lineno, - argval.col_offset - len(argkey) - 1, - "", - "", - error="Failed to find keyword lexographically. Fix manually.") - - ast.NodeVisitor.generic_visit(self, node) - - def visit_Attribute(self, node): # pylint: disable=invalid-name - """Handle bare Attributes i.e. [tf.foo, tf.bar]. - - Args: - node: Node that is of type ast.Attribute - """ - full_name = self._get_attribute_full_path(node) - if full_name: - self._rename_functions(node, full_name) - if full_name in self._api_change_spec.change_to_function: - if not hasattr(node, "is_function_for_call"): - new_text = full_name + "()" - self._file_edit.add("Changed %r to %r" % (full_name, new_text), - node.lineno, node.col_offset, full_name, new_text) - - ast.NodeVisitor.generic_visit(self, node) - - -class ASTCodeUpgrader(object): - """Handles upgrading a set of Python files using a given API change spec.""" - - def __init__(self, api_change_spec): - if not isinstance(api_change_spec, APIChangeSpec): - raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" % - type(api_change_spec)) - self._api_change_spec = api_change_spec - - def process_file(self, in_filename, out_filename): - """Process the given python file for incompatible changes. - - Args: - in_filename: filename to parse - out_filename: output file to write to - Returns: - A tuple representing number of files processed, log of actions, errors - """ - - # Write to a temporary file, just in case we are doing an implace modify. - with open(in_filename, "r") as in_file, \ - tempfile.NamedTemporaryFile("w", delete=False) as temp_file: - ret = self.process_opened_file(in_filename, in_file, out_filename, - temp_file) - - shutil.move(temp_file.name, out_filename) - return ret - - # Broad exceptions are required here because ast throws whatever it wants. - # pylint: disable=broad-except - def process_opened_file(self, in_filename, in_file, out_filename, out_file): - """Process the given python file for incompatible changes. - - This function is split out to facilitate StringIO testing from - tf_upgrade_test.py. - - Args: - in_filename: filename to parse - in_file: opened file (or StringIO) - out_filename: output file to write to - out_file: opened file (or StringIO) - Returns: - A tuple representing number of files processed, log of actions, errors - """ - process_errors = [] - text = "-" * 80 + "\n" - text += "Processing file %r\n outputting to %r\n" % (in_filename, - out_filename) - text += "-" * 80 + "\n\n" - - parsed_ast = None - lines = in_file.readlines() - try: - parsed_ast = ast.parse("".join(lines)) - except Exception: - text += "Failed to parse %r\n\n" % in_filename - text += traceback.format_exc() - if parsed_ast: - visitor = _ASTCallVisitor(in_filename, lines, self._api_change_spec) - visitor.visit(parsed_ast) - out_text, new_text, process_errors = visitor.process(lines) - text += new_text - if out_file: - out_file.write(out_text) - text += "\n" - return 1, text, process_errors - - # pylint: enable=broad-except - - def process_tree(self, root_directory, output_root_directory, - copy_other_files): - """Processes upgrades on an entire tree of python files in place. - - Note that only Python files. If you have custom code in other languages, - you will need to manually upgrade those. - - Args: - root_directory: Directory to walk and process. - output_root_directory: Directory to use as base. - copy_other_files: Copy files that are not touched by this converter. - - Returns: - A tuple of files processed, the report string ofr all files, and errors - """ - - # make sure output directory doesn't exist - if output_root_directory and os.path.exists(output_root_directory): - print("Output directory %r must not already exist." % - (output_root_directory)) - sys.exit(1) - - # make sure output directory does not overlap with root_directory - norm_root = os.path.split(os.path.normpath(root_directory)) - norm_output = os.path.split(os.path.normpath(output_root_directory)) - if norm_root == norm_output: - print("Output directory %r same as input directory %r" % - (root_directory, output_root_directory)) - sys.exit(1) - - # Collect list of files to process (we do this to correctly handle if the - # user puts the output directory in some sub directory of the input dir) - files_to_process = [] - files_to_copy = [] - for dir_name, _, file_list in os.walk(root_directory): - py_files = [f for f in file_list if f.endswith(".py")] - copy_files = [f for f in file_list if not f.endswith(".py")] - for filename in py_files: - fullpath = os.path.join(dir_name, filename) - fullpath_output = os.path.join(output_root_directory, - os.path.relpath(fullpath, - root_directory)) - files_to_process.append((fullpath, fullpath_output)) - if copy_other_files: - for filename in copy_files: - fullpath = os.path.join(dir_name, filename) - fullpath_output = os.path.join(output_root_directory, - os.path.relpath( - fullpath, root_directory)) - files_to_copy.append((fullpath, fullpath_output)) - - file_count = 0 - tree_errors = [] - report = "" - report += ("=" * 80) + "\n" - report += "Input tree: %r\n" % root_directory - report += ("=" * 80) + "\n" - - for input_path, output_path in files_to_process: - output_directory = os.path.dirname(output_path) - if not os.path.isdir(output_directory): - os.makedirs(output_directory) - file_count += 1 - _, l_report, l_errors = self.process_file(input_path, output_path) - tree_errors += l_errors - report += l_report - for input_path, output_path in files_to_copy: - output_directory = os.path.dirname(output_path) - if not os.path.isdir(output_directory): - os.makedirs(output_directory) - shutil.copy(input_path, output_path) - return file_count, report, tree_errors - - -class TFAPIChangeSpec(APIChangeSpec): +class TFAPIChangeSpec(ast_edits.APIChangeSpec): """List of maps that describe what changed in the API.""" def __init__(self): @@ -718,7 +238,7 @@ Simple usage: default="report.txt") args = parser.parse_args() - upgrade = ASTCodeUpgrader(TFAPIChangeSpec()) + upgrade = ast_edits.ASTCodeUpgrader(TFAPIChangeSpec()) report_text = None report_filename = args.report_filename files_processed = 0 diff --git a/tensorflow/tools/compatibility/tf_upgrade_test.py b/tensorflow/tools/compatibility/tf_upgrade_test.py index 3d02eacba6e..66325ea2ad3 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_test.py @@ -22,6 +22,7 @@ import tempfile import six from tensorflow.python.framework import test_util from tensorflow.python.platform import test as test_lib +from tensorflow.tools.compatibility import ast_edits from tensorflow.tools.compatibility import tf_upgrade @@ -36,7 +37,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): def _upgrade(self, old_file_text): in_file = six.StringIO(old_file_text) out_file = six.StringIO() - upgrader = tf_upgrade.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) + upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) count, report, errors = ( upgrader.process_opened_file("test.py", in_file, "test_out.py", out_file)) @@ -139,7 +140,7 @@ class TestUpgradeFiles(test_util.TensorFlowTestCase): upgraded = "tf.multiply(a, b)\n" temp_file.write(original) temp_file.close() - upgrader = tf_upgrade.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) + upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) upgrader.process_file(temp_file.name, temp_file.name) self.assertAllEqual(open(temp_file.name).read(), upgraded) os.unlink(temp_file.name) diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py new file mode 100644 index 00000000000..9702430a121 --- /dev/null +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -0,0 +1,115 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Upgrader for Python scripts from 1.* TensorFlow to 2.0 TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse + +from tensorflow.tools.compatibility import ast_edits +from tensorflow.tools.compatibility import renames_v2 + + +class TFAPIChangeSpec(ast_edits.APIChangeSpec): + """List of maps that describe what changed in the API.""" + + def __init__(self): + # Maps from a function name to a dictionary that describes how to + # map from an old argument keyword to the new argument keyword. + self.function_keyword_renames = {} + + # Mapping from function to the new name of the function + self.function_renames = renames_v2.renames + + # Variables that should be changed to functions. + self.change_to_function = {} + + # Functions that were reordered should be changed to the new keyword args + # for safety, if positional arguments are used. If you have reversed the + # positional arguments yourself, this could do the wrong thing. + self.function_reorders = {} + + # Specially handled functions. + self.function_handle = {} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="""Convert a TensorFlow Python file to 2.0 + +Simple usage: + tf_convert_v2.py --infile foo.py --outfile bar.py + tf_convert_v2.py --intree ~/code/old --outtree ~/code/new +""") + parser.add_argument( + "--infile", + dest="input_file", + help="If converting a single file, the name of the file " + "to convert") + parser.add_argument( + "--outfile", + dest="output_file", + help="If converting a single file, the output filename.") + parser.add_argument( + "--intree", + dest="input_tree", + help="If converting a whole tree of files, the directory " + "to read from (relative or absolute).") + parser.add_argument( + "--outtree", + dest="output_tree", + help="If converting a whole tree of files, the output " + "directory (relative or absolute).") + parser.add_argument( + "--copyotherfiles", + dest="copy_other_files", + help=("If converting a whole tree of files, whether to " + "copy the other files."), + type=bool, + default=False) + parser.add_argument( + "--reportfile", + dest="report_filename", + help=("The name of the file where the report log is " + "stored." + "(default: %(default)s)"), + default="report.txt") + args = parser.parse_args() + + upgrade = ast_edits.ASTCodeUpgrader(TFAPIChangeSpec()) + report_text = None + report_filename = args.report_filename + files_processed = 0 + if args.input_file: + files_processed, report_text, errors = upgrade.process_file( + args.input_file, args.output_file) + files_processed = 1 + elif args.input_tree: + files_processed, report_text, errors = upgrade.process_tree( + args.input_tree, args.output_tree, args.copy_other_files) + else: + parser.print_help() + if report_text: + open(report_filename, "w").write(report_text) + print("TensorFlow 2.0 Upgrade Script") + print("-----------------------------") + print("Converted %d files\n" % files_processed) + print("Detected %d errors that require attention" % len(errors)) + print("-" * 80) + print("\n".join(errors)) + print("\nMake sure to read the detailed log %r\n" % report_filename) diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py new file mode 100644 index 00000000000..57ac04de066 --- /dev/null +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -0,0 +1,83 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf 2.0 upgrader.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import tempfile +import six +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test as test_lib +from tensorflow.tools.compatibility import ast_edits +from tensorflow.tools.compatibility import tf_upgrade_v2 + + +class TestUpgrade(test_util.TensorFlowTestCase): + """Test various APIs that have been changed in 2.0. + + We also test whether a converted file is executable. test_file_v1_10.py + aims to exhaustively test that API changes are convertible and actually + work when run with current TensorFlow. + """ + + def _upgrade(self, old_file_text): + in_file = six.StringIO(old_file_text) + out_file = six.StringIO() + upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec()) + count, report, errors = ( + upgrader.process_opened_file("test.py", in_file, + "test_out.py", out_file)) + return count, report, errors, out_file.getvalue() + + def testParseError(self): + _, report, unused_errors, unused_new_text = self._upgrade( + "import tensorflow as tf\na + \n") + self.assertTrue(report.find("Failed to parse") != -1) + + def testReport(self): + text = "tf.acos(a)\n" + _, report, unused_errors, unused_new_text = self._upgrade(text) + # This is not a complete test, but it is a sanity test that a report + # is generating information. + self.assertTrue(report.find("Renamed function `tf.acos` to `tf.math.acos`")) + + def testRename(self): + text = "tf.acos(a)\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.math.acos(a)\n") + text = "tf.rsqrt(tf.log(3.8))\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.math.rsqrt(tf.math.log(3.8))\n") + + +class TestUpgradeFiles(test_util.TensorFlowTestCase): + + def testInplace(self): + """Check to make sure we don't have a file system race.""" + temp_file = tempfile.NamedTemporaryFile("w", delete=False) + original = "tf.acos(a, b)\n" + upgraded = "tf.math.acos(a, b)\n" + temp_file.write(original) + temp_file.close() + upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec()) + upgrader.process_file(temp_file.name, temp_file.name) + self.assertAllEqual(open(temp_file.name).read(), upgraded) + os.unlink(temp_file.name) + + +if __name__ == "__main__": + test_lib.main() diff --git a/tensorflow/tools/compatibility/update/BUILD b/tensorflow/tools/compatibility/update/BUILD new file mode 100644 index 00000000000..feb37c902ec --- /dev/null +++ b/tensorflow/tools/compatibility/update/BUILD @@ -0,0 +1,15 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +py_binary( + name = "generate_v2_renames_map", + srcs = ["generate_v2_renames_map.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:lib", + "//tensorflow/tools/common:public_api", + "//tensorflow/tools/common:traverse", + ], +) diff --git a/tensorflow/tools/compatibility/update/generate_v2_renames_map.py b/tensorflow/tools/compatibility/update/generate_v2_renames_map.py new file mode 100644 index 00000000000..567eceb0b65 --- /dev/null +++ b/tensorflow/tools/compatibility/update/generate_v2_renames_map.py @@ -0,0 +1,103 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=line-too-long +"""Script for updating tensorflow/tools/compatibility/renames_v2.py. + +To update renames_v2.py, run: + bazel build tensorflow/tools/compatibility/update:generate_v2_renames_map + bazel-bin/tensorflow/tools/compatibility/update/generate_v2_renames_map +""" +# pylint: enable=line-too-long + +import tensorflow as tf + +from tensorflow.python.lib.io import file_io +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_export +from tensorflow.tools.common import public_api +from tensorflow.tools.common import traverse + + +_OUTPUT_FILE_PATH = 'third_party/tensorflow/tools/compatibility/renames_v2.py' +_FILE_HEADER = """# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=line-too-long +\"\"\"List of renames to apply when converting from TF 1.0 to TF 2.0. + +THIS FILE IS AUTOGENERATED: To update, please run: + bazel build tensorflow/tools/compatibility/update:generate_v2_renames_map + bazel-bin/tensorflow/tools/compatibility/update/generate_v2_renames_map +This file should be updated whenever endpoints are deprecated. +\"\"\" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +""" + + +def update_renames_v2(output_file_path): + """Writes a Python dictionary mapping deprecated to canonical API names. + + Args: + output_file_path: File path to write output to. Any existing contents + would be replaced. + """ + # Set of rename lines to write to output file in the form: + # 'tf.deprecated_name': 'tf.canonical_name' + rename_line_set = set() + # _tf_api_names attribute name + tensorflow_api_attr = tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names + + def visit(unused_path, unused_parent, children): + """Visitor that collects rename strings to add to rename_line_set.""" + for child in children: + _, attr = tf_decorator.unwrap(child[1]) + if not hasattr(attr, '__dict__'): + continue + api_names = attr.__dict__.get(tensorflow_api_attr, []) + deprecated_api_names = attr.__dict__.get('_tf_deprecated_api_names', []) + canonical_name = tf_export.get_canonical_name( + api_names, deprecated_api_names) + for name in deprecated_api_names: + rename_line_set.add(' \'tf.%s\': \'tf.%s\'' % (name, canonical_name)) + + visitor = public_api.PublicAPIVisitor(visit) + visitor.do_not_descend_map['tf'].append('contrib') + traverse.traverse(tf, visitor) + + renames_file_text = '%srenames = {\n%s\n}\n' % ( + _FILE_HEADER, ',\n'.join(sorted(rename_line_set))) + file_io.write_string_to_file(output_file_path, renames_file_text) + + +def main(unused_argv): + update_renames_v2(_OUTPUT_FILE_PATH) + + +if __name__ == '__main__': + tf.app.run(main=main)