Clean up conversion warning language, and use the ability to distinguish info, warning, and errors better.

Also improves the conversion of some symbols (changes a warning to a conversion for more cases that is now possible).

PiperOrigin-RevId: 231334490
This commit is contained in:
Martin Wicke 2019-01-28 19:52:50 -08:00 committed by TensorFlower Gardener
parent ad65038c90
commit 2ac6545b3a
7 changed files with 632 additions and 363 deletions

View File

@ -62,6 +62,45 @@ def full_name_node(name, ctx=ast.Load()):
return node
def get_arg_value(node, arg_name, arg_pos=None):
"""Get the value of an argument from a ast.Call node.
This function goes through the positional and keyword arguments to check
whether a given argument was used, and if so, returns its value (the node
representing its value).
This cannot introspect *args or **args, but it safely handles *args in
Python3.5+.
Args:
node: The ast.Call node to extract arg values from.
arg_name: The name of the argument to extract.
arg_pos: The position of the argument (in case it's passed as a positional
argument).
Returns:
A tuple (arg_present, arg_value) containing a boolean indicating whether
the argument is present, and its value in case it is.
"""
# Check keyword args
if arg_name is not None:
for kw in node.keywords:
if kw.arg == arg_name:
return (True, kw.value)
# Check positional args
if arg_pos is not None:
idx = 0
for arg in node.args:
if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred):
continue # Can't parse Starred
if idx == arg_pos:
return (True, arg)
idx += 1
return (False, None)
class APIChangeSpec(object):
"""This class defines the transformations that need to happen.
@ -195,11 +234,10 @@ class _PastaEditVisitor(ast.NodeVisitor):
"""Adds an error to be printed about full_name at node."""
function_warnings = self._api_change_spec.function_warnings
if full_name in function_warnings:
warning_message = function_warnings[full_name]
warning_message = warning_message.replace("<function name>", full_name)
self.add_log(WARNING, node.lineno, node.col_offset,
"%s requires manual check. %s" % (full_name,
warning_message))
level, message = function_warnings[full_name]
message = message.replace("<function name>", full_name)
self.add_log(level, node.lineno, node.col_offset,
"%s requires manual check. %s" % (full_name, message))
return True
else:
return False
@ -232,13 +270,13 @@ class _PastaEditVisitor(ast.NodeVisitor):
arg_warnings = self._get_applicable_dict("function_arg_warnings",
full_name, name)
used_args = [kw.arg for kw in node.keywords]
for (kwarg, arg), warning in arg_warnings.items():
if kwarg in used_args or len(node.args) > arg:
for (kwarg, arg), (level, warning) in arg_warnings.items():
present, _ = get_arg_value(node, kwarg, arg)
if present:
warned = True
warning_message = warning.replace("<function name>", full_name or name)
self.add_log(WARNING, node.lineno, node.col_offset,
"%s called with %s argument requires manual check: %s." %
self.add_log(level, node.lineno, node.col_offset,
"%s called with %s argument requires manual check: %s" %
(full_name or name, kwarg, warning_message))
return warned
@ -280,12 +318,14 @@ class _PastaEditVisitor(ast.NodeVisitor):
if full_name in function_reorders:
reordered = function_reorders[full_name]
new_keywords = []
for idx, arg in enumerate(node.args):
idx = 0
for arg in node.args:
if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred):
continue # Can't move Starred to keywords
keyword_arg = reordered[idx]
keyword = ast.keyword(arg=keyword_arg, value=arg)
new_keywords.append(keyword)
idx += 1
if new_keywords:
self.add_log(INFO, node.lineno, node.col_offset,
@ -374,11 +414,10 @@ class _PastaEditVisitor(ast.NodeVisitor):
logs = []
new_node = transformer(parent, node, full_name, name, logs)
self.add_logs(logs)
if new_node:
if new_node is not node:
pasta.ast_utils.replace_child(parent, node, new_node)
node = new_node
self._stack[-1] = node
if new_node and new_node is not node:
pasta.ast_utils.replace_child(parent, node, new_node)
node = new_node
self._stack[-1] = node
self.generic_visit(node)

View File

@ -417,7 +417,7 @@ class TestAstEdits(test_util.TensorFlowTestCase):
def __init__(self):
NoUpdateSpec.__init__(self)
self.function_warnings = {"*.foo": "not good"}
self.function_warnings = {"*.foo": (ast_edits.WARNING, "not good")}
texts = ["object.foo()", "get_object().foo()",
"get_object().foo()", "object.foo().bar()"]

View File

@ -37,6 +37,7 @@ reorders = {
'tf.decode_csv': ['records', 'record_defaults', 'field_delim', 'use_quote_delim', 'name', 'na_value', 'select_cols'],
'tf.depth_to_space': ['input', 'block_size', 'name', 'data_format'],
'tf.feature_column.categorical_column_with_vocabulary_file': ['key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets', 'default_value', 'dtype'],
'tf.image.sample_distorted_bounding_box': ['image_size', 'bounding_boxes', 'seed', 'seed2', 'min_object_covered', 'aspect_ratio_range', 'area_range', 'max_attempts', 'use_image_if_no_bounding_boxes', 'name'],
'tf.io.decode_csv': ['records', 'record_defaults', 'field_delim', 'use_quote_delim', 'name', 'na_value', 'select_cols'],
'tf.io.parse_example': ['serialized', 'features', 'name', 'example_names'],
'tf.io.parse_single_example': ['serialized', 'features', 'name', 'example_names'],
@ -56,12 +57,19 @@ reorders = {
'tf.math.reduce_prod': ['input_tensor', 'axis', 'keepdims', 'name', 'reduction_indices', 'keep_dims'],
'tf.math.reduce_sum': ['input_tensor', 'axis', 'keepdims', 'name', 'reduction_indices', 'keep_dims'],
'tf.multinomial': ['logits', 'num_samples', 'seed', 'name', 'output_dtype'],
'tf.nn.conv1d': ['value', 'filters', 'stride', 'padding', 'use_cudnn_on_gpu', 'data_format', 'name'],
'tf.nn.conv2d': ['input', 'filter', 'strides', 'padding', 'use_cudnn_on_gpu', 'data_format', 'dilations', 'name'],
'tf.nn.conv2d_backprop_filter': ['input', 'filter_sizes', 'out_backprop', 'strides', 'padding', 'use_cudnn_on_gpu', 'data_format', 'dilations', 'name'],
'tf.nn.conv2d_backprop_input': ['input_sizes', 'filter', 'out_backprop', 'strides', 'padding', 'use_cudnn_on_gpu', 'data_format', 'dilations', 'name'],
'tf.nn.convolution': ['input', 'filter', 'padding', 'strides', 'dilation_rate', 'name', 'data_format'],
'tf.nn.crelu': ['features', 'name', 'axis'],
'tf.nn.ctc_beam_search_decoder': ['inputs', 'sequence_length', 'beam_width', 'top_paths', 'merge_repeated'],
'tf.nn.depth_to_space': ['input', 'block_size', 'name', 'data_format'],
'tf.nn.depthwise_conv2d': ['input', 'filter', 'strides', 'padding', 'rate', 'name', 'data_format'],
'tf.nn.embedding_lookup': ['params', 'ids', 'partition_strategy', 'name', 'validate_indices', 'max_norm'],
'tf.nn.embedding_lookup_sparse': ['params', 'sp_ids', 'sp_weights', 'partition_strategy', 'name', 'combiner', 'max_norm'],
'tf.nn.fractional_avg_pool': ['value', 'pooling_ratio', 'pseudo_random', 'overlapping', 'deterministic', 'seed', 'seed2', 'name'],
'tf.nn.fractional_max_pool': ['value', 'pooling_ratio', 'pseudo_random', 'overlapping', 'deterministic', 'seed', 'seed2', 'name'],
'tf.nn.in_top_k': ['predictions', 'targets', 'k', 'name'],
'tf.nn.moments': ['x', 'axes', 'shift', 'name', 'keep_dims'],
'tf.nn.pool': ['input', 'window_shape', 'pooling_type', 'padding', 'dilation_rate', 'strides', 'name', 'data_format'],
@ -113,6 +121,7 @@ reorders = {
'tf.strings.reduce_join': ['inputs', 'axis', 'keep_dims', 'separator', 'name', 'reduction_indices'],
'tf.strings.substr': ['input', 'pos', 'len', 'name', 'unit'],
'tf.substr': ['input', 'pos', 'len', 'name', 'unit'],
'tf.test.assert_equal_graph_def': ['actual', 'expected', 'checkpoint_v2'],
'tf.transpose': ['a', 'perm', 'name', 'conjugate'],
'tf.tuple': ['tensors', 'name', 'control_inputs'],
'tf.while_loop': ['cond', 'body', 'loop_vars', 'shape_invariants', 'parallel_iterations', 'back_prop', 'swap_memory', 'name', 'maximum_iterations', 'return_same_structure']

View File

@ -177,10 +177,11 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
# Warnings that should be printed if corresponding functions are used.
self.function_warnings = {
"tf.reverse":
"ERROR: tf.reverse has had its argument semantics changed "
"tf.reverse": (
ast_edits.ERROR,
"tf.reverse has had its argument semantics changed "
"significantly. The converter cannot detect this reliably, so "
"you need to inspect this usage manually.\n",
"you need to inspect this usage manually.\n"),
}

View File

@ -19,6 +19,8 @@ from __future__ import division
from __future__ import print_function
import ast
import functools
import sys
import pasta
import six
@ -27,6 +29,9 @@ from tensorflow.tools.compatibility import ast_edits
from tensorflow.tools.compatibility import renames_v2
from tensorflow.tools.compatibility import reorders_v2
# These pylint warnings are a mistake.
# pylint: disable=g-explicit-bool-comparison,g-bool-id-comparison
class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"""List of maps that describe what changed in the API."""
@ -38,6 +43,15 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
# Only keyword args are handled, so make sure to also put any function in
# function_reorders to ensure that all args are made into keywords first.
self.function_keyword_renames = {
"tf.test.assert_equal_graph_def": {
"checkpoint_v2": None,
},
"tf.nn.embedding_lookup": {
"validate_indices": None,
},
"tf.image.sample_distorted_bounding_box": {
"seed2": None,
},
"tf.gradients": {
"colocate_gradients_with_ops": None,
},
@ -380,6 +394,21 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"tf.nn.weighted_moments": {
"keep_dims": "keepdims"
},
"tf.nn.conv1d": {
"value": "input",
"use_cudnn_on_gpu": None,
},
"tf.nn.conv2d": {
"filter": "filters",
"use_cudnn_on_gpu": None,
},
"tf.nn.conv2d_backprop_filter": {
"use_cudnn_on_gpu": None,
},
"tf.nn.conv2d_backprop_input": {
"filter": "filters",
"use_cudnn_on_gpu": None,
},
}
# pylint: disable=line-too-long
@ -662,6 +691,10 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
# Mapping from function to the new name of the function
self.symbol_renames = renames_v2.renames
self.symbol_renames.update(self.manual_symbol_renames)
self.symbol_renames = {
name: new_name
for name, new_name in self.symbol_renames.items()
}
# Variables that should be changed to functions.
self.change_to_function = {}
@ -684,6 +717,11 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"tf.nn.space_to_batch",
"tf.boolean_mask",
"tf.convert_to_tensor",
"tf.nn.conv1d",
"tf.nn.conv2d",
"tf.nn.conv2d_backprop_filter",
"tf.nn.conv2d_backprop_input",
"tf.nn.ctc_beam_search_decoder",
"tf.nn.moments",
"tf.nn.convolution",
"tf.nn.crelu",
@ -746,6 +784,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"tf.nn.embedding_lookup_sparse",
"tf.nn.in_top_k",
"tf.nn.space_to_depth",
"tf.test.assert_equal_graph_def",
"tf.linalg.norm",
"tf.norm",
"tf.reverse_sequence",
@ -754,6 +793,9 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
# keyword arguments. Add keyword arguments in rare case when they
# are not specified.
"tf.nn.softmax_cross_entropy_with_logits",
"tf.nn.fractional_avg_pool",
"tf.nn.fractional_max_pool",
"tf.image.sample_distorted_bounding_box",
}
# Functions that were reordered should be changed to the new keyword args
@ -761,43 +803,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
# positional arguments yourself, this could do the wrong thing.
self.function_reorders = reorders_v2.reorders
# Specially handled functions (pasta version)
# Each transformer is a callable which will be called with the arguments
# transformer(parent, node, full_name, name, logs, errors)
# Where logs and errors are lists to which (line, col, msg) tuples can be
# appended, full_name is the FQN of the function called (or None if that is
# unknown), name is the name of the function called (or None is that is
# unknown). node is an ast.Call node representing this function call, and
# parent is its parent in the AST.
# The function may modify node (but not parent), and must return
# - none, if nothing was modified
# - node, if node was modified in place (make sure to use
# pasta.ast_utils.replace_child to swap out children, otherwise formatting
# may get messy)
# - a replacement for node, if the whole call node was replaced. The caller
# will take care of changing parent.
self.function_transformers = {
"*.make_initializable_iterator": self._iterator_transformer,
"*.make_one_shot_iterator": self._iterator_transformer,
"tf.nn.dropout": self._dropout_transformer,
"tf.batch_gather": self._batch_gather_transformer,
"tf.to_bfloat16": self._cast_transformer,
"tf.to_complex128": self._cast_transformer,
"tf.to_complex64": self._cast_transformer,
"tf.to_double": self._cast_transformer,
"tf.to_float": self._cast_transformer,
"tf.to_int32": self._cast_transformer,
"tf.to_int64": self._cast_transformer,
"tf.nn.softmax_cross_entropy_with_logits":
self._softmax_cross_entropy_with_logits_transformer,
"tf.image.resize_area": self._image_resize_transformer,
"tf.image.resize_bicubic": self._image_resize_transformer,
"tf.image.resize_bilinear": self._image_resize_transformer,
"tf.image.resize_nearest_neighbor": self._image_resize_transformer,
}
decay_function_comment = (
ast_edits.INFO,
"<function name> has been changed to return a callable instead "
"of a tensor when graph building, but its functionality remains "
"unchanged during eager execution (returns a callable like "
@ -806,64 +813,66 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
" be correct).\n"
)
# TODO(b/118888586): add default value change to update script.
default_loss_reduction_changed = (
"default value of loss_reduction has been changed to "
"SUM_OVER_BATCH_SIZE.\n"
)
assert_return_type_comment = (
"assert_* functions have been changed to return None, the "
ast_edits.INFO,
"<function name> has been changed to return None, the "
"data argument has been removed, and arguments have been reordered."
"\nThe calls have been converted to compat.v1 for safety (even though "
" they may already have been correct)."
)
assert_rank_comment = (
"assert_rank_* functions have been changed to return None, and"
ast_edits.INFO,
"<function name> has been changed to return None, and"
" the data and summarize arguments have been removed."
"\nThe calls have been converted to compat.v1 for safety (even though "
" they may already have been correct)."
)
tf_01s_like_no_optimize_comment = (
"tf.zeros_like and tf.ones_like no longer have the optimize "
"argument in TF 2.0 or after (also, `tensor' argument is renamed to "
"`input')."
"\nThe calls have been converted to compat.v1 for safety (even though "
" they may already have been correct)."
)
deprecate_partition_strategy_comment = (
"`partition_strategy` has been removed from `%s` "
" The 'div' strategy is used by default.")
initializers_no_dtype_comment = (
"tf.initializers and tf.keras.initializers no longer have the "
ast_edits.INFO,
"Initializers no longer have the "
"dtype argument in the constructor or partition_info argument in the "
"call method in TF 2.0 and after. The only API symbols are now "
"tf.keras.initializers.* or tf.initializers.*."
"\nThe calls have been converted to compat.v1 for safety (even though "
"they may already have been correct).")
uniform_unit_scaling_initializer_comment = (
"uniform_unit_scaling_initializer has been removed. Please use"
" tf.initializers.variance_scaling instead with distribution=uniform "
"to get equivalent behaviour.")
"__call__ method.\nThe calls have been converted to compat.v1 for"
"safety (even though they may already have been correct).")
metrics_comment = (
"tf.metrics have been converted to object oriented versions in"
ast_edits.INFO,
"tf.metrics have been replaced with object oriented versions in"
" TF 2.0 and after. The metric function calls have been converted to "
"compat.v1 for backward compatibility. Please update these calls to "
"the TF 2.0 versions.")
losses_comment = (
"tf.losses have been converted to object oriented versions in"
ast_edits.INFO,
"tf.losses have been replaced with object oriented versions in"
" TF 2.0 and after. The loss function calls have been converted to "
"compat.v1 for backward compatibility. Please update these calls to "
"the TF 2.0 versions.")
# This could be done with a _rename_if_arg_not_found_transformer
deprecate_partition_strategy_comment = (
ast_edits.WARNING,
"`partition_strategy` has been removed from <function name>. "
" The 'div' strategy will be used by default.")
# TODO(b/118888586): add default value change to update script.
default_loss_reduction_changed = (
ast_edits.WARNING,
"default value of loss_reduction has been changed to "
"SUM_OVER_BATCH_SIZE.\n"
)
# make change instead
uniform_unit_scaling_initializer_comment = (
ast_edits.ERROR,
"uniform_unit_scaling_initializer has been removed. Please use"
" tf.initializers.variance_scaling instead with distribution=uniform "
"to get equivalent behaviour.")
# Make change instead (issue warning about strip_...)
export_saved_model_renamed = (
ast_edits.ERROR,
"(Manual edit required) Please rename the method export_savedmodel() "
"to export_saved_model(). Two things to note:\n\t(1) The argument "
"strip_default_attributes has been removed. The function will always "
@ -947,13 +956,10 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
assert_rank_comment,
"tf.debugging.assert_rank_in":
assert_rank_comment,
"tf.device":
"tf.device no longer takes function as an argument. "
"'device_name_or_function' argument has been renamed to "
"'device_name'.",
"tf.flags":
"tf.flags": (
ast_edits.ERROR,
"tf.flags has been removed, please use the argparse or absl"
" module if you need command line parsing.",
" modules if you need command line parsing."),
"tf.train.exponential_decay":
decay_function_comment,
"tf.train.piecewise_constant_decay":
@ -988,55 +994,12 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
default_loss_reduction_changed,
"tf.estimator.BaselineRegressor":
default_loss_reduction_changed,
"tf.nn.conv1d":
"WARNING: use_cudnn_on_gpu argument has been removed and \"value\""
" was renamed to \"input\"",
"tf.nn.conv2d":
"WARNING: use_cudnn_on_gpu argument has been removed and "
"\"filter\" was renamed to \"filters\"",
"tf.nn.conv2d_backprop_filter":
"WARNING: use_cudnn_on_gpu argument has been removed",
"tf.nn.conv2d_backprop_input":
"WARNING: use_cudnn_on_gpu argument has been removed and "
"\"filter\" was renamed to \"filters\"",
"tf.nn.erosion2d":
"WARNING: <function name> now requires a data_format argument",
"tf.nn.nce_loss":
deprecate_partition_strategy_comment % "tf.nn.nce_loss",
deprecate_partition_strategy_comment,
"tf.nn.safe_embedding_lookup_sparse":
deprecate_partition_strategy_comment %
"tf.nn.safe_embedding_lookup_sparse",
deprecate_partition_strategy_comment,
"tf.nn.sampled_softmax_loss":
deprecate_partition_strategy_comment % "tf.nn.sampled_softmax_loss",
"tf.zeros_like":
tf_01s_like_no_optimize_comment,
"tf.ones_like":
tf_01s_like_no_optimize_comment,
"tf.nn.embedding_lookup":
"WARNING: validate_indices argument has been removed.",
"tf.while_loop":
"tf.while_loop no longer takes 'return_same_structure' argument. "
"'return_same_structure' now defaults to True. Also, 'name'"
"argument is now the last argument.",
"tf.image.sample_distorted_bounding_box":
"tf.image.sample_distorted_bounding_box no longer takes 'seed2' "
"argument.",
"tf.nn.ctc_beam_search_decoder":
"tf.nn.ctc_beam_search_decoder no longer takes 'merge_repeated' "
"argument. 'merge_repeated' now defaults to False.",
"tf.nn.fractional_avg_pool":
"tf.nn.fractional_avg_pool no longer takes 'seed2' and "
"'deterministic' arguments. Now it takes a single 'seed' arg. If "
"'seed' is zero, the execution is random and deterministic "
"otherwise",
"tf.nn.fractional_max_pool":
"tf.nn.fractional_max_pool no longer takes 'seed2' and "
"'deterministic' arguments. Now it takes a single 'seed' arg. If "
"'seed' is zero, the execution is random and deterministic "
"otherwise",
"tf.test.assert_equal_graph_def":
"tf.assert_equal_graph_def no longer takes 'checkpoint_v2' "
"argument. 'checkpoint_v2' now defaults to True.",
deprecate_partition_strategy_comment,
"tf.keras.initializers.Zeros":
initializers_no_dtype_comment,
"tf.keras.initializers.zeros":
@ -1211,227 +1174,489 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
# Warnings that are emitted only if a specific arg is found.
self.function_arg_warnings = {
"tf.nn.conv1d": {
("use_cudnn_on_gpu", 4): (
ast_edits.WARNING,
"use_cudnn_on_gpu has been removed, behavior is now equivalent"
"to setting it to True."),
},
"tf.nn.conv2d": {
("use_cudnn_on_gpu", 4): (
ast_edits.WARNING,
"use_cudnn_on_gpu has been removed, behavior is now equivalent"
"to setting it to True."),
},
"tf.nn.conv2d_backprop_filter": {
("use_cudnn_on_gpu", 5): (
ast_edits.WARNING,
"use_cudnn_on_gpu has been removed, behavior is now equivalent"
"to setting it to True."),
},
"tf.nn.conv2d_backprop_input": {
("use_cudnn_on_gpu", 5): (
ast_edits.WARNING,
"use_cudnn_on_gpu has been removed, behavior is now equivalent"
"to setting it to True."),
},
"tf.gradients": {
("colocate_gradients_with_ops", 4):
("colocate_gradients_with_ops", 4): (
ast_edits.INFO,
"tf.gradients no longer takes "
"'colocate_gradients_with_ops' argument, it behaves as if it "
"was set to True.",
"was set to True."),
},
"*.minimize": {
("colocate_gradients_with_ops", 5):
("colocate_gradients_with_ops", 5): (
ast_edits.INFO,
"Optimizer.minimize no longer takes "
"'colocate_gradients_with_ops' argument, it behaves as if it "
"was set to True.",
"was set to True."),
},
"*.compute_gradients": {
("colocate_gradients_with_ops", 4):
("colocate_gradients_with_ops", 4): (
ast_edits.INFO,
"Optimizer.compute_gradients no "
"longer takes 'colocate_gradients_with_ops' argument, it "
"behaves as if it was set to True.",
"behaves as if it was set to True."),
},
"tf.cond": {
("strict", 3):
("strict", 3): (
ast_edits.WARNING,
"tf.cond no longer takes 'strict' argument, it behaves as "
"if was set to True."
"if was set to True.")
},
}
self.symbol_renames = {
name: new_name
for name, new_name in self.symbol_renames.items()
# Specially handled functions
# Each transformer is a callable which will be called with the arguments
# transformer(parent, node, full_name, name, logs, errors)
# Where logs is a list to which (level, line, col, msg) tuples can be
# appended, full_name is the FQN of the function called (or None if that is
# unknown), name is the name of the function called (or None is that is
# unknown). node is an ast.Call node representing this function call, and
# parent is its parent in the AST.
# The function may modify node (but not parent), and must return
# - none, if nothing was modified
# - node, if node was modified in place (make sure to use
# pasta.ast_utils.replace_child to swap out children, otherwise formatting
# may get messy)
# - a replacement for node, if the whole call node was replaced. The caller
# will take care of changing parent.
self.function_transformers = {
"*.make_initializable_iterator": _iterator_transformer,
"*.make_one_shot_iterator": _iterator_transformer,
"tf.nn.dropout": _dropout_transformer,
"tf.batch_gather": _batch_gather_transformer,
"tf.to_bfloat16": _cast_transformer,
"tf.to_complex128": _cast_transformer,
"tf.to_complex64": _cast_transformer,
"tf.to_double": _cast_transformer,
"tf.to_float": _cast_transformer,
"tf.to_int32": _cast_transformer,
"tf.to_int64": _cast_transformer,
"tf.nn.softmax_cross_entropy_with_logits":
_softmax_cross_entropy_with_logits_transformer,
"tf.image.resize_area": _image_resize_transformer,
"tf.image.resize_bicubic": _image_resize_transformer,
"tf.image.resize_bilinear": _image_resize_transformer,
"tf.image.resize_nearest_neighbor": _image_resize_transformer,
"tf.nn.fractional_avg_pool": _pool_seed_transformer,
"tf.nn.fractional_max_pool": _pool_seed_transformer,
"tf.device": functools.partial(
_rename_if_arg_found_transformer, arg_name="device_name",
arg_ok_predicate=_is_ast_str, remove_if_ok=False,
message="tf.device no longer takes functions as an argument. "
"We could not determine that the argument value is a string, so "
"the call was converted to compat.v1."),
"tf.zeros_like": functools.partial(
_rename_if_arg_found_transformer, arg_name="optimize",
arg_ok_predicate=_is_ast_true, remove_if_ok=True,
message="tf.zeros_like no longer takes an optimize argument, and "
"behaves as if optimize=True. This call site specifies something "
"other than optimize=True, so it was converted to compat.v1."),
"tf.ones_like": functools.partial(
_rename_if_arg_found_transformer, arg_name="optimize",
arg_ok_predicate=_is_ast_true, remove_if_ok=True,
message="tf.ones_like no longer takes an optimize argument, and "
"behaves as if optimize=True. This call site specifies something "
"other than optimize=True, so it was converted to compat.v1."),
"tf.while_loop": functools.partial(
_rename_if_arg_found_transformer,
arg_name="return_same_structure",
arg_ok_predicate=_is_ast_true, remove_if_ok=True,
message="tf.while_loop no longer takes 'return_same_structure' "
"argument and behaves as if return_same_structure=True. This call "
"site specifies something other than return_same_structure=True, "
"so it was converted to compat.v1."),
"tf.nn.ctc_beam_search_decoder": functools.partial(
_rename_if_arg_found_transformer,
arg_name="merge_repeated",
arg_ok_predicate=_is_ast_false, remove_if_ok=True,
message="tf.nn.ctc_beam_search_decoder no longer takes the "
"'merge_repeated' argument and behaves as if merge_repeated=False. "
"This call site specifies something other than "
"merge_repeated=False, so it was converted to compat.v1."),
"tf.nn.erosion2d": functools.partial(
_add_argument_transformer,
arg_name="data_format",
arg_value_ast=ast.Str("NHWC")),
}
@staticmethod
def _iterator_transformer(parent, node, full_name, name, logs):
# First, check that node.func.value is not already something we like
# (tf.compat.v1.data), or something which is handled in the rename
# (tf.data). This transformer only handles the method call to function call
# conversion.
if full_name and (full_name.startswith("tf.compat.v1.data") or
full_name.startswith("tf.data")):
return
# This should never happen, since we're only called for Attribute nodes.
if not isinstance(node.func, ast.Attribute):
return
def _is_ast_str(node):
"""Determine whether this node represents a string."""
allowed_types = [ast.Str]
if hasattr(ast, "Bytes"):
allowed_types += [ast.Bytes]
if hasattr(ast, "JoinedStr"):
allowed_types += [ast.JoinedStr]
if hasattr(ast, "FormattedValue"):
allowed_types += [ast.FormattedValue]
return isinstance(node, allowed_types)
# Transform from x.f(y) to tf.compat.v1.data.f(x, y)
# Fortunately, node.func.value should already have valid position info
node.args = [node.func.value] + node.args
node.func.value = ast_edits.full_name_node("tf.compat.v1.data")
logs.append((ast_edits.WARNING, node.lineno, node.col_offset,
"Changing dataset.%s() to tf.compat.v1.data.%s(dataset). "
"Please check this transformation.\n" % (name, name)))
def _is_ast_true(node):
if hasattr(ast, "NameConstant"):
return isinstance(node, ast.NameConstant) and node.value is True
else:
return isinstance(node, ast.Name) and node.id == "True"
return node
@staticmethod
def _dropout_transformer(parent, node, full_name, name, logs):
def _replace_keep_prob_node(parent, old_value):
"""Replaces old_value with 1-(old_value)."""
one = ast.Num(n=1)
one.lineno = 0
one.col_offset = 0
new_value = ast.BinOp(left=one, op=ast.Sub(),
right=old_value)
# This copies the prefix and suffix on old_value to new_value.
pasta.ast_utils.replace_child(parent, old_value, new_value)
ast.copy_location(new_value, old_value)
# Put parentheses around keep_prob.value (and remove the old prefix/
# suffix, they should only be around new_value).
pasta.base.formatting.set(old_value, "prefix", "(")
pasta.base.formatting.set(old_value, "suffix", ")")
def _is_ast_false(node):
if hasattr(ast, "NameConstant"):
return isinstance(node, ast.NameConstant) and node.value is False
else:
return isinstance(node, ast.Name) and node.id == "False"
# Check if we have a keep_prob keyword arg
for keep_prob in node.keywords:
if keep_prob.arg == "keep_prob":
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
"Changing keep_prob arg of tf.nn.dropout to rate\n"))
keep_prob.arg = "rate"
_replace_keep_prob_node(keep_prob, keep_prob.value)
return node
# Maybe it was a positional arg
if len(node.args) < 2:
logs.append((ast_edits.ERROR, node.lineno, node.col_offset,
"tf.nn.dropout called without arguments, so "
"automatic fix was disabled. tf.nn.dropout has changed "
"the semantics of the second argument."))
# Lots of unused arguments below, since these are called in a standard manner.
# pylint: disable=unused-argument
def _rename_if_arg_found_transformer(parent, node, full_name, name, logs,
arg_name=None,
arg_ok_predicate=None,
remove_if_ok=False,
message=None):
"""Replaces the given call with tf.compat.v1 if the given arg is found.
This requires the function to be called with all named args, so for using
this transformer, the function should also be added to renames.
If the arg is not found, the call site is left alone.
If the arg is found, and if arg_ok_predicate is given, it is called with
the ast Expression representing the argument value found. If it returns
True, the function is left alone.
If the arg is found, arg_ok_predicate is not None and returns ok, and
remove_if_ok is True, the argument is removed from the call.
Otherwise, `compat.v1` is inserted between tf and the function name.
Args:
parent: Parent of node.
node: ast.Call node to maybe modify.
full_name: full name of function to modify
name: name of function to modify
logs: list of logs to append to
arg_name: name of the argument to look for
arg_ok_predicate: predicate callable with the ast of the argument value,
returns whether the argument value is allowed.
remove_if_ok: remove the argument if present and ok as determined by
arg_ok_predicate.
message: message to print if a non-ok arg is found (and hence, the function
is renamed to its compat.v1 version).
Returns:
node, if it was modified, else None.
"""
# Check whether arg is there.
arg_present, arg_value = ast_edits.get_arg_value(node, arg_name)
if not arg_present:
return
# Check whether arg is problematic (and if not, maybe remove it).
if arg_ok_predicate and arg_ok_predicate(arg_value):
if remove_if_ok:
for i, kw in enumerate(node.keywords):
if kw.arg == arg_name:
node.keywords.pop(i)
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
"Removed argument %s for function %s" % (
arg_name, full_name or name)))
break
return node
else:
_replace_keep_prob_node(node, node.args[1])
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
"Changing keep_prob arg of tf.nn.dropout to rate, and "
"recomputing value.\n"))
return
# All conditions met, insert v1 and log what we did.
# We must have a full name, so the func is an attribute.
new_name = full_name.replace("tf.", "tf.compat.v1.", 1)
node.func = ast_edits.full_name_node(new_name)
logs.append((
ast_edits.INFO, node.lineno, node.col_offset,
"Renaming %s to %s because argument %s is present. %s" %
(full_name, new_name, arg_name, message if message is not None else "")
))
return node
def _add_argument_transformer(parent, node, full_name, name, logs,
arg_name, arg_value_ast):
"""Adds an argument (as a final kwarg arg_name=arg_value_ast)."""
node.keywords.append(ast.keyword(arg=arg_name, value=arg_value_ast))
logs.add((
ast_edits.INFO, node.lineno, node.col_offset,
"Adding argument '%s' to call to %s." % (pasta.dump(node.keywords[-1],
full_name or name))
))
return node
def _iterator_transformer(parent, node, full_name, name, logs):
"""Transform iterator methods to compat function calls."""
# First, check that node.func.value is not already something we like
# (tf.compat.v1.data), or something which is handled in the rename
# (tf.data). This transformer only handles the method call to function call
# conversion.
if full_name and (full_name.startswith("tf.compat.v1.data") or
full_name.startswith("tf.data")):
return
# This should never happen, since we're only called for Attribute nodes.
if not isinstance(node.func, ast.Attribute):
return
# Transform from x.f(y) to tf.compat.v1.data.f(x, y)
# Fortunately, node.func.value should already have valid position info
node.args = [node.func.value] + node.args
node.func.value = ast_edits.full_name_node("tf.compat.v1.data")
logs.append((ast_edits.WARNING, node.lineno, node.col_offset,
"Changing dataset.%s() to tf.compat.v1.data.%s(dataset). "
"Please check this transformation.\n" % (name, name)))
return node
def _dropout_transformer(parent, node, full_name, name, logs):
"""Replace keep_prob with 1-rate."""
def _replace_keep_prob_node(parent, old_value):
"""Replaces old_value with 1-(old_value)."""
one = ast.Num(n=1)
one.lineno = 0
one.col_offset = 0
new_value = ast.BinOp(left=one, op=ast.Sub(),
right=old_value)
# This copies the prefix and suffix on old_value to new_value.
pasta.ast_utils.replace_child(parent, old_value, new_value)
ast.copy_location(new_value, old_value)
# Put parentheses around keep_prob.value (and remove the old prefix/
# suffix, they should only be around new_value).
pasta.base.formatting.set(old_value, "prefix", "(")
pasta.base.formatting.set(old_value, "suffix", ")")
# Check if we have a keep_prob keyword arg
for keep_prob in node.keywords:
if keep_prob.arg == "keep_prob":
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
"Changing keep_prob arg of tf.nn.dropout to rate\n"))
keep_prob.arg = "rate"
_replace_keep_prob_node(keep_prob, keep_prob.value)
return node
@staticmethod
def _cast_transformer(parent, node, full_name, name, logs):
"""Transforms to_int and to_float to cast(..., dtype=...)."""
# Find out the dtype to cast to from the function name
dtype_str = name[3:]
# Special cases where the full dtype is not given
if dtype_str == "float":
dtype_str = "float32"
elif dtype_str == "double":
dtype_str = "float64"
new_arg = ast.keyword(arg="dtype",
value=ast.Attribute(value=ast.Name(id="tf",
ctx=ast.Load()),
attr=dtype_str, ctx=ast.Load()))
# Ensures a valid transformation when a positional name arg is given
if len(node.args) == 2:
name_arg = ast.keyword(arg="name",
value=node.args[-1])
node.args = node.args[:-1]
node.keywords.append(name_arg)
# Python3 ast requires the args for the Attribute, but codegen will mess up
# the arg order if we just set them to 0.
new_arg.value.lineno = node.lineno
new_arg.value.col_offset = node.col_offset+100
node.keywords.append(new_arg)
if isinstance(node.func, ast.Attribute):
node.func.attr = "cast"
else:
assert isinstance(node.func, ast.Name)
node.func.id = "cast"
# Maybe it was a positional arg
if len(node.args) < 2:
logs.append((ast_edits.ERROR, node.lineno, node.col_offset,
"tf.nn.dropout called without arguments, so "
"automatic fix was disabled. tf.nn.dropout has changed "
"the semantics of the second argument."))
else:
_replace_keep_prob_node(node, node.args[1])
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
"Changed %s call to tf.cast(..., dtype=tf.%s)." % (full_name,
dtype_str)))
"Changing keep_prob arg of tf.nn.dropout to rate, and "
"recomputing value.\n"))
return node
@staticmethod
def _softmax_cross_entropy_with_logits_transformer(
parent, node, full_name, name, logs):
def _wrap_label(parent, old_value):
"""Wrap labels with tf.stop_gradient."""
if six.PY3:
new_value = ast.Call(
ast.Name(id="tf.stop_gradient", ctx=ast.Load()),
[old_value], [])
else:
new_value = ast.Call(
ast.Name(id="tf.stop_gradient", ctx=ast.Load()),
[old_value], [], None, None)
# This copies the prefix and suffix on old_value to new_value.
pasta.ast_utils.replace_child(parent, old_value, new_value)
ast.copy_location(new_value, old_value)
def _cast_transformer(parent, node, full_name, name, logs):
"""Transforms to_int and to_float to cast(..., dtype=...)."""
# Check if we have a labels keyword arg
for karg in node.keywords:
if karg.arg == "labels":
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
"Changing labels arg of "
"tf.nn.softmax_cross_entropy_with_logits to "
"tf.stop_gradient(labels). Please check this "
"transformation.\n"))
_wrap_label(karg, karg.value)
return node
return node
# Find out the dtype to cast to from the function name
dtype_str = name[3:]
# Special cases where the full dtype is not given
if dtype_str == "float":
dtype_str = "float32"
elif dtype_str == "double":
dtype_str = "float64"
new_arg = ast.keyword(arg="dtype",
value=ast.Attribute(value=ast.Name(id="tf",
ctx=ast.Load()),
attr=dtype_str, ctx=ast.Load()))
# Ensures a valid transformation when a positional name arg is given
if len(node.args) == 2:
name_arg = ast.keyword(arg="name",
value=node.args[-1])
node.args = node.args[:-1]
node.keywords.append(name_arg)
@staticmethod
def _batch_gather_transformer(parent, node, full_name, name, logs):
# Check if the call already has a batch_dims argument
if any([kw.arg == "batch_dims" for kw in node.keywords]):
# Python3 ast requires the args for the Attribute, but codegen will mess up
# the arg order if we just set them to 0.
new_arg.value.lineno = node.lineno
new_arg.value.col_offset = node.col_offset+100
node.keywords.append(new_arg)
if isinstance(node.func, ast.Attribute):
node.func.attr = "cast"
else:
assert isinstance(node.func, ast.Name)
node.func.id = "cast"
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
"Changed %s call to tf.cast(..., dtype=tf.%s)." % (full_name,
dtype_str)))
return node
def _softmax_cross_entropy_with_logits_transformer(
parent, node, full_name, name, logs):
"""Wrap labels argument with stop_gradients."""
def _wrap_label(parent, old_value):
"""Wrap labels with tf.stop_gradient."""
if six.PY3:
new_value = ast.Call(
ast.Name(id="tf.stop_gradient", ctx=ast.Load()),
[old_value], [])
else:
new_value = ast.Call(
ast.Name(id="tf.stop_gradient", ctx=ast.Load()),
[old_value], [], None, None)
# This copies the prefix and suffix on old_value to new_value.
pasta.ast_utils.replace_child(parent, old_value, new_value)
ast.copy_location(new_value, old_value)
# Check if we have a labels keyword arg
for karg in node.keywords:
if karg.arg == "labels":
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
"tf.batch_gather already has batch_dims argument. Neat."))
return None
"Changing labels arg of "
"tf.nn.softmax_cross_entropy_with_logits to "
"tf.stop_gradient(labels). Please check this "
"transformation.\n"))
_wrap_label(karg, karg.value)
return node
return node
minus_one = ast.Num(n=-1)
minus_one.lineno = 0
minus_one.col_offset = 0
new_arg = ast.keyword("batch_dims", minus_one)
node.keywords.append(new_arg)
def _batch_gather_transformer(parent, node, full_name, name, logs):
"""Add batch_dims argument for gather calls."""
# Check if the call already has a batch_dims argument
if any([kw.arg == "batch_dims" for kw in node.keywords]):
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
"Added keyword argument batch_dims=-1 to tf.batch_gather."))
return node
"tf.batch_gather already has batch_dims argument. Neat."))
return None
@staticmethod
def _image_resize_transformer(parent, node, full_name, name, logs):
"""Transforms image.resize_* to image.resize(..., method=*, ...)."""
minus_one = ast.Num(n=-1)
minus_one.lineno = 0
minus_one.col_offset = 0
new_arg = ast.keyword("batch_dims", minus_one)
node.keywords.append(new_arg)
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
"Added keyword argument batch_dims=-1 to tf.batch_gather."))
return node
resize_method = name[7:].upper()
new_arg = ast.keyword(arg="method",
value=ast.Attribute(
value=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="tf", ctx=ast.Load()),
attr="image", ctx=ast.Load()),
attr="ResizeMethod", ctx=ast.Load()),
attr=resize_method, ctx=ast.Load()))
# Ensures a valid transformation when a positional name arg is given
if len(node.args) == 4:
pos_arg = ast.keyword(arg="preserve_aspect_ratio",
value=node.args[-1])
node.args = node.args[:-1]
node.keywords.append(pos_arg)
if len(node.args) == 3:
pos_arg = ast.keyword(arg="align_corners",
value=node.args[-1])
node.args = node.args[:-1]
node.keywords.append(pos_arg)
def _image_resize_transformer(parent, node, full_name, name, logs):
"""Transforms image.resize_* to image.resize(..., method=*, ...)."""
resize_method = name[7:].upper()
new_arg = ast.keyword(arg="method",
value=ast.Attribute(
value=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="tf", ctx=ast.Load()),
attr="image", ctx=ast.Load()),
attr="ResizeMethod", ctx=ast.Load()),
attr=resize_method, ctx=ast.Load()))
# Python3 ast requires the args for the Attribute, but codegen will mess up
# the arg order if we just set them to 0.
new_arg.value.lineno = node.lineno
new_arg.value.col_offset = node.col_offset+100
# Ensures a valid transformation when a positional name arg is given
if len(node.args) == 4:
pos_arg = ast.keyword(arg="preserve_aspect_ratio",
value=node.args[-1])
node.args = node.args[:-1]
node.keywords.append(pos_arg)
if len(node.args) == 3:
pos_arg = ast.keyword(arg="align_corners",
value=node.args[-1])
node.args = node.args[:-1]
node.keywords.append(pos_arg)
node.keywords.append(new_arg)
if isinstance(node.func, ast.Attribute):
node.func.attr = "resize"
# Python3 ast requires the args for the Attribute, but codegen will mess up
# the arg order if we just set them to 0.
new_arg.value.lineno = node.lineno
new_arg.value.col_offset = node.col_offset+100
node.keywords.append(new_arg)
if isinstance(node.func, ast.Attribute):
node.func.attr = "resize"
else:
assert isinstance(node.func, ast.Name)
node.func.id = "resize"
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
"Changed %s call to tf.image.resize(..., "
"method=tf.image.ResizeMethod.%s)." % (full_name,
resize_method)))
return node
def _pool_seed_transformer(parent, node, full_name, name, logs):
"""Removes seed2 and deterministic, and adds non-zero seed if needed."""
# This requires that this function uses all kwargs (add to renames!).
seed_arg = None
deterministic = False
modified = False
new_keywords = []
for kw in node.keywords:
if sys.version_info[:2] >= (3, 5) and isinstance(kw, ast.Starred):
pass
elif kw.arg == "seed":
seed_arg = kw
elif kw.arg == "seed2" or kw.arg == "deterministic":
lineno = getattr(kw, "lineno", node.lineno)
col_offset = getattr(kw, "col_offset", node.col_offset)
logs.append((ast_edits.INFO, lineno, col_offset,
"Removed argument %s for function %s" % (
kw.arg, full_name or name)))
if kw.arg == "deterministic":
if not _is_ast_false(kw.value):
deterministic = True
modified = True
continue
new_keywords.append(kw)
if deterministic:
if seed_arg is None:
new_keywords.append(ast.keyword(arg="seed", value=ast.Num(42)))
logs.add((
ast_edits.INFO, node.lineno, node.col_offset,
"Adding seed=42 to call to %s since determinism was requested" % (
full_name or name)
))
else:
assert isinstance(node.func, ast.Name)
node.func.id = "resize"
logs.add((
ast_edits.WARNING, node.lineno, node.col_offset,
"The deterministic argument is deprecated for %s, pass a "
"non-zero seed for determinism. The deterministic argument is "
"present, possibly not False, and the seed is already set. The "
"converter cannot determine whether it is nonzero, please check."
))
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
"Changed %s call to tf.image.resize(..., "
"method=tf.image.ResizeMethod.%s)." % (full_name,
resize_method)))
if modified:
node.keywords = new_keywords
return node
else:
return

View File

@ -112,7 +112,7 @@ Simple usage:
report = ("TensorFlow 2.0 Upgrade Script\n"
"-----------------------------\n"
"Converted %d files\n" % files_processed +
"Detected %d errors that require attention" % num_errors + "\n" +
"Detected %d issues that require attention" % num_errors + "\n" +
"-" * 80 + "\n") + "".join(report)
with open(report_filename, "w") as report_file:
report_file.write(report)

View File

@ -376,19 +376,13 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
"tf.train.inverse_time_decay", "tf.train.cosine_decay",
"tf.train.cosine_decay_restarts",
"tf.train.linear_cosine_decay",
"tf.train.noisy_linear_cosine_decay"]:
"tf.train.noisy_linear_cosine_decay",
"tf.train.piecewise_constant_decay",
]:
text = "%s(a, b)\n" % decay
_, report, errors, _ = self._upgrade(text)
self.assertIn("%s requires manual check" % decay, errors[0])
self.assertIn("%s has been changed" % decay, report)
def testPiecewiseDecay(self):
text = "tf.train.piecewise_constant_decay(a, b)\n"
_, report, errors, _ = self._upgrade(text)
self.assertIn("tf.train.piecewise_constant_decay requires manual check",
errors[0])
self.assertIn("tf.train.piecewise_constant_decay has been changed", report)
_, report, unused_errors, _ = self._upgrade(text)
self.assertIn("%s has been changed to return a callable" % decay, report)
def testMetrics(self):
metrics = [
@ -427,16 +421,11 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
"true_positives_at_thresholds",
]
for m in metrics:
ns = "tf.metrics." + m
text = ns + "(a, b)"
_, report, errors, new_text = self._upgrade(text)
text = "tf.metrics." + m + "(a, b)"
_, report, unused_errors, new_text = self._upgrade(text)
self.assertEqual("tf.compat.v1.metrics." + m + "(a, b)", new_text)
self.assertIn("%s requires manual check" % ns, errors[0])
self.assertIn(
"tf.metrics have been converted to object oriented"
" versions in TF 2.0 and after. The metric function calls have been "
"converted to compat.v1 for backward compatibility. Please update "
"these calls to the TF 2.0 versions.", report)
"tf.metrics have been replaced with object oriented versions", report)
def testLosses(self):
losses = [
@ -458,16 +447,11 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
"sparse_softmax_cross_entropy",
]
for l in losses:
ns = "tf.losses." + l
text = ns + "(a, b)"
_, report, errors, new_text = self._upgrade(text)
text = "tf.losses." + l + "(a, b)"
_, report, unused_errors, new_text = self._upgrade(text)
self.assertEqual("tf.compat.v1.losses." + l + "(a, b)", new_text)
self.assertIn("%s requires manual check" % ns, errors[0])
self.assertIn(
"tf.losses have been converted to object oriented"
" versions in TF 2.0 and after. The loss function calls have been "
"converted to compat.v1 for backward compatibility. Please update "
"these calls to the TF 2.0 versions.", report)
"tf.losses have been replaced with object oriented versions", report)
def testEstimatorLossReductionChange(self):
classes = [
@ -605,10 +589,9 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
self.assertEqual(errors, [])
text = "tf.gradients(a, colocate_gradients_with_ops=False)\n"
_, unused_report, errors, new_text = self._upgrade(text)
_, report, unused_errors, new_text = self._upgrade(text)
self.assertEqual("tf.gradients(a)\n", new_text)
self.assertIn("tf.gradients", errors[0])
self.assertIn("requires manual check", errors[0])
self.assertIn("tf.gradients no longer takes", report)
def testColocateGradientsWithOpsMinimize(self):
text = "optimizer.minimize(a, foo=False)\n"
@ -617,10 +600,9 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
self.assertEqual(errors, [])
text = "optimizer.minimize(a, colocate_gradients_with_ops=False)\n"
_, unused_report, errors, new_text = self._upgrade(text)
_, report, unused_errors, new_text = self._upgrade(text)
self.assertEqual("optimizer.minimize(a)\n", new_text)
self.assertIn("requires manual check", errors[0])
self.assertIn("minimize", errors[0])
self.assertIn("Optimizer.minimize no longer takes", report)
def testColocateGradientsWithOpsComputeGradients(self):
text = "optimizer.compute_gradients(a, foo=False)\n"
@ -629,10 +611,9 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
self.assertEqual(errors, [])
text = "optimizer.compute_gradients(a, colocate_gradients_with_ops=False)\n"
_, unused_report, errors, new_text = self._upgrade(text)
_, report, unused_errors, new_text = self._upgrade(text)
self.assertEqual("optimizer.compute_gradients(a)\n", new_text)
self.assertIn("requires manual check", errors[0])
self.assertIn("compute_gradients", errors[0])
self.assertIn("Optimizer.compute_gradients no longer takes", report)
def testExportSavedModelRename(self):
text = "self.est.export_savedmodel(path)"
@ -831,7 +812,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
"validate_indices, max_norm)")
expected_text = ("tf.nn.embedding_lookup(params=params, ids=ids, "
"partition_strategy=partition_strategy, name=name, "
"validate_indices=validate_indices, max_norm=max_norm)")
"max_norm=max_norm)")
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
@ -1029,29 +1010,43 @@ def _log_prob(self, x):
"assert_scalar"]:
text = "tf.%s(a)" % name
expected_text = "tf.compat.v1.%s(a)" % name
_, unused_report, errors, new_text = self._upgrade(text)
_, report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
self.assertIn("assert_* functions", errors[0])
self.assertIn("%s has been" % name, report)
text = "tf.debugging.%s(a)" % name
expected_text = "tf.compat.v1.debugging.%s(a)" % name
_, unused_report, errors, new_text = self._upgrade(text)
_, report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
self.assertIn("assert_* functions", errors[0])
self.assertIn("%s has been" % name, report)
def testAssertRankStatements(self):
for name in ["assert_rank", "assert_rank_at_least", "assert_rank_in"]:
text = "tf.%s(a)" % name
expected_text = "tf.compat.v1.%s(a)" % name
_, unused_report, errors, new_text = self._upgrade(text)
_, report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
self.assertIn("assert_rank_* functions", errors[0])
self.assertIn("%s has been" % name, report)
text = "tf.debugging.%s(a)" % name
expected_text = "tf.compat.v1.debugging.%s(a)" % name
_, unused_report, errors, new_text = self._upgrade(text)
_, report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
self.assertIn("assert_rank_* functions", errors[0])
self.assertIn("%s has been" % name, report)
def test_assert_equal_graph_def(self):
text = "tf.test.assert_equal_graph_def(a, b, checkpoint_v2=x)"
expected = "tf.test.assert_equal_graph_def(actual=a, expected=b)"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected, new_text)
def test_sample_distorted_bounding_box(self):
# pylint: disable=line-too-long
text = "tf.image.sample_distorted_bounding_box(a, b, c, d, e, f, g, h, i, j)"
expected = "tf.image.sample_distorted_bounding_box(image_size=a, bounding_boxes=b, seed=c, min_object_covered=e, aspect_ratio_range=f, area_range=g, max_attempts=h, use_image_if_no_bounding_boxes=i, name=j)"
# pylint: enable=line-too-long
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected, new_text)
class TestUpgradeFiles(test_util.TensorFlowTestCase):