diff --git a/tensorflow/BUILD b/tensorflow/BUILD index ef04a6a88ed..aa29b45c1ee 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -207,6 +207,7 @@ filegroup( "//tensorflow/tensorboard/lib/python:all_files", "//tensorflow/tensorboard/scripts:all_files", "//tensorflow/tools/common:all_files", + "//tensorflow/tools/compatibility:all_files", "//tensorflow/tools/dist_test/server:all_files", "//tensorflow/tools/docker:all_files", "//tensorflow/tools/docker/notebooks:all_files", diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD new file mode 100644 index 00000000000..0f3de10a0ad --- /dev/null +++ b/tensorflow/tools/compatibility/BUILD @@ -0,0 +1,83 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_copts", # @unused + "tf_cc_test", # @unused +) + +py_binary( + name = "tf_upgrade", + srcs = ["tf_upgrade.py"], + srcs_version = "PY2AND3", +) + +py_test( + name = "tf_upgrade_test", + srcs = ["tf_upgrade_test.py"], + srcs_version = "PY2AND3", + deps = [ + "tf_upgrade", + "//tensorflow:tensorflow_py", + ], +) + +# Keep for reference, this test will succeed in 0.11 but fail in 1.0 +# py_test( +# name = "test_file_v0_11", +# size = "small", +# srcs = ["testdata/test_file_v0_11.py"], +# srcs_version = "PY2AND3", +# deps = [ +# "//tensorflow:tensorflow_py", +# ], +# ) + +genrule( + name = "generate_upgraded_file", + testonly = 1, + srcs = ["testdata/test_file_v0_11.py"], + outs = [ + "test_file_v1_0.py", + "report.txt", + ], + cmd = ("$(location tf_upgrade)" + + " --infile $(location testdata/test_file_v0_11.py)" + + " --outfile $(location test_file_v1_0.py)" + + " --reportfile $(location report.txt)"), + tools = ["tf_upgrade"], +) + +py_test( + name = "test_file_v1_0", + size = "small", + srcs = ["test_file_v1_0.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +exports_files( + [ + "tf_upgrade.py", + "testdata/test_file_v0_11.py", + ], +) + +# ----------------------------------------------------------------------------- +# Google-internal targets. These must be at the end for syncrepo. + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/tools/compatibility/README.md b/tensorflow/tools/compatibility/README.md new file mode 100644 index 00000000000..3b66e73f9a2 --- /dev/null +++ b/tensorflow/tools/compatibility/README.md @@ -0,0 +1,48 @@ +# TensorFlow Python API Upgrade Utility + +This tool allows you to upgrade your existing TensorFlow Python scripts. +This script can be run on a single Python file: + +``` +tf_upgrade.py --infile foo.py --outfile foo-upgraded.py +``` + +It will print a list of errors it finds that it can't fix. You can also run +it on a directory tree: + +``` +tf_upgrade.py --intree coolcode -outtree coolcode-upgraded +``` + +In either case, it will also dump out a report e.g. which will detail changes +e.g.: + +``` +third_party/tensorflow/tools/compatibility/test_file_v0.11.py Line 125 + +Renamed keyword argument from `dim` to `axis` +Renamed keyword argument from `squeeze_dims` to `axis` + + Old: [[1, 2, 3]], dim=1), squeeze_dims=[1]).eval(), + ~~~~ ~~~~~~~~~~~~~ + New: [[1, 2, 3]], axis=1), axis=[1]).eval(), + ~~~~~ ~~~~~ +``` + +## Caveats + +- Don't update parts of your code manually before running this script. In +particular, functions that have had reordered arguments like `tf.concat`, +`tf.split` will cause the script to incorrectly add keyword arguments that +mismap arguments. + +- This script is not able to upgrade all functions. One notable example is +`tf.reverse()` which has been changed to take a list of indices rather than +a tensor of bools. If the script detects this, it will report this to stdout +(and in the report), and you can fix it manually. For example if you have +`tf.reverse(a, [False, True, True])` you will need to manually change it to +`tf.reverse(a, [1, 2])`. + + + + diff --git a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py new file mode 100644 index 00000000000..37d914c6486 --- /dev/null +++ b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py @@ -0,0 +1,208 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf upgrader.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import shutil +import tempfile +import numpy as np +import tensorflow as tf +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test as test_lib + + +class TestUpgrade(test_util.TensorFlowTestCase): + """Test various APIs that have been changed in 1.0. + + This test will not run in current TensorFlow, but did run in 0.11. + This file is intended to be converted by a genrule() that uses the converter + so that a 1.0 compatible version of this file is generated. That is run as + a unit test if the converter is successful. + """ + + def testArgRenames(self): + with self.test_session(): + + a = [[1., 2., 3.], [4., 5., 6.]] + b = [[True, False, False], [False, True, True]] + dim0 = [1] + dim1 = [1] + + self.assertAllEqual( + tf.reduce_any( + b, reduction_indices=dim0).eval(), [True, True]) + self.assertAllEqual( + tf.reduce_all( + b, reduction_indices=[0]).eval(), [False, False, False]) + self.assertAllEqual( + tf.reduce_all( + b, reduction_indices=dim1).eval(), [False, False]) + self.assertAllEqual( + tf.reduce_sum( + a, reduction_indices=[1]).eval(), [6., 15.]) + self.assertAllEqual( + tf.reduce_sum( + a, reduction_indices=[0, 1]).eval(), 21.0) + self.assertAllEqual(tf.reduce_sum(a, [0, 1]).eval(), 21.0) + self.assertAllEqual( + tf.reduce_prod( + a, reduction_indices=[1]).eval(), [6., 120.]) + self.assertAllEqual( + tf.reduce_prod( + a, reduction_indices=[0, 1]).eval(), 720.0) + self.assertAllEqual(tf.reduce_prod(a, [0, 1]).eval(), 720.0) + self.assertAllEqual( + tf.reduce_mean( + a, reduction_indices=[1]).eval(), [2., 5.]) + self.assertAllEqual( + tf.reduce_mean( + a, reduction_indices=[0, 1]).eval(), 3.5) + self.assertAllEqual(tf.reduce_mean(a, [0, 1]).eval(), 3.5) + self.assertAllEqual( + tf.reduce_min( + a, reduction_indices=[1]).eval(), [1., 4.]) + self.assertAllEqual( + tf.reduce_min( + a, reduction_indices=[0, 1]).eval(), 1.0) + self.assertAllEqual(tf.reduce_min(a, [0, 1]).eval(), 1.0) + self.assertAllEqual( + tf.reduce_max( + a, reduction_indices=[1]).eval(), [3., 6.]) + self.assertAllEqual( + tf.reduce_max( + a, reduction_indices=[0, 1]).eval(), 6.0) + self.assertAllEqual(tf.reduce_max(a, [0, 1]).eval(), 6.0) + self.assertAllClose(tf.reduce_logsumexp(a, reduction_indices=[1]).eval(), + [3.40760589, 6.40760612]) + self.assertAllClose( + tf.reduce_logsumexp(a, reduction_indices=[0, 1]).eval(), + 6.45619344711) + self.assertAllClose( + tf.reduce_logsumexp(a, [0, 1]).eval(), 6.45619344711) + self.assertAllEqual( + tf.expand_dims([[1, 2], [3, 4]], dim=1).eval(), + [[[1, 2]], [[3, 4]]]) + + def testArgMinMax(self): + with self.test_session(): + self.assertAllEqual( + tf.argmin([[1, 2, 3], [4, 1, 0]], dimension=1).eval(), + [0, 2]) + self.assertAllEqual( + tf.argmin([[1, 2, 3], [4, 1, 0]], dimension=0).eval(), + [0, 1, 1]) + self.assertAllEqual( + tf.argmax([[1, 2, 3], [4, 1, 0]], dimension=1).eval(), + [2, 0]) + self.assertAllEqual( + tf.argmax([[1, 2, 3], [4, 1, 0]], dimension=0).eval(), + [1, 0, 0]) + + def testExpandAndSqueeze(self): + with self.test_session(): + + # TODO(aselle): sparse_split, sparse_reduce_sum, + # sparse_reduce_sum_sparse, reduce_join + a = [[1, 2, 3]] + self.assertAllEqual(tf.expand_dims(tf.squeeze(a, [0]), 0).eval(), + a) + self.assertAllEqual(tf.squeeze(tf.expand_dims(a, 1), [1]).eval(), + a) + self.assertAllEqual( + tf.expand_dims( + tf.squeeze( + [[1, 2, 3]], squeeze_dims=[0]), dim=0).eval(), + a) + self.assertAllEqual( + tf.squeeze( + tf.expand_dims( + [[1, 2, 3]], dim=1), squeeze_dims=[1]).eval(), + a) + + self.assertAllEqual( + tf.squeeze( + tf.expand_dims( + [[1, 2, 3]], dim=1), squeeze_dims=[1]).eval(), + a) + + def testArithmeticRenames(self): + with self.test_session() as s: + stuff = tf.split(1, 2, [[1, 2, 3, 4], [4, 5, 6, 7]]) + vals = s.run(stuff) + self.assertAllEqual(vals, + [[[1, 2], [4, 5]], [[3, 4], [6, 7]]]) + self.assertAllEqual( + tf.neg(tf.mul(tf.add(1, 2), tf.sub(5, 3))).eval(), + -6) + self.assertAllEqual( + s.run(tf.listdiff([1, 2, 3], [3, 3, 4]))[0], [1, 2]) + self.assertAllEqual( + tf.list_diff([1, 2, 3], [3, 3, 4])[0].eval(), [1, 2]) + a = [[1., 2., 3.], [4., 5., 6.]] + foo = np.where(np.less(a, 2), np.negative(a), a) + self.assertAllEqual( + tf.select(tf.less(a, 2), tf.neg(a), a).eval(), + foo) + self.assertAllEqual( + tf.complex_abs(tf.constant(3 + 4.j)).eval(), + 5) + # # TODO(aselle): (tf.batch_*) + # ] + + def testVariables(self): + with self.test_session() as s: + + # make some variables + _ = [tf.Variable([1, 2, 3], dtype=tf.float32), + tf.Variable([1, 2, 3], dtype=tf.int32)] + s.run(tf.initialize_all_variables()) + _ = [v.name for v in tf.all_variables()] + _ = [v.name for v in tf.local_variables()] + + def testSummaries(self): + with self.test_session() as s: + var = tf.Variable([1, 2, 3], dtype=tf.float32) + s.run(tf.initialize_all_variables()) + x, y = np.meshgrid(np.linspace(-10, 10, 256), np.linspace(-10, 10, 256)) + image = np.sin(x**2 + y**2) / np.sqrt(x**2 + y**2) * .5 + .5 + image = image[None, :, :, None] + + # make a dummy sound + freq = 440 # A = 440Hz + sampling_frequency = 11000 + audio = np.sin(2 * np.pi * np.linspace(0, 1, sampling_frequency) * freq) + audio = audio[None, :, None] + test_dir = tempfile.mkdtemp() + # test summaries + writer = tf.train.SummaryWriter(test_dir) + summaries = [ + tf.scalar_summary("scalar_var", var[0]), + tf.scalar_summary("scalar_reduce_var", tf.reduce_sum(var)), + tf.histogram_summary("var_histogram", var), + tf.image_summary("sin_image", image), + tf.audio_summary("sin_wave", audio, sampling_frequency), + ] + run_summaries = s.run(summaries) + writer.add_summary(s.run(tf.merge_summary(inputs=run_summaries))) + # This is redundant, but we want to be able to rewrite the command + writer.add_summary(s.run(tf.merge_all_summaries())) + writer.close() + shutil.rmtree(test_dir) + + +if __name__ == "__main__": + test_lib.main() diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py new file mode 100644 index 00000000000..223f8cd5f5d --- /dev/null +++ b/tensorflow/tools/compatibility/tf_upgrade.py @@ -0,0 +1,550 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Upgrader for Python scripts from pre-1.0 TensorFlow to 1.0 TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import argparse +import ast +import collections +import os +import sys +import traceback + +# TODO(aselle): Add SVD, Concat +# TODO(aselle): summary merge all (can we detect this?) +# TODO(aselle): batch_matmul +# TODO(wicke): tf.nn.{softmax,sparse_softmax,sigmoid}_cross_entropy_with_logits? + + +class APIChangeSpec(object): + """List of maps that describe what changed in the API.""" + + def __init__(self): + # Maps from a function name to a dictionary that describes how to + # map from an old argument keyword to the new argument keyword. + self.function_keyword_renames = { + "tf.count_nonzero": { + "reduction_indices": "axis" + }, + "tf.reduce_all": { + "reduction_indices": "axis" + }, + "tf.reduce_any": { + "reduction_indices": "axis" + }, + "tf.reduce_max": { + "reduction_indices": "axis" + }, + "tf.reduce_mean": { + "reduction_indices": "axis" + }, + "tf.reduce_min": { + "reduction_indices": "axis" + }, + "tf.reduce_prod": { + "reduction_indices": "axis" + }, + "tf.reduce_sum": { + "reduction_indices": "axis" + }, + "tf.reduce_logsumexp": { + "reduction_indices": "axis" + }, + "tf.expand_dims": { + "dim": "axis" + }, + "tf.argmax": { + "dimension": "axis" + }, + "tf.argmin": { + "dimension": "axis" + }, + "tf.reduce_join": { + "reduction_indices": "axis" + }, + "tf.sparse_concat": { + "concat_dim": "axis" + }, + "tf.sparse_split": { + "split_dim": "axis" + }, + "tf.sparse_reduce_sum": { + "reduction_axes": "axis" + }, + "tf.reverse_sequence": { + "seq_dim": "seq_axis", + "batch_dim": "batch_axis" + }, + "tf.sparse_reduce_sum_sparse": { + "reduction_axes": "axis" + }, + "tf.squeeze": { + "squeeze_dims": "axis" + }, + "tf.split": { + "split_dim": "axis", + "num_split": "num_or_size_splits" + } + } + + # Mapping from function to the new name of the function + self.function_renames = { + "tf.contrib.deprecated.scalar_summary": "tf.summary.scalar", + "tf.contrib.deprecated.histogram_summary": "tf.summary.histogram", + "tf.listdiff": "tf.setdiff1d", + "tf.list_diff": "tf.setdiff1d", + "tf.mul": "tf.multiply", + "tf.neg": "tf.negative", + "tf.sub": "tf.subtract", + "tf.train.SummaryWriter": "tf.summary.FileWriter", + "tf.scalar_summary": "tf.summary.scalar", + "tf.histogram_summary": "tf.summary.histogram", + "tf.audio_summary": "tf.summary.audio", + "tf.image_summary": "tf.summary.image", + "tf.merge_summary": "tf.summary.merge", + "tf.merge_all_summaries": "tf.summary.merge_all", + "tf.image.per_image_whitening": "tf.image.per_image_standardization", + "tf.all_variables": "tf.global_variables", + "tf.VARIABLES": "tf.GLOBAL_VARIABLES", + "tf.initialize_all_variables": "tf.global_variables_initializer", + "tf.initialize_variables": "tf.variables_initializer", + "tf.initialize_local_variables": "tf.local_variables_initializer", + "tf.batch_matrix_diag": "tf.matrix_diag", + "tf.batch_band_part": "tf.band_part", + "tf.batch_set_diag": "tf.set_diag", + "tf.batch_matrix_transpose": "tf.matrix_transpose", + "tf.batch_matrix_determinant": "tf.matrix_determinant", + "tf.batch_matrix_inverse": "tf.matrix_inverse", + "tf.batch_cholesky": "tf.cholesky", + "tf.batch_cholesky_solve": "tf.cholesky_solve", + "tf.batch_matrix_solve": "tf.matrix_solve", + "tf.batch_matrix_triangular_solve": "tf.matrix_triangular_solve", + "tf.batch_matrix_solve_ls": "tf.matrix_solve_ls", + "tf.batch_self_adjoint_eig": "tf.self_adjoint_eig", + "tf.batch_self_adjoint_eigvals": "tf.self_adjoint_eigvals", + "tf.batch_svd": "tf.svd", + "tf.batch_fft": "tf.fft", + "tf.batch_ifft": "tf.ifft", + "tf.batch_ifft2d": "tf.ifft2d", + "tf.batch_fft3d": "tf.fft3d", + "tf.batch_ifft3d": "tf.ifft3d", + "tf.select": "tf.where", + "tf.complex_abs": "tf.abs" + } + + # Functions that were reordered should be changed to the new keyword args + # for safety, if positional arguments are used. If you have reversed the + # positional arguments yourself, this could do the wrong thing. + self.function_reorders = { + "tf.split": ["axis", "num_or_size_splits", "value", "name"], + "tf.concat": ["concat_dim", "values", "name"] + } + + # Specially handled functions. + self.function_handle = {"tf.reverse": self._reverse_handler} + + @staticmethod + def _reverse_handler(file_edit_recorder, node): + # TODO(aselle): Could check for a literal list of bools and try to convert + # them to indices. + comment = ("ERROR: tf.reverse has had its argument semantics changed\n" + "significantly the converter cannot detect this reliably, so you" + "need to inspect this usage manually.\n") + file_edit_recorder.add(comment, + node.lineno, + node.col_offset, + "tf.reverse", + "tf.reverse", + error="tf.reverse requires manual check.") + + +class FileEditTuple(collections.namedtuple( + "FileEditTuple", ["comment", "line", "start", "old", "new"])): + """Each edit that is recorded by a FileEditRecorder. + + Fields: + comment: A description of the edit and why it was made. + line: The line number in the file where the edit occurs (1-indexed). + start: The line number in the file where the edit occurs (0-indexed). + old: text string to remove (this must match what was in file). + new: text string to add in place of `old`. + """ + + __slots__ = () + + +class FileEditRecorder(object): + """Record changes that need to be done to the file.""" + + def __init__(self, filename): + # all edits are lists of chars + self._filename = filename + + self._line_to_edit = collections.defaultdict(list) + self._errors = [] + + def process(self, text): + """Process a list of strings, each corresponding to the recorded changes. + + Args: + text: A list of lines of text (assumed to contain newlines) + Returns: + A tuple of the modified text and a textual description of what is done. + Raises: + ValueError: if substitution source location does not have expected text. + """ + + change_report = "" + + # Iterate of each line + for line, edits in self._line_to_edit.items(): + offset = 0 + # sort by column so that edits are processed in order in order to make + # indexing adjustments cumulative for changes that change the string + # length + edits.sort(key=lambda x: x.start) + + # Extract each line to a list of characters, because mutable lists + # are editable, unlike immutable strings. + char_array = list(text[line - 1]) + + # Record a description of the change + change_report += "%s Line %d\n" % (self._filename, line) + change_report += "-" * 80 + "\n\n" + for e in edits: + change_report += "%s\n" % e.comment + change_report += "\n Old: %s" % (text[line - 1]) + + # Make underscore buffers for underlining where in the line the edit was + change_list = [" "] * len(text[line - 1]) + change_list_new = [" "] * len(text[line - 1]) + + # Iterate for each edit + for e in edits: + # Create effective start, end by accounting for change in length due + # to previous edits + start_eff = e.start + offset + end_eff = start_eff + len(e.old) + + # Make sure the edit is changing what it should be changing + old_actual = "".join(char_array[start_eff:end_eff]) + if old_actual != e.old: + raise ValueError("Expected text '%s' but got '%s'" % + ("".join(e.old), "".join(old_actual))) + # Make the edit + char_array[start_eff:end_eff] = list(e.new) + + # Create the underline highlighting of the before and after + change_list[e.start:e.start + len(e.old)] = "~" * len(e.old) + change_list_new[start_eff:end_eff] = "~" * len(e.new) + + # Keep track of how to generate effective ranges + offset += len(e.new) - len(e.old) + + # Finish the report comment + change_report += " %s\n" % "".join(change_list) + text[line - 1] = "".join(char_array) + change_report += " New: %s" % (text[line - 1]) + change_report += " %s\n\n" % "".join(change_list_new) + return "".join(text), change_report, self._errors + + def add(self, comment, line, start, old, new, error=None): + """Add a new change that is needed. + + Args: + comment: A description of what was changed + line: Line number (1 indexed) + start: Column offset (0 indexed) + old: old text + new: new text + error: this "edit" is something that cannot be fixed automatically + Returns: + None + """ + + self._line_to_edit[line].append( + FileEditTuple(comment, line, start, old, new)) + if error is not None: + self._errors.append("%s:%d: %s" % (self._filename, line, error)) + + +class TensorFlowCallVisitor(ast.NodeVisitor): + """AST Visitor that finds TensorFlow Function calls. + + Updates function calls from old API version to new API version. + """ + + def __init__(self, filename, lines): + self._filename = filename + self._file_edit = FileEditRecorder(filename) + self._lines = lines + self._api_change_spec = APIChangeSpec() + + def process(self, lines): + return self._file_edit.process(lines) + + def generic_visit(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def _rename_functions(self, node, full_name): + function_renames = self._api_change_spec.function_renames + if full_name in function_renames: + new_name = function_renames[full_name] + self._file_edit.add("Renamed function `%s` to `%s`" % (full_name, + new_name), + node.lineno, node.col_offset, full_name, new_name) + + def visit_Call(self, node): # pylint: disable=invalid-name + """Handle visiting a call node in the AST. + + Args: + node: Current Node + """ + + # Find call string (this is not perfectly accurate, + # but should cover tf.x*) + curr = node.func + items = [] + valid = True + while not isinstance(curr, ast.Name): + if isinstance(curr, ast.Attribute): + items.append(curr.attr) + else: + # We cannot just return, because we need to keep walking. + # TODO(aselle): Would it be cleaner to use an exception here with else? + valid = False + break + curr = curr.value + if valid: + items.append(curr.id) + + if valid: + # Conversion logic + full_name = ".".join(items[::-1]) + if full_name.startswith("tf."): + # Call special handlers + function_handles = self._api_change_spec.function_handle + if full_name in function_handles: + function_handles[full_name](self._file_edit, node) + + # Check for renames + self._rename_functions(node, full_name) + + # Examine any non-keyword argument and make it into a keyword argument + # if reordering required. + function_reorders = self._api_change_spec.function_reorders + if full_name in function_reorders: + reordered = function_reorders[full_name] + for idx, arg in enumerate(node.args): + self._file_edit.add("Added keyword `%s` to reordered function `%s`" + % (reordered[idx], full_name), arg.lineno, + arg.col_offset, "", reordered[idx] + "=") + + # Examine each keyword argument and convert it to the final renamed form + function_keyword_renames = ( + self._api_change_spec.function_keyword_renames) + renamed_keywords = ({} if full_name not in function_keyword_renames else + function_keyword_renames[full_name]) + for keyword in node.keywords: + argkey = keyword.arg + argval = keyword.value + if argkey in renamed_keywords: + self._file_edit.add("Renamed keyword argument from `%s` to `%s`" % + (argkey, renamed_keywords[argkey]), + argval.lineno, + argval.col_offset - len(argkey) - 1, + argkey + "=", renamed_keywords[argkey] + "=") + + ast.NodeVisitor.generic_visit(self, node) + + +class TensorFlowCodeUpgrader(object): + """Class that handles upgrading a set of Python files to TensorFlow 1.0.""" + + def __init__(self): + pass + + def process_file(self, in_filename, out_filename): + """Process the given python file for incompatible changes. + + Args: + in_filename: filename to parse + out_filename: output file to write to + Returns: + A tuple representing number of files processed, log of actions, errors + """ + in_file = open(in_filename, "r") + out_file = open(out_filename, "w") if out_filename else None + + return self.process_opened_file( + in_filename, in_file, out_filename, out_file) + + # Broad exceptions are required here because ast throws whatever it wants. + # pylint: disable=broad-except + def process_opened_file(self, in_filename, in_file, out_filename, out_file): + """Process the given python file for incompatible changes. + + This function is split out to facilitate StringIO testing from + tf_upgrade_test.py. + + Args: + in_filename: filename to parse + in_file: opened file (or StringIO) + out_filename: output file to write to + out_file: opened file (or StringIO) + Returns: + A tuple representing number of files processed, log of actions, errors + """ + process_errors = [] + text = "-" * 80 + "\n" + text += "Processing file %s\n outputting to %s\n" % (in_filename, + out_filename) + text += "-" * 80 + "\n\n" + + parsed_ast = None + lines = in_file.readlines() + try: + parsed_ast = ast.parse("".join(lines)) + except Exception: + text += "Failed to parse %s\n\n" % in_filename + text += traceback.format_exc() + if parsed_ast: + visitor = TensorFlowCallVisitor(in_filename, lines) + visitor.visit(parsed_ast) + out_text, new_text, process_errors = visitor.process(lines) + text += new_text + if out_file: + out_file.write(out_text) + text += "\n" + return 1, text, process_errors + # pylint: enable=broad-except + + def process_tree(self, root_directory, output_root_directory): + """Processes upgrades on an entire tree of python files in place. + + Note that only Python files. If you have custom code in other languages, + you will need to manually upgrade those. + + Args: + root_directory: Directory to walk and process. + output_root_directory: Directory to use as base + Returns: + A tuple of files processed, the report string ofr all files, and errors + """ + + # make sure output directory doesn't exist + if output_root_directory and os.path.exists(output_root_directory): + print("Output directory '%s' must not already exist." % ( + output_root_directory)) + sys.exit(1) + + # make sure output directory does not overlap with root_directory + norm_root = os.path.split(os.path.normpath(root_directory)) + norm_output = os.path.split(os.path.normpath(output_root_directory)) + if norm_root == norm_output: + print("Output directory '%s' same as input directory '%s"'' % ( + root_directory, output_root_directory)) + sys.exit(1) + + # Collect list of files to process (we do this to correctly handle if the + # user puts the output directory in some sub directory of the input dir) + files_to_process = [] + for dir_name, _, file_list in os.walk(root_directory): + py_files = [f for f in file_list if f.endswith(".py")] + for filename in py_files: + fullpath = os.path.join(dir_name, filename) + fullpath_output = os.path.join( + output_root_directory, os.path.relpath(fullpath, root_directory)) + files_to_process.append((fullpath, fullpath_output)) + + file_count = 0 + tree_errors = [] + report = "" + report += ("=" * 80) + "\n" + report += "Input tree: %s\n" % root_directory + report += ("=" * 80) + "\n" + + for input_path, output_path in files_to_process: + output_directory = os.path.dirname(output_path) + if not os.path.isdir(output_directory): + os.makedirs(output_directory) + file_count += 1 + _, l_report, l_errors = self.process_file(input_path, output_path) + tree_errors += l_errors + report += l_report + return file_count, report, tree_errors + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="""Convert a TensorFlow Python file to 1.0 + +Simple usage: + tf_convert.py --infile foo.py --outfile bar.py + tf_convert.py --intree ~/code/old --outtree ~/code/new +""") + parser.add_argument( + "--infile", + dest="input_file", + help="If converting a single file, the name of the file " + "to convert") + parser.add_argument( + "--outfile", + dest="output_file", + help="If converting a single file, the output filename.") + parser.add_argument( + "--intree", + dest="input_tree", + help="If converting a whole tree of files, the directory " + "to read from (relative or absolute).") + parser.add_argument( + "--outtree", + dest="output_tree", + help="If converting a whole tree of files, the output " + "directory (relative or absolute).") + parser.add_argument( + "--reportfile", + dest="report_filename", + help=("The name of the file where the report log is " + "stored." + "(default: %(default)s)"), + default="report.txt") + args = parser.parse_args() + + upgrade = TensorFlowCodeUpgrader() + report_text = None + report_filename = args.report_filename + files_processed = 0 + if args.input_file: + files_processed, report_text, errors = upgrade.process_file( + args.input_file, args.output_file) + files_processed = 1 + elif args.input_tree: + files_processed, report_text, errors = upgrade.process_tree( + args.input_tree, args.output_tree) + else: + parser.print_help() + if report_text: + open(report_filename, "w").write(report_text) + print("TensorFlow 1.0 Upgrade Script") + print("-----------------------------") + print("Converted %d files\n" % files_processed) + print("Detected %d errors that require attention" % len(errors)) + print("-" * 80) + print("\n".join(errors)) + print("\nMake sure to read the detailed log %s\n" % report_filename) diff --git a/tensorflow/tools/compatibility/tf_upgrade_test.py b/tensorflow/tools/compatibility/tf_upgrade_test.py new file mode 100644 index 00000000000..69a85d8bda9 --- /dev/null +++ b/tensorflow/tools/compatibility/tf_upgrade_test.py @@ -0,0 +1,85 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf upgrader.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import StringIO +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test as test_lib +from tensorflow.tools.compatibility import tf_upgrade + + +class TestUpgrade(test_util.TensorFlowTestCase): + """Test various APIs that have been changed in 1.0. + + We also test whether a converted file is executable. test_file_v0_11.py + aims to exhaustively test that API changes are convertible and actually + work when run with current TensorFlow. + """ + + def _upgrade(self, old_file_text): + in_file = StringIO.StringIO(old_file_text) + out_file = StringIO.StringIO() + upgrader = tf_upgrade.TensorFlowCodeUpgrader() + count, report, errors = ( + upgrader.process_opened_file("test.py", in_file, + "test_out.py", out_file)) + return count, report, errors, out_file.getvalue() + + def testParseError(self): + _, report, unused_errors, unused_new_text = self._upgrade( + "import tensorflow as tf\na + \n") + self.assertTrue(report.find("Failed to parse") != -1) + + def testReport(self): + text = "tf.mul(a, b)\n" + _, report, unused_errors, unused_new_text = self._upgrade(text) + # This is not a complete test, but it is a sanity test that a report + # is generating information. + self.assertTrue(report.find("Renamed function `tf.mul` to `tf.multiply`")) + + def testRename(self): + text = "tf.mul(a, tf.sub(b, c))\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.multiply(a, tf.subtract(b, c))\n") + + def testReorder(self): + text = "tf.concat(a, b)\ntf.split(a, b, c)\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.concat(concat_dim=a, values=b)\n" + "tf.split(axis=a, num_or_size_splits=b, value=c)\n") + + def testKeyword(self): + text = "tf.reduce_any(a, reduction_indices=[1, 2])\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.reduce_any(a, axis=[1, 2])\n") + + def testComplexExpression(self): + text = "(foo + bar)[a].word()" + _ = self._upgrade(text) + + def testReverse(self): + text = "tf.reverse(a, b)\n" + _, unused_report, errors, new_text = self._upgrade(text) + self.assertEqual(new_text, new_text) + self.assertEqual(errors, ["test.py:1: tf.reverse requires manual check."]) + + # TODO(aselle): Explicitly not testing command line interface and process_tree + # for now, since this is a one off utility. + +if __name__ == "__main__": + test_lib.main()