Initial API compatibility script for TF2.0. I am pretty much reusing 1.0 conversion script but passing V2 data. Also, remove code from tf_update.py which is also in ast_edits.py.
PiperOrigin-RevId: 205887317
This commit is contained in:
parent
57d051e7b1
commit
76e8f7b7fd
tensorflow/tools/compatibility
@ -8,10 +8,17 @@ load(
|
||||
"tf_cc_test", # @unused
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "ast_edits",
|
||||
srcs = ["ast_edits.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "tf_upgrade",
|
||||
srcs = ["tf_upgrade.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":ast_edits"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
@ -26,6 +33,28 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "tf_upgrade_v2",
|
||||
srcs = [
|
||||
"renames_v2.py",
|
||||
"tf_upgrade_v2.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":ast_edits"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tf_upgrade_v2_test",
|
||||
srcs = ["tf_upgrade_v2_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":tf_upgrade_v2",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
# Keep for reference, this test will succeed in 0.11 but fail in 1.0
|
||||
# py_test(
|
||||
# name = "test_file_v0_11",
|
||||
@ -62,9 +91,37 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
[
|
||||
"tf_upgrade.py",
|
||||
"testdata/test_file_v0_11.py",
|
||||
genrule(
|
||||
name = "generate_upgraded_file_v2",
|
||||
testonly = 1,
|
||||
srcs = ["testdata/test_file_v1_10.py"],
|
||||
outs = [
|
||||
"test_file_v2_0.py",
|
||||
"report_v2.txt",
|
||||
],
|
||||
cmd = ("$(location :tf_upgrade_v2)" +
|
||||
" --infile $(location testdata/test_file_v1_10.py)" +
|
||||
" --outfile $(location test_file_v2_0.py)" +
|
||||
" --reportfile $(location report_v2.txt)"),
|
||||
tools = [":tf_upgrade_v2"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_file_v2_0",
|
||||
size = "small",
|
||||
srcs = ["test_file_v2_0.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
[
|
||||
"ast_edits.py",
|
||||
"tf_upgrade.py",
|
||||
"renames_v2.py",
|
||||
"testdata/test_file_v0_11.py",
|
||||
"testdata/test_file_v1_10.py",
|
||||
],
|
||||
)
|
||||
|
134
tensorflow/tools/compatibility/renames_v2.py
Normal file
134
tensorflow/tools/compatibility/renames_v2.py
Normal file
@ -0,0 +1,134 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=line-too-long
|
||||
"""List of renames to apply when converting from TF 1.0 to TF 2.0.
|
||||
|
||||
THIS FILE IS AUTOGENERATED: To update, please run:
|
||||
bazel build tensorflow/tools/compatibility/update:generate_v2_renames_map
|
||||
bazel-bin/tensorflow/tools/compatibility/update/generate_v2_renames_map
|
||||
This file should be updated whenever endpoints are deprecated.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
renames = {
|
||||
'tf.acos': 'tf.math.acos',
|
||||
'tf.acosh': 'tf.math.acosh',
|
||||
'tf.add': 'tf.math.add',
|
||||
'tf.as_string': 'tf.dtypes.as_string',
|
||||
'tf.asin': 'tf.math.asin',
|
||||
'tf.asinh': 'tf.math.asinh',
|
||||
'tf.atan': 'tf.math.atan',
|
||||
'tf.atan2': 'tf.math.atan2',
|
||||
'tf.atanh': 'tf.math.atanh',
|
||||
'tf.batch_to_space_nd': 'tf.manip.batch_to_space_nd',
|
||||
'tf.betainc': 'tf.math.betainc',
|
||||
'tf.ceil': 'tf.math.ceil',
|
||||
'tf.check_numerics': 'tf.debugging.check_numerics',
|
||||
'tf.cholesky': 'tf.linalg.cholesky',
|
||||
'tf.cos': 'tf.math.cos',
|
||||
'tf.cosh': 'tf.math.cosh',
|
||||
'tf.cross': 'tf.linalg.cross',
|
||||
'tf.decode_base64': 'tf.io.decode_base64',
|
||||
'tf.decode_compressed': 'tf.io.decode_compressed',
|
||||
'tf.decode_json_example': 'tf.io.decode_json_example',
|
||||
'tf.decode_raw': 'tf.io.decode_raw',
|
||||
'tf.dequantize': 'tf.quantization.dequantize',
|
||||
'tf.diag': 'tf.linalg.tensor_diag',
|
||||
'tf.diag_part': 'tf.linalg.tensor_diag_part',
|
||||
'tf.digamma': 'tf.math.digamma',
|
||||
'tf.encode_base64': 'tf.io.encode_base64',
|
||||
'tf.equal': 'tf.math.equal',
|
||||
'tf.erfc': 'tf.math.erfc',
|
||||
'tf.exp': 'tf.math.exp',
|
||||
'tf.expm1': 'tf.math.expm1',
|
||||
'tf.extract_image_patches': 'tf.image.extract_image_patches',
|
||||
'tf.fake_quant_with_min_max_args': 'tf.quantization.fake_quant_with_min_max_args',
|
||||
'tf.fake_quant_with_min_max_args_gradient': 'tf.quantization.fake_quant_with_min_max_args_gradient',
|
||||
'tf.fake_quant_with_min_max_vars': 'tf.quantization.fake_quant_with_min_max_vars',
|
||||
'tf.fake_quant_with_min_max_vars_gradient': 'tf.quantization.fake_quant_with_min_max_vars_gradient',
|
||||
'tf.fake_quant_with_min_max_vars_per_channel': 'tf.quantization.fake_quant_with_min_max_vars_per_channel',
|
||||
'tf.fake_quant_with_min_max_vars_per_channel_gradient': 'tf.quantization.fake_quant_with_min_max_vars_per_channel_gradient',
|
||||
'tf.fft': 'tf.spectral.fft',
|
||||
'tf.floor': 'tf.math.floor',
|
||||
'tf.gather_nd': 'tf.manip.gather_nd',
|
||||
'tf.greater': 'tf.math.greater',
|
||||
'tf.greater_equal': 'tf.math.greater_equal',
|
||||
'tf.ifft': 'tf.spectral.ifft',
|
||||
'tf.igamma': 'tf.math.igamma',
|
||||
'tf.igammac': 'tf.math.igammac',
|
||||
'tf.invert_permutation': 'tf.math.invert_permutation',
|
||||
'tf.is_finite': 'tf.debugging.is_finite',
|
||||
'tf.is_inf': 'tf.debugging.is_inf',
|
||||
'tf.is_nan': 'tf.debugging.is_nan',
|
||||
'tf.less': 'tf.math.less',
|
||||
'tf.less_equal': 'tf.math.less_equal',
|
||||
'tf.lgamma': 'tf.math.lgamma',
|
||||
'tf.log': 'tf.math.log',
|
||||
'tf.log1p': 'tf.math.log1p',
|
||||
'tf.logical_and': 'tf.math.logical_and',
|
||||
'tf.logical_not': 'tf.math.logical_not',
|
||||
'tf.logical_or': 'tf.math.logical_or',
|
||||
'tf.matching_files': 'tf.io.matching_files',
|
||||
'tf.matrix_band_part': 'tf.linalg.band_part',
|
||||
'tf.matrix_determinant': 'tf.linalg.det',
|
||||
'tf.matrix_diag': 'tf.linalg.diag',
|
||||
'tf.matrix_diag_part': 'tf.linalg.diag_part',
|
||||
'tf.matrix_inverse': 'tf.linalg.inv',
|
||||
'tf.matrix_set_diag': 'tf.linalg.set_diag',
|
||||
'tf.matrix_solve': 'tf.linalg.solve',
|
||||
'tf.matrix_triangular_solve': 'tf.linalg.triangular_solve',
|
||||
'tf.maximum': 'tf.math.maximum',
|
||||
'tf.minimum': 'tf.math.minimum',
|
||||
'tf.not_equal': 'tf.math.not_equal',
|
||||
'tf.parse_tensor': 'tf.io.parse_tensor',
|
||||
'tf.polygamma': 'tf.math.polygamma',
|
||||
'tf.qr': 'tf.linalg.qr',
|
||||
'tf.quantized_concat': 'tf.quantization.quantized_concat',
|
||||
'tf.read_file': 'tf.io.read_file',
|
||||
'tf.reciprocal': 'tf.math.reciprocal',
|
||||
'tf.regex_replace': 'tf.strings.regex_replace',
|
||||
'tf.reshape': 'tf.manip.reshape',
|
||||
'tf.reverse': 'tf.manip.reverse',
|
||||
'tf.reverse_v2': 'tf.manip.reverse',
|
||||
'tf.rint': 'tf.math.rint',
|
||||
'tf.rsqrt': 'tf.math.rsqrt',
|
||||
'tf.scatter_nd': 'tf.manip.scatter_nd',
|
||||
'tf.segment_max': 'tf.math.segment_max',
|
||||
'tf.segment_mean': 'tf.math.segment_mean',
|
||||
'tf.segment_min': 'tf.math.segment_min',
|
||||
'tf.segment_prod': 'tf.math.segment_prod',
|
||||
'tf.segment_sum': 'tf.math.segment_sum',
|
||||
'tf.sin': 'tf.math.sin',
|
||||
'tf.sinh': 'tf.math.sinh',
|
||||
'tf.space_to_batch_nd': 'tf.manip.space_to_batch_nd',
|
||||
'tf.squared_difference': 'tf.math.squared_difference',
|
||||
'tf.string_join': 'tf.strings.join',
|
||||
'tf.string_strip': 'tf.strings.strip',
|
||||
'tf.string_to_hash_bucket': 'tf.strings.to_hash_bucket',
|
||||
'tf.string_to_hash_bucket_fast': 'tf.strings.to_hash_bucket_fast',
|
||||
'tf.string_to_hash_bucket_strong': 'tf.strings.to_hash_bucket_strong',
|
||||
'tf.string_to_number': 'tf.strings.to_number',
|
||||
'tf.substr': 'tf.strings.substr',
|
||||
'tf.tan': 'tf.math.tan',
|
||||
'tf.tile': 'tf.manip.tile',
|
||||
'tf.unsorted_segment_max': 'tf.math.unsorted_segment_max',
|
||||
'tf.unsorted_segment_min': 'tf.math.unsorted_segment_min',
|
||||
'tf.unsorted_segment_prod': 'tf.math.unsorted_segment_prod',
|
||||
'tf.unsorted_segment_sum': 'tf.math.unsorted_segment_sum',
|
||||
'tf.write_file': 'tf.io.write_file',
|
||||
'tf.zeta': 'tf.math.zeta'
|
||||
}
|
34
tensorflow/tools/compatibility/testdata/test_file_v1_10.py
vendored
Normal file
34
tensorflow/tools/compatibility/testdata/test_file_v1_10.py
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf upgrader."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test as test_lib
|
||||
|
||||
|
||||
class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
"""Test various APIs that have been changed in 2.0."""
|
||||
|
||||
def testRenames(self):
|
||||
with self.test_session():
|
||||
self.assertAllClose(1.04719755, tf.acos(0.5).eval())
|
||||
self.assertAllClose(0.5, tf.rsqrt(4.0).eval())
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lib.main()
|
@ -19,491 +19,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import collections
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
|
||||
from tensorflow.tools.compatibility import ast_edits
|
||||
|
||||
class APIChangeSpec(object):
|
||||
"""This class defines the transformations that need to happen.
|
||||
|
||||
This class must provide the following fields:
|
||||
|
||||
* `function_keyword_renames`: maps function names to a map of old -> new
|
||||
argument names
|
||||
* `function_renames`: maps function names to new function names
|
||||
* `change_to_function`: a set of function names that have changed (for
|
||||
notifications)
|
||||
* `function_reorders`: maps functions whose argument order has changed to the
|
||||
list of arguments in the new order
|
||||
* `function_handle`: maps function names to custom handlers for the function
|
||||
|
||||
For an example, see `TFAPIChangeSpec`.
|
||||
"""
|
||||
|
||||
|
||||
class _FileEditTuple(
|
||||
collections.namedtuple("_FileEditTuple",
|
||||
["comment", "line", "start", "old", "new"])):
|
||||
"""Each edit that is recorded by a _FileEditRecorder.
|
||||
|
||||
Fields:
|
||||
comment: A description of the edit and why it was made.
|
||||
line: The line number in the file where the edit occurs (1-indexed).
|
||||
start: The line number in the file where the edit occurs (0-indexed).
|
||||
old: text string to remove (this must match what was in file).
|
||||
new: text string to add in place of `old`.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class _FileEditRecorder(object):
|
||||
"""Record changes that need to be done to the file."""
|
||||
|
||||
def __init__(self, filename):
|
||||
# all edits are lists of chars
|
||||
self._filename = filename
|
||||
|
||||
self._line_to_edit = collections.defaultdict(list)
|
||||
self._errors = []
|
||||
|
||||
def process(self, text):
|
||||
"""Process a list of strings, each corresponding to the recorded changes.
|
||||
|
||||
Args:
|
||||
text: A list of lines of text (assumed to contain newlines)
|
||||
Returns:
|
||||
A tuple of the modified text and a textual description of what is done.
|
||||
Raises:
|
||||
ValueError: if substitution source location does not have expected text.
|
||||
"""
|
||||
|
||||
change_report = ""
|
||||
|
||||
# Iterate of each line
|
||||
for line, edits in self._line_to_edit.items():
|
||||
offset = 0
|
||||
# sort by column so that edits are processed in order in order to make
|
||||
# indexing adjustments cumulative for changes that change the string
|
||||
# length
|
||||
edits.sort(key=lambda x: x.start)
|
||||
|
||||
# Extract each line to a list of characters, because mutable lists
|
||||
# are editable, unlike immutable strings.
|
||||
char_array = list(text[line - 1])
|
||||
|
||||
# Record a description of the change
|
||||
change_report += "%r Line %d\n" % (self._filename, line)
|
||||
change_report += "-" * 80 + "\n\n"
|
||||
for e in edits:
|
||||
change_report += "%s\n" % e.comment
|
||||
change_report += "\n Old: %s" % (text[line - 1])
|
||||
|
||||
# Make underscore buffers for underlining where in the line the edit was
|
||||
change_list = [" "] * len(text[line - 1])
|
||||
change_list_new = [" "] * len(text[line - 1])
|
||||
|
||||
# Iterate for each edit
|
||||
for e in edits:
|
||||
# Create effective start, end by accounting for change in length due
|
||||
# to previous edits
|
||||
start_eff = e.start + offset
|
||||
end_eff = start_eff + len(e.old)
|
||||
|
||||
# Make sure the edit is changing what it should be changing
|
||||
old_actual = "".join(char_array[start_eff:end_eff])
|
||||
if old_actual != e.old:
|
||||
raise ValueError("Expected text %r but got %r" %
|
||||
("".join(e.old), "".join(old_actual)))
|
||||
# Make the edit
|
||||
char_array[start_eff:end_eff] = list(e.new)
|
||||
|
||||
# Create the underline highlighting of the before and after
|
||||
change_list[e.start:e.start + len(e.old)] = "~" * len(e.old)
|
||||
change_list_new[start_eff:end_eff] = "~" * len(e.new)
|
||||
|
||||
# Keep track of how to generate effective ranges
|
||||
offset += len(e.new) - len(e.old)
|
||||
|
||||
# Finish the report comment
|
||||
change_report += " %s\n" % "".join(change_list)
|
||||
text[line - 1] = "".join(char_array)
|
||||
change_report += " New: %s" % (text[line - 1])
|
||||
change_report += " %s\n\n" % "".join(change_list_new)
|
||||
return "".join(text), change_report, self._errors
|
||||
|
||||
def add(self, comment, line, start, old, new, error=None):
|
||||
"""Add a new change that is needed.
|
||||
|
||||
Args:
|
||||
comment: A description of what was changed
|
||||
line: Line number (1 indexed)
|
||||
start: Column offset (0 indexed)
|
||||
old: old text
|
||||
new: new text
|
||||
error: this "edit" is something that cannot be fixed automatically
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
self._line_to_edit[line].append(
|
||||
_FileEditTuple(comment, line, start, old, new))
|
||||
if error:
|
||||
self._errors.append("%s:%d: %s" % (self._filename, line, error))
|
||||
|
||||
|
||||
class _ASTCallVisitor(ast.NodeVisitor):
|
||||
"""AST Visitor that processes function calls.
|
||||
|
||||
Updates function calls from old API version to new API version using a given
|
||||
change spec.
|
||||
"""
|
||||
|
||||
def __init__(self, filename, lines, api_change_spec):
|
||||
self._filename = filename
|
||||
self._file_edit = _FileEditRecorder(filename)
|
||||
self._lines = lines
|
||||
self._api_change_spec = api_change_spec
|
||||
|
||||
def process(self, lines):
|
||||
return self._file_edit.process(lines)
|
||||
|
||||
def generic_visit(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def _rename_functions(self, node, full_name):
|
||||
function_renames = self._api_change_spec.function_renames
|
||||
try:
|
||||
new_name = function_renames[full_name]
|
||||
self._file_edit.add("Renamed function %r to %r" % (full_name, new_name),
|
||||
node.lineno, node.col_offset, full_name, new_name)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def _get_attribute_full_path(self, node):
|
||||
"""Traverse an attribute to generate a full name e.g. tf.foo.bar.
|
||||
|
||||
Args:
|
||||
node: A Node of type Attribute.
|
||||
|
||||
Returns:
|
||||
a '.'-delimited full-name or None if the tree was not a simple form.
|
||||
i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
|
||||
"""
|
||||
curr = node
|
||||
items = []
|
||||
while not isinstance(curr, ast.Name):
|
||||
if not isinstance(curr, ast.Attribute):
|
||||
return None
|
||||
items.append(curr.attr)
|
||||
curr = curr.value
|
||||
items.append(curr.id)
|
||||
return ".".join(reversed(items))
|
||||
|
||||
def _find_true_position(self, node):
|
||||
"""Return correct line number and column offset for a given node.
|
||||
|
||||
This is necessary mainly because ListComp's location reporting reports
|
||||
the next token after the list comprehension list opening.
|
||||
|
||||
Args:
|
||||
node: Node for which we wish to know the lineno and col_offset
|
||||
"""
|
||||
import re
|
||||
find_open = re.compile("^\s*(\\[).*$")
|
||||
find_string_chars = re.compile("['\"]")
|
||||
|
||||
if isinstance(node, ast.ListComp):
|
||||
# Strangely, ast.ListComp returns the col_offset of the first token
|
||||
# after the '[' token which appears to be a bug. Workaround by
|
||||
# explicitly finding the real start of the list comprehension.
|
||||
line = node.lineno
|
||||
col = node.col_offset
|
||||
# loop over lines
|
||||
while 1:
|
||||
# Reverse the text to and regular expression search for whitespace
|
||||
text = self._lines[line - 1]
|
||||
reversed_preceding_text = text[:col][::-1]
|
||||
# First find if a [ can be found with only whitespace between it and
|
||||
# col.
|
||||
m = find_open.match(reversed_preceding_text)
|
||||
if m:
|
||||
new_col_offset = col - m.start(1) - 1
|
||||
return line, new_col_offset
|
||||
else:
|
||||
if (reversed_preceding_text == "" or
|
||||
reversed_preceding_text.isspace()):
|
||||
line = line - 1
|
||||
prev_line = self._lines[line - 1]
|
||||
# TODO(aselle):
|
||||
# this is poor comment detection, but it is good enough for
|
||||
# cases where the comment does not contain string literal starting/
|
||||
# ending characters. If ast gave us start and end locations of the
|
||||
# ast nodes rather than just start, we could use string literal
|
||||
# node ranges to filter out spurious #'s that appear in string
|
||||
# literals.
|
||||
comment_start = prev_line.find("#")
|
||||
if comment_start == -1:
|
||||
col = len(prev_line) - 1
|
||||
elif find_string_chars.search(prev_line[comment_start:]) is None:
|
||||
col = comment_start
|
||||
else:
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
# Most other nodes return proper locations (with notably does not), but
|
||||
# it is not possible to use that in an argument.
|
||||
return node.lineno, node.col_offset
|
||||
|
||||
def visit_Call(self, node): # pylint: disable=invalid-name
|
||||
"""Handle visiting a call node in the AST.
|
||||
|
||||
Args:
|
||||
node: Current Node
|
||||
"""
|
||||
|
||||
# Find a simple attribute name path e.g. "tf.foo.bar"
|
||||
full_name = self._get_attribute_full_path(node.func)
|
||||
|
||||
# Make sure the func is marked as being part of a call
|
||||
node.func.is_function_for_call = True
|
||||
|
||||
if full_name:
|
||||
# Call special handlers
|
||||
function_handles = self._api_change_spec.function_handle
|
||||
if full_name in function_handles:
|
||||
function_handles[full_name](self._file_edit, node)
|
||||
|
||||
# Examine any non-keyword argument and make it into a keyword argument
|
||||
# if reordering required.
|
||||
function_reorders = self._api_change_spec.function_reorders
|
||||
function_keyword_renames = (
|
||||
self._api_change_spec.function_keyword_renames)
|
||||
|
||||
if full_name in function_reorders:
|
||||
reordered = function_reorders[full_name]
|
||||
for idx, arg in enumerate(node.args):
|
||||
lineno, col_offset = self._find_true_position(arg)
|
||||
if lineno is None or col_offset is None:
|
||||
self._file_edit.add(
|
||||
"Failed to add keyword %r to reordered function %r" %
|
||||
(reordered[idx], full_name),
|
||||
arg.lineno,
|
||||
arg.col_offset,
|
||||
"",
|
||||
"",
|
||||
error="A necessary keyword argument failed to be inserted.")
|
||||
else:
|
||||
keyword_arg = reordered[idx]
|
||||
if (full_name in function_keyword_renames and
|
||||
keyword_arg in function_keyword_renames[full_name]):
|
||||
keyword_arg = function_keyword_renames[full_name][keyword_arg]
|
||||
self._file_edit.add("Added keyword %r to reordered function %r" %
|
||||
(reordered[idx], full_name), lineno, col_offset,
|
||||
"", keyword_arg + "=")
|
||||
|
||||
# Examine each keyword argument and convert it to the final renamed form
|
||||
renamed_keywords = ({} if full_name not in function_keyword_renames else
|
||||
function_keyword_renames[full_name])
|
||||
for keyword in node.keywords:
|
||||
argkey = keyword.arg
|
||||
argval = keyword.value
|
||||
|
||||
if argkey in renamed_keywords:
|
||||
argval_lineno, argval_col_offset = self._find_true_position(argval)
|
||||
if argval_lineno is not None and argval_col_offset is not None:
|
||||
# TODO(aselle): We should scan backward to find the start of the
|
||||
# keyword key. Unfortunately ast does not give you the location of
|
||||
# keyword keys, so we are forced to infer it from the keyword arg
|
||||
# value.
|
||||
key_start = argval_col_offset - len(argkey) - 1
|
||||
key_end = key_start + len(argkey) + 1
|
||||
if (self._lines[argval_lineno - 1][key_start:key_end] == argkey +
|
||||
"="):
|
||||
self._file_edit.add("Renamed keyword argument from %r to %r" %
|
||||
(argkey,
|
||||
renamed_keywords[argkey]), argval_lineno,
|
||||
argval_col_offset - len(argkey) - 1,
|
||||
argkey + "=", renamed_keywords[argkey] + "=")
|
||||
continue
|
||||
self._file_edit.add(
|
||||
"Failed to rename keyword argument from %r to %r" %
|
||||
(argkey, renamed_keywords[argkey]),
|
||||
argval.lineno,
|
||||
argval.col_offset - len(argkey) - 1,
|
||||
"",
|
||||
"",
|
||||
error="Failed to find keyword lexographically. Fix manually.")
|
||||
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
def visit_Attribute(self, node): # pylint: disable=invalid-name
|
||||
"""Handle bare Attributes i.e. [tf.foo, tf.bar].
|
||||
|
||||
Args:
|
||||
node: Node that is of type ast.Attribute
|
||||
"""
|
||||
full_name = self._get_attribute_full_path(node)
|
||||
if full_name:
|
||||
self._rename_functions(node, full_name)
|
||||
if full_name in self._api_change_spec.change_to_function:
|
||||
if not hasattr(node, "is_function_for_call"):
|
||||
new_text = full_name + "()"
|
||||
self._file_edit.add("Changed %r to %r" % (full_name, new_text),
|
||||
node.lineno, node.col_offset, full_name, new_text)
|
||||
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
|
||||
class ASTCodeUpgrader(object):
|
||||
"""Handles upgrading a set of Python files using a given API change spec."""
|
||||
|
||||
def __init__(self, api_change_spec):
|
||||
if not isinstance(api_change_spec, APIChangeSpec):
|
||||
raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" %
|
||||
type(api_change_spec))
|
||||
self._api_change_spec = api_change_spec
|
||||
|
||||
def process_file(self, in_filename, out_filename):
|
||||
"""Process the given python file for incompatible changes.
|
||||
|
||||
Args:
|
||||
in_filename: filename to parse
|
||||
out_filename: output file to write to
|
||||
Returns:
|
||||
A tuple representing number of files processed, log of actions, errors
|
||||
"""
|
||||
|
||||
# Write to a temporary file, just in case we are doing an implace modify.
|
||||
with open(in_filename, "r") as in_file, \
|
||||
tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
|
||||
ret = self.process_opened_file(in_filename, in_file, out_filename,
|
||||
temp_file)
|
||||
|
||||
shutil.move(temp_file.name, out_filename)
|
||||
return ret
|
||||
|
||||
# Broad exceptions are required here because ast throws whatever it wants.
|
||||
# pylint: disable=broad-except
|
||||
def process_opened_file(self, in_filename, in_file, out_filename, out_file):
|
||||
"""Process the given python file for incompatible changes.
|
||||
|
||||
This function is split out to facilitate StringIO testing from
|
||||
tf_upgrade_test.py.
|
||||
|
||||
Args:
|
||||
in_filename: filename to parse
|
||||
in_file: opened file (or StringIO)
|
||||
out_filename: output file to write to
|
||||
out_file: opened file (or StringIO)
|
||||
Returns:
|
||||
A tuple representing number of files processed, log of actions, errors
|
||||
"""
|
||||
process_errors = []
|
||||
text = "-" * 80 + "\n"
|
||||
text += "Processing file %r\n outputting to %r\n" % (in_filename,
|
||||
out_filename)
|
||||
text += "-" * 80 + "\n\n"
|
||||
|
||||
parsed_ast = None
|
||||
lines = in_file.readlines()
|
||||
try:
|
||||
parsed_ast = ast.parse("".join(lines))
|
||||
except Exception:
|
||||
text += "Failed to parse %r\n\n" % in_filename
|
||||
text += traceback.format_exc()
|
||||
if parsed_ast:
|
||||
visitor = _ASTCallVisitor(in_filename, lines, self._api_change_spec)
|
||||
visitor.visit(parsed_ast)
|
||||
out_text, new_text, process_errors = visitor.process(lines)
|
||||
text += new_text
|
||||
if out_file:
|
||||
out_file.write(out_text)
|
||||
text += "\n"
|
||||
return 1, text, process_errors
|
||||
|
||||
# pylint: enable=broad-except
|
||||
|
||||
def process_tree(self, root_directory, output_root_directory,
|
||||
copy_other_files):
|
||||
"""Processes upgrades on an entire tree of python files in place.
|
||||
|
||||
Note that only Python files. If you have custom code in other languages,
|
||||
you will need to manually upgrade those.
|
||||
|
||||
Args:
|
||||
root_directory: Directory to walk and process.
|
||||
output_root_directory: Directory to use as base.
|
||||
copy_other_files: Copy files that are not touched by this converter.
|
||||
|
||||
Returns:
|
||||
A tuple of files processed, the report string ofr all files, and errors
|
||||
"""
|
||||
|
||||
# make sure output directory doesn't exist
|
||||
if output_root_directory and os.path.exists(output_root_directory):
|
||||
print("Output directory %r must not already exist." %
|
||||
(output_root_directory))
|
||||
sys.exit(1)
|
||||
|
||||
# make sure output directory does not overlap with root_directory
|
||||
norm_root = os.path.split(os.path.normpath(root_directory))
|
||||
norm_output = os.path.split(os.path.normpath(output_root_directory))
|
||||
if norm_root == norm_output:
|
||||
print("Output directory %r same as input directory %r" %
|
||||
(root_directory, output_root_directory))
|
||||
sys.exit(1)
|
||||
|
||||
# Collect list of files to process (we do this to correctly handle if the
|
||||
# user puts the output directory in some sub directory of the input dir)
|
||||
files_to_process = []
|
||||
files_to_copy = []
|
||||
for dir_name, _, file_list in os.walk(root_directory):
|
||||
py_files = [f for f in file_list if f.endswith(".py")]
|
||||
copy_files = [f for f in file_list if not f.endswith(".py")]
|
||||
for filename in py_files:
|
||||
fullpath = os.path.join(dir_name, filename)
|
||||
fullpath_output = os.path.join(output_root_directory,
|
||||
os.path.relpath(fullpath,
|
||||
root_directory))
|
||||
files_to_process.append((fullpath, fullpath_output))
|
||||
if copy_other_files:
|
||||
for filename in copy_files:
|
||||
fullpath = os.path.join(dir_name, filename)
|
||||
fullpath_output = os.path.join(output_root_directory,
|
||||
os.path.relpath(
|
||||
fullpath, root_directory))
|
||||
files_to_copy.append((fullpath, fullpath_output))
|
||||
|
||||
file_count = 0
|
||||
tree_errors = []
|
||||
report = ""
|
||||
report += ("=" * 80) + "\n"
|
||||
report += "Input tree: %r\n" % root_directory
|
||||
report += ("=" * 80) + "\n"
|
||||
|
||||
for input_path, output_path in files_to_process:
|
||||
output_directory = os.path.dirname(output_path)
|
||||
if not os.path.isdir(output_directory):
|
||||
os.makedirs(output_directory)
|
||||
file_count += 1
|
||||
_, l_report, l_errors = self.process_file(input_path, output_path)
|
||||
tree_errors += l_errors
|
||||
report += l_report
|
||||
for input_path, output_path in files_to_copy:
|
||||
output_directory = os.path.dirname(output_path)
|
||||
if not os.path.isdir(output_directory):
|
||||
os.makedirs(output_directory)
|
||||
shutil.copy(input_path, output_path)
|
||||
return file_count, report, tree_errors
|
||||
|
||||
|
||||
class TFAPIChangeSpec(APIChangeSpec):
|
||||
class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"""List of maps that describe what changed in the API."""
|
||||
|
||||
def __init__(self):
|
||||
@ -718,7 +238,7 @@ Simple usage:
|
||||
default="report.txt")
|
||||
args = parser.parse_args()
|
||||
|
||||
upgrade = ASTCodeUpgrader(TFAPIChangeSpec())
|
||||
upgrade = ast_edits.ASTCodeUpgrader(TFAPIChangeSpec())
|
||||
report_text = None
|
||||
report_filename = args.report_filename
|
||||
files_processed = 0
|
||||
|
@ -22,6 +22,7 @@ import tempfile
|
||||
import six
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test as test_lib
|
||||
from tensorflow.tools.compatibility import ast_edits
|
||||
from tensorflow.tools.compatibility import tf_upgrade
|
||||
|
||||
|
||||
@ -36,7 +37,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
def _upgrade(self, old_file_text):
|
||||
in_file = six.StringIO(old_file_text)
|
||||
out_file = six.StringIO()
|
||||
upgrader = tf_upgrade.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec())
|
||||
upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec())
|
||||
count, report, errors = (
|
||||
upgrader.process_opened_file("test.py", in_file,
|
||||
"test_out.py", out_file))
|
||||
@ -139,7 +140,7 @@ class TestUpgradeFiles(test_util.TensorFlowTestCase):
|
||||
upgraded = "tf.multiply(a, b)\n"
|
||||
temp_file.write(original)
|
||||
temp_file.close()
|
||||
upgrader = tf_upgrade.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec())
|
||||
upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec())
|
||||
upgrader.process_file(temp_file.name, temp_file.name)
|
||||
self.assertAllEqual(open(temp_file.name).read(), upgraded)
|
||||
os.unlink(temp_file.name)
|
||||
|
115
tensorflow/tools/compatibility/tf_upgrade_v2.py
Normal file
115
tensorflow/tools/compatibility/tf_upgrade_v2.py
Normal file
@ -0,0 +1,115 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Upgrader for Python scripts from 1.* TensorFlow to 2.0 TensorFlow."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
|
||||
from tensorflow.tools.compatibility import ast_edits
|
||||
from tensorflow.tools.compatibility import renames_v2
|
||||
|
||||
|
||||
class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"""List of maps that describe what changed in the API."""
|
||||
|
||||
def __init__(self):
|
||||
# Maps from a function name to a dictionary that describes how to
|
||||
# map from an old argument keyword to the new argument keyword.
|
||||
self.function_keyword_renames = {}
|
||||
|
||||
# Mapping from function to the new name of the function
|
||||
self.function_renames = renames_v2.renames
|
||||
|
||||
# Variables that should be changed to functions.
|
||||
self.change_to_function = {}
|
||||
|
||||
# Functions that were reordered should be changed to the new keyword args
|
||||
# for safety, if positional arguments are used. If you have reversed the
|
||||
# positional arguments yourself, this could do the wrong thing.
|
||||
self.function_reorders = {}
|
||||
|
||||
# Specially handled functions.
|
||||
self.function_handle = {}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
description="""Convert a TensorFlow Python file to 2.0
|
||||
|
||||
Simple usage:
|
||||
tf_convert_v2.py --infile foo.py --outfile bar.py
|
||||
tf_convert_v2.py --intree ~/code/old --outtree ~/code/new
|
||||
""")
|
||||
parser.add_argument(
|
||||
"--infile",
|
||||
dest="input_file",
|
||||
help="If converting a single file, the name of the file "
|
||||
"to convert")
|
||||
parser.add_argument(
|
||||
"--outfile",
|
||||
dest="output_file",
|
||||
help="If converting a single file, the output filename.")
|
||||
parser.add_argument(
|
||||
"--intree",
|
||||
dest="input_tree",
|
||||
help="If converting a whole tree of files, the directory "
|
||||
"to read from (relative or absolute).")
|
||||
parser.add_argument(
|
||||
"--outtree",
|
||||
dest="output_tree",
|
||||
help="If converting a whole tree of files, the output "
|
||||
"directory (relative or absolute).")
|
||||
parser.add_argument(
|
||||
"--copyotherfiles",
|
||||
dest="copy_other_files",
|
||||
help=("If converting a whole tree of files, whether to "
|
||||
"copy the other files."),
|
||||
type=bool,
|
||||
default=False)
|
||||
parser.add_argument(
|
||||
"--reportfile",
|
||||
dest="report_filename",
|
||||
help=("The name of the file where the report log is "
|
||||
"stored."
|
||||
"(default: %(default)s)"),
|
||||
default="report.txt")
|
||||
args = parser.parse_args()
|
||||
|
||||
upgrade = ast_edits.ASTCodeUpgrader(TFAPIChangeSpec())
|
||||
report_text = None
|
||||
report_filename = args.report_filename
|
||||
files_processed = 0
|
||||
if args.input_file:
|
||||
files_processed, report_text, errors = upgrade.process_file(
|
||||
args.input_file, args.output_file)
|
||||
files_processed = 1
|
||||
elif args.input_tree:
|
||||
files_processed, report_text, errors = upgrade.process_tree(
|
||||
args.input_tree, args.output_tree, args.copy_other_files)
|
||||
else:
|
||||
parser.print_help()
|
||||
if report_text:
|
||||
open(report_filename, "w").write(report_text)
|
||||
print("TensorFlow 2.0 Upgrade Script")
|
||||
print("-----------------------------")
|
||||
print("Converted %d files\n" % files_processed)
|
||||
print("Detected %d errors that require attention" % len(errors))
|
||||
print("-" * 80)
|
||||
print("\n".join(errors))
|
||||
print("\nMake sure to read the detailed log %r\n" % report_filename)
|
83
tensorflow/tools/compatibility/tf_upgrade_v2_test.py
Normal file
83
tensorflow/tools/compatibility/tf_upgrade_v2_test.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf 2.0 upgrader."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import tempfile
|
||||
import six
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test as test_lib
|
||||
from tensorflow.tools.compatibility import ast_edits
|
||||
from tensorflow.tools.compatibility import tf_upgrade_v2
|
||||
|
||||
|
||||
class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
"""Test various APIs that have been changed in 2.0.
|
||||
|
||||
We also test whether a converted file is executable. test_file_v1_10.py
|
||||
aims to exhaustively test that API changes are convertible and actually
|
||||
work when run with current TensorFlow.
|
||||
"""
|
||||
|
||||
def _upgrade(self, old_file_text):
|
||||
in_file = six.StringIO(old_file_text)
|
||||
out_file = six.StringIO()
|
||||
upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec())
|
||||
count, report, errors = (
|
||||
upgrader.process_opened_file("test.py", in_file,
|
||||
"test_out.py", out_file))
|
||||
return count, report, errors, out_file.getvalue()
|
||||
|
||||
def testParseError(self):
|
||||
_, report, unused_errors, unused_new_text = self._upgrade(
|
||||
"import tensorflow as tf\na + \n")
|
||||
self.assertTrue(report.find("Failed to parse") != -1)
|
||||
|
||||
def testReport(self):
|
||||
text = "tf.acos(a)\n"
|
||||
_, report, unused_errors, unused_new_text = self._upgrade(text)
|
||||
# This is not a complete test, but it is a sanity test that a report
|
||||
# is generating information.
|
||||
self.assertTrue(report.find("Renamed function `tf.acos` to `tf.math.acos`"))
|
||||
|
||||
def testRename(self):
|
||||
text = "tf.acos(a)\n"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, "tf.math.acos(a)\n")
|
||||
text = "tf.rsqrt(tf.log(3.8))\n"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, "tf.math.rsqrt(tf.math.log(3.8))\n")
|
||||
|
||||
|
||||
class TestUpgradeFiles(test_util.TensorFlowTestCase):
|
||||
|
||||
def testInplace(self):
|
||||
"""Check to make sure we don't have a file system race."""
|
||||
temp_file = tempfile.NamedTemporaryFile("w", delete=False)
|
||||
original = "tf.acos(a, b)\n"
|
||||
upgraded = "tf.math.acos(a, b)\n"
|
||||
temp_file.write(original)
|
||||
temp_file.close()
|
||||
upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec())
|
||||
upgrader.process_file(temp_file.name, temp_file.name)
|
||||
self.assertAllEqual(open(temp_file.name).read(), upgraded)
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lib.main()
|
15
tensorflow/tools/compatibility/update/BUILD
Normal file
15
tensorflow/tools/compatibility/update/BUILD
Normal file
@ -0,0 +1,15 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
py_binary(
|
||||
name = "generate_v2_renames_map",
|
||||
srcs = ["generate_v2_renames_map.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/tools/common:public_api",
|
||||
"//tensorflow/tools/common:traverse",
|
||||
],
|
||||
)
|
103
tensorflow/tools/compatibility/update/generate_v2_renames_map.py
Normal file
103
tensorflow/tools/compatibility/update/generate_v2_renames_map.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=line-too-long
|
||||
"""Script for updating tensorflow/tools/compatibility/renames_v2.py.
|
||||
|
||||
To update renames_v2.py, run:
|
||||
bazel build tensorflow/tools/compatibility/update:generate_v2_renames_map
|
||||
bazel-bin/tensorflow/tools/compatibility/update/generate_v2_renames_map
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_export
|
||||
from tensorflow.tools.common import public_api
|
||||
from tensorflow.tools.common import traverse
|
||||
|
||||
|
||||
_OUTPUT_FILE_PATH = 'third_party/tensorflow/tools/compatibility/renames_v2.py'
|
||||
_FILE_HEADER = """# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=line-too-long
|
||||
\"\"\"List of renames to apply when converting from TF 1.0 to TF 2.0.
|
||||
|
||||
THIS FILE IS AUTOGENERATED: To update, please run:
|
||||
bazel build tensorflow/tools/compatibility/update:generate_v2_renames_map
|
||||
bazel-bin/tensorflow/tools/compatibility/update/generate_v2_renames_map
|
||||
This file should be updated whenever endpoints are deprecated.
|
||||
\"\"\"
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def update_renames_v2(output_file_path):
|
||||
"""Writes a Python dictionary mapping deprecated to canonical API names.
|
||||
|
||||
Args:
|
||||
output_file_path: File path to write output to. Any existing contents
|
||||
would be replaced.
|
||||
"""
|
||||
# Set of rename lines to write to output file in the form:
|
||||
# 'tf.deprecated_name': 'tf.canonical_name'
|
||||
rename_line_set = set()
|
||||
# _tf_api_names attribute name
|
||||
tensorflow_api_attr = tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names
|
||||
|
||||
def visit(unused_path, unused_parent, children):
|
||||
"""Visitor that collects rename strings to add to rename_line_set."""
|
||||
for child in children:
|
||||
_, attr = tf_decorator.unwrap(child[1])
|
||||
if not hasattr(attr, '__dict__'):
|
||||
continue
|
||||
api_names = attr.__dict__.get(tensorflow_api_attr, [])
|
||||
deprecated_api_names = attr.__dict__.get('_tf_deprecated_api_names', [])
|
||||
canonical_name = tf_export.get_canonical_name(
|
||||
api_names, deprecated_api_names)
|
||||
for name in deprecated_api_names:
|
||||
rename_line_set.add(' \'tf.%s\': \'tf.%s\'' % (name, canonical_name))
|
||||
|
||||
visitor = public_api.PublicAPIVisitor(visit)
|
||||
visitor.do_not_descend_map['tf'].append('contrib')
|
||||
traverse.traverse(tf, visitor)
|
||||
|
||||
renames_file_text = '%srenames = {\n%s\n}\n' % (
|
||||
_FILE_HEADER, ',\n'.join(sorted(rename_line_set)))
|
||||
file_io.write_string_to_file(output_file_path, renames_file_text)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
update_renames_v2(_OUTPUT_FILE_PATH)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.app.run(main=main)
|
Loading…
Reference in New Issue
Block a user