diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 5968f2684bb..93d2838f6a2 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -33,8 +33,7 @@ from tensorflow.python.util.deprecation import deprecated
 
 
 class Variable(object):
-  """See the @{$variables$Variables How To} for a high
-  level overview.
+  """See the @{$variables$Variables How To} for a high level overview.
 
   A variable maintains state in the graph across calls to `run()`. You add a
   variable to the graph by constructing an instance of the class `Variable`.
diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD
index 2ca71e7ab87..fb40cf0833f 100644
--- a/tensorflow/tools/compatibility/BUILD
+++ b/tensorflow/tools/compatibility/BUILD
@@ -10,7 +10,10 @@ load(
 
 py_binary(
     name = "tf_upgrade",
-    srcs = ["tf_upgrade.py"],
+    srcs = [
+        "ast_edits.py",
+        "tf_upgrade.py",
+    ],
     srcs_version = "PY2AND3",
 )
 
diff --git a/tensorflow/tools/compatibility/ast_edits.py b/tensorflow/tools/compatibility/ast_edits.py
new file mode 100644
index 00000000000..e7e4c916921
--- /dev/null
+++ b/tensorflow/tools/compatibility/ast_edits.py
@@ -0,0 +1,497 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Upgrader for Python scripts according to an API change specification."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+import collections
+import os
+import shutil
+import sys
+import tempfile
+import traceback
+
+
+class APIChangeSpec(object):
+  """This class defines the transformations that need to happen.
+
+  This class must provide the following fields:
+
+  * `function_keyword_renames`: maps function names to a map of old -> new
+    argument names
+  * `function_renames`: maps function names to new function names
+  * `change_to_function`: a set of function names that have changed (for
+    notifications)
+  * `function_reorders`: maps functions whose argument order has changed to the
+    list of arguments in the new order
+  * `function_handle`: maps function names to custom handlers for the function
+
+  For an example, see `TFAPIChangeSpec`.
+  """
+
+
+class _FileEditTuple(collections.namedtuple(
+    "_FileEditTuple", ["comment", "line", "start", "old", "new"])):
+  """Each edit that is recorded by a _FileEditRecorder.
+
+  Fields:
+    comment: A description of the edit and why it was made.
+    line: The line number in the file where the edit occurs (1-indexed).
+    start: The line number in the file where the edit occurs (0-indexed).
+    old: text string to remove (this must match what was in file).
+    new: text string to add in place of `old`.
+  """
+
+  __slots__ = ()
+
+
+class _FileEditRecorder(object):
+  """Record changes that need to be done to the file."""
+
+  def __init__(self, filename):
+    # all edits are lists of chars
+    self._filename = filename
+
+    self._line_to_edit = collections.defaultdict(list)
+    self._errors = []
+
+  def process(self, text):
+    """Process a list of strings, each corresponding to the recorded changes.
+
+    Args:
+      text: A list of lines of text (assumed to contain newlines)
+    Returns:
+      A tuple of the modified text and a textual description of what is done.
+    Raises:
+      ValueError: if substitution source location does not have expected text.
+    """
+
+    change_report = ""
+
+    # Iterate of each line
+    for line, edits in self._line_to_edit.items():
+      offset = 0
+      # sort by column so that edits are processed in order in order to make
+      # indexing adjustments cumulative for changes that change the string
+      # length
+      edits.sort(key=lambda x: x.start)
+
+      # Extract each line to a list of characters, because mutable lists
+      # are editable, unlike immutable strings.
+      char_array = list(text[line - 1])
+
+      # Record a description of the change
+      change_report += "%r Line %d\n" % (self._filename, line)
+      change_report += "-" * 80 + "\n\n"
+      for e in edits:
+        change_report += "%s\n" % e.comment
+      change_report += "\n    Old: %s" % (text[line - 1])
+
+      # Make underscore buffers for underlining where in the line the edit was
+      change_list = [" "] * len(text[line - 1])
+      change_list_new = [" "] * len(text[line - 1])
+
+      # Iterate for each edit
+      for e in edits:
+        # Create effective start, end by accounting for change in length due
+        # to previous edits
+        start_eff = e.start + offset
+        end_eff = start_eff + len(e.old)
+
+        # Make sure the edit is changing what it should be changing
+        old_actual = "".join(char_array[start_eff:end_eff])
+        if old_actual != e.old:
+          raise ValueError("Expected text %r but got %r" %
+                           ("".join(e.old), "".join(old_actual)))
+        # Make the edit
+        char_array[start_eff:end_eff] = list(e.new)
+
+        # Create the underline highlighting of the before and after
+        change_list[e.start:e.start + len(e.old)] = "~" * len(e.old)
+        change_list_new[start_eff:end_eff] = "~" * len(e.new)
+
+        # Keep track of how to generate effective ranges
+        offset += len(e.new) - len(e.old)
+
+      # Finish the report comment
+      change_report += "         %s\n" % "".join(change_list)
+      text[line - 1] = "".join(char_array)
+      change_report += "    New: %s" % (text[line - 1])
+      change_report += "         %s\n\n" % "".join(change_list_new)
+    return "".join(text), change_report, self._errors
+
+  def add(self, comment, line, start, old, new, error=None):
+    """Add a new change that is needed.
+
+    Args:
+      comment: A description of what was changed
+      line: Line number (1 indexed)
+      start: Column offset (0 indexed)
+      old: old text
+      new: new text
+      error: this "edit" is something that cannot be fixed automatically
+    Returns:
+      None
+    """
+
+    self._line_to_edit[line].append(
+        _FileEditTuple(comment, line, start, old, new))
+    if error:
+      self._errors.append("%s:%d: %s" % (self._filename, line, error))
+
+
+class _ASTCallVisitor(ast.NodeVisitor):
+  """AST Visitor that processes function calls.
+
+  Updates function calls from old API version to new API version using a given
+  change spec.
+  """
+
+  def __init__(self, filename, lines, api_change_spec):
+    self._filename = filename
+    self._file_edit = _FileEditRecorder(filename)
+    self._lines = lines
+    self._api_change_spec = api_change_spec
+
+  def process(self, lines):
+    return self._file_edit.process(lines)
+
+  def generic_visit(self, node):
+    ast.NodeVisitor.generic_visit(self, node)
+
+  def _rename_functions(self, node, full_name):
+    function_renames = self._api_change_spec.function_renames
+    try:
+      new_name = function_renames[full_name]
+      self._file_edit.add("Renamed function %r to %r" % (full_name,
+                                                         new_name),
+                          node.lineno, node.col_offset, full_name, new_name)
+    except KeyError:
+      pass
+
+  def _get_attribute_full_path(self, node):
+    """Traverse an attribute to generate a full name e.g. tf.foo.bar.
+
+    Args:
+      node: A Node of type Attribute.
+
+    Returns:
+      a '.'-delimited full-name or None if the tree was not a simple form.
+      i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
+    """
+    curr = node
+    items = []
+    while not isinstance(curr, ast.Name):
+      if not isinstance(curr, ast.Attribute):
+        return None
+      items.append(curr.attr)
+      curr = curr.value
+    items.append(curr.id)
+    return ".".join(reversed(items))
+
+  def _find_true_position(self, node):
+    """Return correct line number and column offset for a given node.
+
+    This is necessary mainly because ListComp's location reporting reports
+    the next token after the list comprehension list opening.
+
+    Args:
+      node: Node for which we wish to know the lineno and col_offset
+    """
+    import re
+    find_open = re.compile("^\s*(\\[).*$")
+    find_string_chars = re.compile("['\"]")
+
+    if isinstance(node, ast.ListComp):
+      # Strangely, ast.ListComp returns the col_offset of the first token
+      # after the '[' token which appears to be a bug. Workaround by
+      # explicitly finding the real start of the list comprehension.
+      line = node.lineno
+      col = node.col_offset
+      # loop over lines
+      while 1:
+        # Reverse the text to and regular expression search for whitespace
+        text = self._lines[line-1]
+        reversed_preceding_text = text[:col][::-1]
+        # First find if a [ can be found with only whitespace between it and
+        # col.
+        m = find_open.match(reversed_preceding_text)
+        if m:
+          new_col_offset = col - m.start(1) - 1
+          return line, new_col_offset
+        else:
+          if (reversed_preceding_text=="" or
+             reversed_preceding_text.isspace()):
+            line = line - 1
+            prev_line = self._lines[line - 1]
+            # TODO(aselle):
+            # this is poor comment detection, but it is good enough for
+            # cases where the comment does not contain string literal starting/
+            # ending characters. If ast gave us start and end locations of the
+            # ast nodes rather than just start, we could use string literal
+            # node ranges to filter out spurious #'s that appear in string
+            # literals.
+            comment_start = prev_line.find("#")
+            if comment_start ==  -1:
+              col = len(prev_line) -1
+            elif find_string_chars.search(prev_line[comment_start:]) is None:
+              col = comment_start
+            else:
+              return None, None
+          else:
+            return None, None
+    # Most other nodes return proper locations (with notably does not), but
+    # it is not possible to use that in an argument.
+    return node.lineno, node.col_offset
+
+
+  def visit_Call(self, node):  # pylint: disable=invalid-name
+    """Handle visiting a call node in the AST.
+
+    Args:
+      node: Current Node
+    """
+
+
+    # Find a simple attribute name path e.g. "tf.foo.bar"
+    full_name = self._get_attribute_full_path(node.func)
+
+    # Make sure the func is marked as being part of a call
+    node.func.is_function_for_call = True
+
+    if full_name:
+      # Call special handlers
+      function_handles = self._api_change_spec.function_handle
+      if full_name in function_handles:
+        function_handles[full_name](self._file_edit, node)
+
+      # Examine any non-keyword argument and make it into a keyword argument
+      # if reordering required.
+      function_reorders = self._api_change_spec.function_reorders
+      function_keyword_renames = (
+          self._api_change_spec.function_keyword_renames)
+
+      if full_name in function_reorders:
+        reordered = function_reorders[full_name]
+        for idx, arg in enumerate(node.args):
+          lineno, col_offset = self._find_true_position(arg)
+          if lineno is None or col_offset is None:
+            self._file_edit.add(
+                "Failed to add keyword %r to reordered function %r"
+                % (reordered[idx], full_name), arg.lineno, arg.col_offset,
+                "", "",
+                error="A necessary keyword argument failed to be inserted.")
+          else:
+            keyword_arg = reordered[idx]
+            if (full_name in function_keyword_renames and
+                keyword_arg in function_keyword_renames[full_name]):
+              keyword_arg = function_keyword_renames[full_name][keyword_arg]
+            self._file_edit.add("Added keyword %r to reordered function %r"
+                                % (reordered[idx], full_name), lineno,
+                                col_offset, "", keyword_arg + "=")
+
+      # Examine each keyword argument and convert it to the final renamed form
+      renamed_keywords = ({} if full_name not in function_keyword_renames else
+                          function_keyword_renames[full_name])
+      for keyword in node.keywords:
+        argkey = keyword.arg
+        argval = keyword.value
+
+        if argkey in renamed_keywords:
+          argval_lineno, argval_col_offset = self._find_true_position(argval)
+          if argval_lineno is not None and argval_col_offset is not None:
+            # TODO(aselle): We should scan backward to find the start of the
+            # keyword key. Unfortunately ast does not give you the location of
+            # keyword keys, so we are forced to infer it from the keyword arg
+            # value.
+            key_start = argval_col_offset - len(argkey) - 1
+            key_end = key_start + len(argkey) + 1
+            if (self._lines[argval_lineno - 1][key_start:key_end] ==
+                argkey + "="):
+              self._file_edit.add("Renamed keyword argument from %r to %r" %
+                                  (argkey, renamed_keywords[argkey]),
+                                  argval_lineno,
+                                  argval_col_offset - len(argkey) - 1,
+                                  argkey + "=", renamed_keywords[argkey] + "=")
+              continue
+          self._file_edit.add(
+              "Failed to rename keyword argument from %r to %r" %
+              (argkey, renamed_keywords[argkey]),
+              argval.lineno,
+              argval.col_offset - len(argkey) - 1,
+              "", "",
+              error="Failed to find keyword lexographically. Fix manually.")
+
+    ast.NodeVisitor.generic_visit(self, node)
+
+  def visit_Attribute(self, node):  # pylint: disable=invalid-name
+    """Handle bare Attributes i.e. [tf.foo, tf.bar].
+
+    Args:
+      node: Node that is of type ast.Attribute
+    """
+    full_name = self._get_attribute_full_path(node)
+    if full_name:
+      self._rename_functions(node, full_name)
+    if full_name in self._api_change_spec.change_to_function:
+      if not hasattr(node, "is_function_for_call"):
+        new_text = full_name + "()"
+        self._file_edit.add("Changed %r to %r"%(full_name, new_text),
+                            node.lineno, node.col_offset, full_name, new_text)
+
+    ast.NodeVisitor.generic_visit(self, node)
+
+
+class ASTCodeUpgrader(object):
+  """Handles upgrading a set of Python files using a given API change spec."""
+
+  def __init__(self, api_change_spec):
+    if not isinstance(api_change_spec, APIChangeSpec):
+      raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" %
+                      type(api_change_spec))
+    self._api_change_spec = api_change_spec
+
+  def process_file(self, in_filename, out_filename):
+    """Process the given python file for incompatible changes.
+
+    Args:
+      in_filename: filename to parse
+      out_filename: output file to write to
+    Returns:
+      A tuple representing number of files processed, log of actions, errors
+    """
+
+    # Write to a temporary file, just in case we are doing an implace modify.
+    with open(in_filename, "r") as in_file, \
+        tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
+      ret = self.process_opened_file(
+          in_filename, in_file, out_filename, temp_file)
+
+    shutil.move(temp_file.name, out_filename)
+    return ret
+
+  # Broad exceptions are required here because ast throws whatever it wants.
+  # pylint: disable=broad-except
+  def process_opened_file(self, in_filename, in_file, out_filename, out_file):
+    """Process the given python file for incompatible changes.
+
+    This function is split out to facilitate StringIO testing from
+    tf_upgrade_test.py.
+
+    Args:
+      in_filename: filename to parse
+      in_file: opened file (or StringIO)
+      out_filename: output file to write to
+      out_file: opened file (or StringIO)
+    Returns:
+      A tuple representing number of files processed, log of actions, errors
+    """
+    process_errors = []
+    text = "-" * 80 + "\n"
+    text += "Processing file %r\n outputting to %r\n" % (in_filename,
+                                                         out_filename)
+    text += "-" * 80 + "\n\n"
+
+    parsed_ast = None
+    lines = in_file.readlines()
+    try:
+      parsed_ast = ast.parse("".join(lines))
+    except Exception:
+      text += "Failed to parse %r\n\n" % in_filename
+      text += traceback.format_exc()
+    if parsed_ast:
+      visitor = _ASTCallVisitor(in_filename, lines, self._api_change_spec)
+      visitor.visit(parsed_ast)
+      out_text, new_text, process_errors = visitor.process(lines)
+      text += new_text
+      if out_file:
+        out_file.write(out_text)
+    text += "\n"
+    return 1, text, process_errors
+  # pylint: enable=broad-except
+
+  def process_tree(self, root_directory, output_root_directory,
+                   copy_other_files):
+    """Processes upgrades on an entire tree of python files in place.
+
+    Note that only Python files. If you have custom code in other languages,
+    you will need to manually upgrade those.
+
+    Args:
+      root_directory: Directory to walk and process.
+      output_root_directory: Directory to use as base.
+      copy_other_files: Copy files that are not touched by this converter.
+
+    Returns:
+      A tuple of files processed, the report string ofr all files, and errors
+    """
+
+    # make sure output directory doesn't exist
+    if output_root_directory and os.path.exists(output_root_directory):
+      print("Output directory %r must not already exist." % (
+          output_root_directory))
+      sys.exit(1)
+
+    # make sure output directory does not overlap with root_directory
+    norm_root = os.path.split(os.path.normpath(root_directory))
+    norm_output = os.path.split(os.path.normpath(output_root_directory))
+    if norm_root == norm_output:
+      print("Output directory %r same as input directory %r" % (
+          root_directory, output_root_directory))
+      sys.exit(1)
+
+    # Collect list of files to process (we do this to correctly handle if the
+    # user puts the output directory in some sub directory of the input dir)
+    files_to_process = []
+    files_to_copy = []
+    for dir_name, _, file_list in os.walk(root_directory):
+      py_files = [f for f in file_list if f.endswith(".py")]
+      copy_files = [f for f in file_list if not f.endswith(".py")]
+      for filename in py_files:
+        fullpath = os.path.join(dir_name, filename)
+        fullpath_output = os.path.join(
+            output_root_directory, os.path.relpath(fullpath, root_directory))
+        files_to_process.append((fullpath, fullpath_output))
+      if copy_other_files:
+        for filename in copy_files:
+          fullpath = os.path.join(dir_name, filename)
+          fullpath_output = os.path.join(
+              output_root_directory, os.path.relpath(fullpath, root_directory))
+          files_to_copy.append((fullpath, fullpath_output))
+
+    file_count = 0
+    tree_errors = []
+    report = ""
+    report += ("=" * 80) + "\n"
+    report += "Input tree: %r\n" % root_directory
+    report += ("=" * 80) + "\n"
+
+    for input_path, output_path in files_to_process:
+      output_directory = os.path.dirname(output_path)
+      if not os.path.isdir(output_directory):
+        os.makedirs(output_directory)
+      file_count += 1
+      _, l_report, l_errors = self.process_file(input_path, output_path)
+      tree_errors += l_errors
+      report += l_report
+    for input_path, output_path in files_to_copy:
+      output_directory = os.path.dirname(output_path)
+      if not os.path.isdir(output_directory):
+        os.makedirs(output_directory)
+      shutil.copy(input_path, output_path)
+    return file_count, report, tree_errors
diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py
index 9a4a8ff71d9..72fe4a48cdd 100644
--- a/tensorflow/tools/compatibility/tf_upgrade.py
+++ b/tensorflow/tools/compatibility/tf_upgrade.py
@@ -17,17 +17,13 @@
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
+
 import argparse
-import ast
-import collections
-import os
-import shutil
-import sys
-import tempfile
-import traceback
+
+from tensorflow.tools.compatibility import ast_edits
 
 
-class APIChangeSpec(object):
+class TFAPIChangeSpec(ast_edits.APIChangeSpec):
   """List of maps that describe what changed in the API."""
 
   def __init__(self):
@@ -179,7 +175,9 @@ class APIChangeSpec(object):
     }
 
     # Specially handled functions.
-    self.function_handle = {"tf.reverse": self._reverse_handler}
+    self.function_handle = {
+        "tf.reverse": self._reverse_handler
+    }
 
   @staticmethod
   def _reverse_handler(file_edit_recorder, node):
@@ -196,450 +194,6 @@ class APIChangeSpec(object):
                            error="tf.reverse requires manual check.")
 
 
-class FileEditTuple(collections.namedtuple(
-    "FileEditTuple", ["comment", "line", "start", "old", "new"])):
-  """Each edit that is recorded by a FileEditRecorder.
-
-  Fields:
-    comment: A description of the edit and why it was made.
-    line: The line number in the file where the edit occurs (1-indexed).
-    start: The line number in the file where the edit occurs (0-indexed).
-    old: text string to remove (this must match what was in file).
-    new: text string to add in place of `old`.
-  """
-
-  __slots__ = ()
-
-
-class FileEditRecorder(object):
-  """Record changes that need to be done to the file."""
-
-  def __init__(self, filename):
-    # all edits are lists of chars
-    self._filename = filename
-
-    self._line_to_edit = collections.defaultdict(list)
-    self._errors = []
-
-  def process(self, text):
-    """Process a list of strings, each corresponding to the recorded changes.
-
-    Args:
-      text: A list of lines of text (assumed to contain newlines)
-    Returns:
-      A tuple of the modified text and a textual description of what is done.
-    Raises:
-      ValueError: if substitution source location does not have expected text.
-    """
-
-    change_report = ""
-
-    # Iterate of each line
-    for line, edits in self._line_to_edit.items():
-      offset = 0
-      # sort by column so that edits are processed in order in order to make
-      # indexing adjustments cumulative for changes that change the string
-      # length
-      edits.sort(key=lambda x: x.start)
-
-      # Extract each line to a list of characters, because mutable lists
-      # are editable, unlike immutable strings.
-      char_array = list(text[line - 1])
-
-      # Record a description of the change
-      change_report += "%r Line %d\n" % (self._filename, line)
-      change_report += "-" * 80 + "\n\n"
-      for e in edits:
-        change_report += "%s\n" % e.comment
-      change_report += "\n    Old: %s" % (text[line - 1])
-
-      # Make underscore buffers for underlining where in the line the edit was
-      change_list = [" "] * len(text[line - 1])
-      change_list_new = [" "] * len(text[line - 1])
-
-      # Iterate for each edit
-      for e in edits:
-        # Create effective start, end by accounting for change in length due
-        # to previous edits
-        start_eff = e.start + offset
-        end_eff = start_eff + len(e.old)
-
-        # Make sure the edit is changing what it should be changing
-        old_actual = "".join(char_array[start_eff:end_eff])
-        if old_actual != e.old:
-          raise ValueError("Expected text %r but got %r" %
-                           ("".join(e.old), "".join(old_actual)))
-        # Make the edit
-        char_array[start_eff:end_eff] = list(e.new)
-
-        # Create the underline highlighting of the before and after
-        change_list[e.start:e.start + len(e.old)] = "~" * len(e.old)
-        change_list_new[start_eff:end_eff] = "~" * len(e.new)
-
-        # Keep track of how to generate effective ranges
-        offset += len(e.new) - len(e.old)
-
-      # Finish the report comment
-      change_report += "         %s\n" % "".join(change_list)
-      text[line - 1] = "".join(char_array)
-      change_report += "    New: %s" % (text[line - 1])
-      change_report += "         %s\n\n" % "".join(change_list_new)
-    return "".join(text), change_report, self._errors
-
-  def add(self, comment, line, start, old, new, error=None):
-    """Add a new change that is needed.
-
-    Args:
-      comment: A description of what was changed
-      line: Line number (1 indexed)
-      start: Column offset (0 indexed)
-      old: old text
-      new: new text
-      error: this "edit" is something that cannot be fixed automatically
-    Returns:
-      None
-    """
-
-    self._line_to_edit[line].append(
-        FileEditTuple(comment, line, start, old, new))
-    if error:
-      self._errors.append("%s:%d: %s" % (self._filename, line, error))
-
-
-class TensorFlowCallVisitor(ast.NodeVisitor):
-  """AST Visitor that finds TensorFlow Function calls.
-
-  Updates function calls from old API version to new API version.
-  """
-
-  def __init__(self, filename, lines):
-    self._filename = filename
-    self._file_edit = FileEditRecorder(filename)
-    self._lines = lines
-    self._api_change_spec = APIChangeSpec()
-
-  def process(self, lines):
-    return self._file_edit.process(lines)
-
-  def generic_visit(self, node):
-    ast.NodeVisitor.generic_visit(self, node)
-
-  def _rename_functions(self, node, full_name):
-    function_renames = self._api_change_spec.function_renames
-    try:
-      new_name = function_renames[full_name]
-      self._file_edit.add("Renamed function %r to %r" % (full_name,
-                                                         new_name),
-                          node.lineno, node.col_offset, full_name, new_name)
-    except KeyError:
-      pass
-
-  def _get_attribute_full_path(self, node):
-    """Traverse an attribute to generate a full name e.g. tf.foo.bar.
-
-    Args:
-      node: A Node of type Attribute.
-
-    Returns:
-      a '.'-delimited full-name or None if the tree was not a simple form.
-      i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
-    """
-    curr = node
-    items = []
-    while not isinstance(curr, ast.Name):
-      if not isinstance(curr, ast.Attribute):
-        return None
-      items.append(curr.attr)
-      curr = curr.value
-    items.append(curr.id)
-    return ".".join(reversed(items))
-
-  def _find_true_position(self, node):
-    """Return correct line number and column offset for a given node.
-
-    This is necessary mainly because ListComp's location reporting reports
-    the next token after the list comprehension list opening.
-
-    Args:
-      node: Node for which we wish to know the lineno and col_offset
-    """
-    import re
-    find_open = re.compile("^\s*(\\[).*$")
-    find_string_chars = re.compile("['\"]")
-
-    if isinstance(node, ast.ListComp):
-      # Strangely, ast.ListComp returns the col_offset of the first token
-      # after the '[' token which appears to be a bug. Workaround by
-      # explicitly finding the real start of the list comprehension.
-      line = node.lineno
-      col = node.col_offset
-      # loop over lines
-      while 1:
-        # Reverse the text to and regular expression search for whitespace
-        text = self._lines[line-1]
-        reversed_preceding_text = text[:col][::-1]
-        # First find if a [ can be found with only whitespace between it and
-        # col.
-        m = find_open.match(reversed_preceding_text)
-        if m:
-          new_col_offset = col - m.start(1) - 1
-          return line, new_col_offset
-        else:
-          if (reversed_preceding_text=="" or
-             reversed_preceding_text.isspace()):
-            line = line - 1
-            prev_line = self._lines[line - 1]
-            # TODO(aselle):
-            # this is poor comment detection, but it is good enough for
-            # cases where the comment does not contain string literal starting/
-            # ending characters. If ast gave us start and end locations of the
-            # ast nodes rather than just start, we could use string literal
-            # node ranges to filter out spurious #'s that appear in string
-            # literals.
-            comment_start = prev_line.find("#")
-            if comment_start ==  -1:
-              col = len(prev_line) -1
-            elif find_string_chars.search(prev_line[comment_start:]) is None:
-              col = comment_start
-            else:
-              return None, None
-          else:
-            return None, None
-    # Most other nodes return proper locations (with notably does not), but
-    # it is not possible to use that in an argument.
-    return node.lineno, node.col_offset
-
-
-  def visit_Call(self, node):  # pylint: disable=invalid-name
-    """Handle visiting a call node in the AST.
-
-    Args:
-      node: Current Node
-    """
-
-
-    # Find a simple attribute name path e.g. "tf.foo.bar"
-    full_name = self._get_attribute_full_path(node.func)
-
-    # Make sure the func is marked as being part of a call
-    node.func.is_function_for_call = True
-
-    if full_name and full_name.startswith("tf."):
-      # Call special handlers
-      function_handles = self._api_change_spec.function_handle
-      if full_name in function_handles:
-        function_handles[full_name](self._file_edit, node)
-
-      # Examine any non-keyword argument and make it into a keyword argument
-      # if reordering required.
-      function_reorders = self._api_change_spec.function_reorders
-      function_keyword_renames = (
-          self._api_change_spec.function_keyword_renames)
-
-      if full_name in function_reorders:
-        reordered = function_reorders[full_name]
-        for idx, arg in enumerate(node.args):
-          lineno, col_offset = self._find_true_position(arg)
-          if lineno is None or col_offset is None:
-            self._file_edit.add(
-                "Failed to add keyword %r to reordered function %r"
-                % (reordered[idx], full_name), arg.lineno, arg.col_offset,
-                "", "",
-                error="A necessary keyword argument failed to be inserted.")
-          else:
-            keyword_arg = reordered[idx]
-            if (full_name in function_keyword_renames and
-                keyword_arg in function_keyword_renames[full_name]):
-              keyword_arg = function_keyword_renames[full_name][keyword_arg]
-            self._file_edit.add("Added keyword %r to reordered function %r"
-                                % (reordered[idx], full_name), lineno,
-                                col_offset, "", keyword_arg + "=")
-
-      # Examine each keyword argument and convert it to the final renamed form
-      renamed_keywords = ({} if full_name not in function_keyword_renames else
-                          function_keyword_renames[full_name])
-      for keyword in node.keywords:
-        argkey = keyword.arg
-        argval = keyword.value
-
-        if argkey in renamed_keywords:
-          argval_lineno, argval_col_offset = self._find_true_position(argval)
-          if (argval_lineno is not None and argval_col_offset is not None):
-            # TODO(aselle): We should scan backward to find the start of the
-            # keyword key. Unfortunately ast does not give you the location of
-            # keyword keys, so we are forced to infer it from the keyword arg
-            # value.
-            key_start = argval_col_offset - len(argkey) - 1
-            key_end = key_start + len(argkey) + 1
-            if self._lines[argval_lineno - 1][key_start:key_end] == argkey + "=":
-              self._file_edit.add("Renamed keyword argument from %r to %r" %
-                              (argkey, renamed_keywords[argkey]),
-                              argval_lineno,
-                              argval_col_offset - len(argkey) - 1,
-                              argkey + "=", renamed_keywords[argkey] + "=")
-              continue
-          self._file_edit.add(
-              "Failed to rename keyword argument from %r to %r" %
-              (argkey, renamed_keywords[argkey]),
-              argval.lineno,
-              argval.col_offset - len(argkey) - 1,
-              "", "",
-              error="Failed to find keyword lexographically. Fix manually.")
-
-    ast.NodeVisitor.generic_visit(self, node)
-
-  def visit_Attribute(self, node):  # pylint: disable=invalid-name
-    """Handle bare Attributes i.e. [tf.foo, tf.bar].
-
-    Args:
-      node: Node that is of type ast.Attribute
-    """
-    full_name = self._get_attribute_full_path(node)
-    if full_name and full_name.startswith("tf."):
-      self._rename_functions(node, full_name)
-    if full_name in self._api_change_spec.change_to_function:
-      if not hasattr(node, "is_function_for_call"):
-        new_text = full_name + "()"
-        self._file_edit.add("Changed %r to %r"%(full_name, new_text),
-                            node.lineno, node.col_offset, full_name, new_text)
-
-    ast.NodeVisitor.generic_visit(self, node)
-
-
-class TensorFlowCodeUpgrader(object):
-  """Class that handles upgrading a set of Python files to TensorFlow 1.0."""
-
-  def __init__(self):
-    pass
-
-  def process_file(self, in_filename, out_filename):
-    """Process the given python file for incompatible changes.
-
-    Args:
-      in_filename: filename to parse
-      out_filename: output file to write to
-    Returns:
-      A tuple representing number of files processed, log of actions, errors
-    """
-
-    # Write to a temporary file, just in case we are doing an implace modify.
-    with open(in_filename, "r") as in_file, \
-        tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
-      ret = self.process_opened_file(
-          in_filename, in_file, out_filename, temp_file)
-
-    shutil.move(temp_file.name, out_filename)
-    return ret
-
-  # Broad exceptions are required here because ast throws whatever it wants.
-  # pylint: disable=broad-except
-  def process_opened_file(self, in_filename, in_file, out_filename, out_file):
-    """Process the given python file for incompatible changes.
-
-    This function is split out to facilitate StringIO testing from
-    tf_upgrade_test.py.
-
-    Args:
-      in_filename: filename to parse
-      in_file: opened file (or StringIO)
-      out_filename: output file to write to
-      out_file: opened file (or StringIO)
-    Returns:
-      A tuple representing number of files processed, log of actions, errors
-    """
-    process_errors = []
-    text = "-" * 80 + "\n"
-    text += "Processing file %r\n outputting to %r\n" % (in_filename,
-                                                         out_filename)
-    text += "-" * 80 + "\n\n"
-
-    parsed_ast = None
-    lines = in_file.readlines()
-    try:
-      parsed_ast = ast.parse("".join(lines))
-    except Exception:
-      text += "Failed to parse %r\n\n" % in_filename
-      text += traceback.format_exc()
-    if parsed_ast:
-      visitor = TensorFlowCallVisitor(in_filename, lines)
-      visitor.visit(parsed_ast)
-      out_text, new_text, process_errors = visitor.process(lines)
-      text += new_text
-      if out_file:
-        out_file.write(out_text)
-    text += "\n"
-    return 1, text, process_errors
-  # pylint: enable=broad-except
-
-  def process_tree(self, root_directory, output_root_directory, copy_other_files):
-    """Processes upgrades on an entire tree of python files in place.
-
-    Note that only Python files. If you have custom code in other languages,
-    you will need to manually upgrade those.
-
-    Args:
-      root_directory: Directory to walk and process.
-      output_root_directory: Directory to use as base
-    Returns:
-      A tuple of files processed, the report string ofr all files, and errors
-    """
-
-    # make sure output directory doesn't exist
-    if output_root_directory and os.path.exists(output_root_directory):
-      print("Output directory %r must not already exist." % (
-          output_root_directory))
-      sys.exit(1)
-
-    # make sure output directory does not overlap with root_directory
-    norm_root = os.path.split(os.path.normpath(root_directory))
-    norm_output = os.path.split(os.path.normpath(output_root_directory))
-    if norm_root == norm_output:
-      print("Output directory %r same as input directory %r" % (
-          root_directory, output_root_directory))
-      sys.exit(1)
-
-    # Collect list of files to process (we do this to correctly handle if the
-    # user puts the output directory in some sub directory of the input dir)
-    files_to_process = []
-    files_to_copy = []
-    for dir_name, _, file_list in os.walk(root_directory):
-      py_files = [f for f in file_list if f.endswith(".py")]
-      copy_files = [f for f in file_list if not f.endswith(".py")]
-      for filename in py_files:
-        fullpath = os.path.join(dir_name, filename)
-        fullpath_output = os.path.join(
-            output_root_directory, os.path.relpath(fullpath, root_directory))
-        files_to_process.append((fullpath, fullpath_output))
-      if copy_other_files:
-        for filename in copy_files:
-          fullpath = os.path.join(dir_name, filename)
-          fullpath_output = os.path.join(
-              output_root_directory, os.path.relpath(fullpath, root_directory))
-          files_to_copy.append((fullpath, fullpath_output))
-
-    file_count = 0
-    tree_errors = []
-    report = ""
-    report += ("=" * 80) + "\n"
-    report += "Input tree: %r\n" % root_directory
-    report += ("=" * 80) + "\n"
-
-    for input_path, output_path in files_to_process:
-      output_directory = os.path.dirname(output_path)
-      if not os.path.isdir(output_directory):
-        os.makedirs(output_directory)
-      file_count += 1
-      _, l_report, l_errors = self.process_file(input_path, output_path)
-      tree_errors += l_errors
-      report += l_report
-    for input_path, output_path in files_to_copy:
-      output_directory = os.path.dirname(output_path)
-      if not os.path.isdir(output_directory):
-        os.makedirs(output_directory)
-      shutil.copy(input_path, output_path)
-    return file_count, report, tree_errors
-
-
 if __name__ == "__main__":
   parser = argparse.ArgumentParser(
       formatter_class=argparse.RawDescriptionHelpFormatter,
@@ -684,7 +238,7 @@ Simple usage:
       default="report.txt")
   args = parser.parse_args()
 
-  upgrade = TensorFlowCodeUpgrader()
+  upgrade = ast_edits.ASTCodeUpgrader(TFAPIChangeSpec())
   report_text = None
   report_filename = args.report_filename
   files_processed = 0
diff --git a/tensorflow/tools/compatibility/tf_upgrade_test.py b/tensorflow/tools/compatibility/tf_upgrade_test.py
index de4e3de73cd..ac838a2791f 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_test.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_test.py
@@ -22,6 +22,7 @@ import tempfile
 import six
 from tensorflow.python.framework import test_util
 from tensorflow.python.platform import test as test_lib
+from tensorflow.tools.compatibility import ast_edits
 from tensorflow.tools.compatibility import tf_upgrade
 
 
@@ -36,7 +37,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
   def _upgrade(self, old_file_text):
     in_file = six.StringIO(old_file_text)
     out_file = six.StringIO()
-    upgrader = tf_upgrade.TensorFlowCodeUpgrader()
+    upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec())
     count, report, errors = (
         upgrader.process_opened_file("test.py", in_file,
                                      "test_out.py", out_file))
@@ -139,7 +140,7 @@ class TestUpgradeFiles(test_util.TensorFlowTestCase):
     upgraded = "tf.multiply(a, b)\n"
     temp_file.write(original)
     temp_file.close()
-    upgrader = tf_upgrade.TensorFlowCodeUpgrader()
+    upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec())
     upgrader.process_file(temp_file.name, temp_file.name)
     self.assertAllEqual(open(temp_file.name).read(), upgraded)
     os.unlink(temp_file.name)