From 2ac6545b3a5c2ed2b7d6f4e554492469e97555e0 Mon Sep 17 00:00:00 2001 From: Martin Wicke Date: Mon, 28 Jan 2019 19:52:50 -0800 Subject: [PATCH] Clean up conversion warning language, and use the ability to distinguish info, warning, and errors better. Also improves the conversion of some symbols (changes a warning to a conversion for more cases that is now possible). PiperOrigin-RevId: 231334490 --- tensorflow/tools/compatibility/ast_edits.py | 71 +- .../tools/compatibility/ast_edits_test.py | 2 +- tensorflow/tools/compatibility/reorders_v2.py | 9 + tensorflow/tools/compatibility/tf_upgrade.py | 7 +- .../tools/compatibility/tf_upgrade_v2.py | 819 +++++++++++------- .../tools/compatibility/tf_upgrade_v2_main.py | 2 +- .../tools/compatibility/tf_upgrade_v2_test.py | 85 +- 7 files changed, 632 insertions(+), 363 deletions(-) diff --git a/tensorflow/tools/compatibility/ast_edits.py b/tensorflow/tools/compatibility/ast_edits.py index eabb0be4e6c..572bf384e4c 100644 --- a/tensorflow/tools/compatibility/ast_edits.py +++ b/tensorflow/tools/compatibility/ast_edits.py @@ -62,6 +62,45 @@ def full_name_node(name, ctx=ast.Load()): return node +def get_arg_value(node, arg_name, arg_pos=None): + """Get the value of an argument from a ast.Call node. + + This function goes through the positional and keyword arguments to check + whether a given argument was used, and if so, returns its value (the node + representing its value). + + This cannot introspect *args or **args, but it safely handles *args in + Python3.5+. + + Args: + node: The ast.Call node to extract arg values from. + arg_name: The name of the argument to extract. + arg_pos: The position of the argument (in case it's passed as a positional + argument). + + Returns: + A tuple (arg_present, arg_value) containing a boolean indicating whether + the argument is present, and its value in case it is. + """ + # Check keyword args + if arg_name is not None: + for kw in node.keywords: + if kw.arg == arg_name: + return (True, kw.value) + + # Check positional args + if arg_pos is not None: + idx = 0 + for arg in node.args: + if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred): + continue # Can't parse Starred + if idx == arg_pos: + return (True, arg) + idx += 1 + + return (False, None) + + class APIChangeSpec(object): """This class defines the transformations that need to happen. @@ -195,11 +234,10 @@ class _PastaEditVisitor(ast.NodeVisitor): """Adds an error to be printed about full_name at node.""" function_warnings = self._api_change_spec.function_warnings if full_name in function_warnings: - warning_message = function_warnings[full_name] - warning_message = warning_message.replace("", full_name) - self.add_log(WARNING, node.lineno, node.col_offset, - "%s requires manual check. %s" % (full_name, - warning_message)) + level, message = function_warnings[full_name] + message = message.replace("", full_name) + self.add_log(level, node.lineno, node.col_offset, + "%s requires manual check. %s" % (full_name, message)) return True else: return False @@ -232,13 +270,13 @@ class _PastaEditVisitor(ast.NodeVisitor): arg_warnings = self._get_applicable_dict("function_arg_warnings", full_name, name) - used_args = [kw.arg for kw in node.keywords] - for (kwarg, arg), warning in arg_warnings.items(): - if kwarg in used_args or len(node.args) > arg: + for (kwarg, arg), (level, warning) in arg_warnings.items(): + present, _ = get_arg_value(node, kwarg, arg) + if present: warned = True warning_message = warning.replace("", full_name or name) - self.add_log(WARNING, node.lineno, node.col_offset, - "%s called with %s argument requires manual check: %s." % + self.add_log(level, node.lineno, node.col_offset, + "%s called with %s argument requires manual check: %s" % (full_name or name, kwarg, warning_message)) return warned @@ -280,12 +318,14 @@ class _PastaEditVisitor(ast.NodeVisitor): if full_name in function_reorders: reordered = function_reorders[full_name] new_keywords = [] - for idx, arg in enumerate(node.args): + idx = 0 + for arg in node.args: if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred): continue # Can't move Starred to keywords keyword_arg = reordered[idx] keyword = ast.keyword(arg=keyword_arg, value=arg) new_keywords.append(keyword) + idx += 1 if new_keywords: self.add_log(INFO, node.lineno, node.col_offset, @@ -374,11 +414,10 @@ class _PastaEditVisitor(ast.NodeVisitor): logs = [] new_node = transformer(parent, node, full_name, name, logs) self.add_logs(logs) - if new_node: - if new_node is not node: - pasta.ast_utils.replace_child(parent, node, new_node) - node = new_node - self._stack[-1] = node + if new_node and new_node is not node: + pasta.ast_utils.replace_child(parent, node, new_node) + node = new_node + self._stack[-1] = node self.generic_visit(node) diff --git a/tensorflow/tools/compatibility/ast_edits_test.py b/tensorflow/tools/compatibility/ast_edits_test.py index 70494791529..366ea2cb72f 100644 --- a/tensorflow/tools/compatibility/ast_edits_test.py +++ b/tensorflow/tools/compatibility/ast_edits_test.py @@ -417,7 +417,7 @@ class TestAstEdits(test_util.TensorFlowTestCase): def __init__(self): NoUpdateSpec.__init__(self) - self.function_warnings = {"*.foo": "not good"} + self.function_warnings = {"*.foo": (ast_edits.WARNING, "not good")} texts = ["object.foo()", "get_object().foo()", "get_object().foo()", "object.foo().bar()"] diff --git a/tensorflow/tools/compatibility/reorders_v2.py b/tensorflow/tools/compatibility/reorders_v2.py index f9b0e3f9d8e..01556b1225d 100644 --- a/tensorflow/tools/compatibility/reorders_v2.py +++ b/tensorflow/tools/compatibility/reorders_v2.py @@ -37,6 +37,7 @@ reorders = { 'tf.decode_csv': ['records', 'record_defaults', 'field_delim', 'use_quote_delim', 'name', 'na_value', 'select_cols'], 'tf.depth_to_space': ['input', 'block_size', 'name', 'data_format'], 'tf.feature_column.categorical_column_with_vocabulary_file': ['key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets', 'default_value', 'dtype'], + 'tf.image.sample_distorted_bounding_box': ['image_size', 'bounding_boxes', 'seed', 'seed2', 'min_object_covered', 'aspect_ratio_range', 'area_range', 'max_attempts', 'use_image_if_no_bounding_boxes', 'name'], 'tf.io.decode_csv': ['records', 'record_defaults', 'field_delim', 'use_quote_delim', 'name', 'na_value', 'select_cols'], 'tf.io.parse_example': ['serialized', 'features', 'name', 'example_names'], 'tf.io.parse_single_example': ['serialized', 'features', 'name', 'example_names'], @@ -56,12 +57,19 @@ reorders = { 'tf.math.reduce_prod': ['input_tensor', 'axis', 'keepdims', 'name', 'reduction_indices', 'keep_dims'], 'tf.math.reduce_sum': ['input_tensor', 'axis', 'keepdims', 'name', 'reduction_indices', 'keep_dims'], 'tf.multinomial': ['logits', 'num_samples', 'seed', 'name', 'output_dtype'], + 'tf.nn.conv1d': ['value', 'filters', 'stride', 'padding', 'use_cudnn_on_gpu', 'data_format', 'name'], + 'tf.nn.conv2d': ['input', 'filter', 'strides', 'padding', 'use_cudnn_on_gpu', 'data_format', 'dilations', 'name'], + 'tf.nn.conv2d_backprop_filter': ['input', 'filter_sizes', 'out_backprop', 'strides', 'padding', 'use_cudnn_on_gpu', 'data_format', 'dilations', 'name'], + 'tf.nn.conv2d_backprop_input': ['input_sizes', 'filter', 'out_backprop', 'strides', 'padding', 'use_cudnn_on_gpu', 'data_format', 'dilations', 'name'], 'tf.nn.convolution': ['input', 'filter', 'padding', 'strides', 'dilation_rate', 'name', 'data_format'], 'tf.nn.crelu': ['features', 'name', 'axis'], + 'tf.nn.ctc_beam_search_decoder': ['inputs', 'sequence_length', 'beam_width', 'top_paths', 'merge_repeated'], 'tf.nn.depth_to_space': ['input', 'block_size', 'name', 'data_format'], 'tf.nn.depthwise_conv2d': ['input', 'filter', 'strides', 'padding', 'rate', 'name', 'data_format'], 'tf.nn.embedding_lookup': ['params', 'ids', 'partition_strategy', 'name', 'validate_indices', 'max_norm'], 'tf.nn.embedding_lookup_sparse': ['params', 'sp_ids', 'sp_weights', 'partition_strategy', 'name', 'combiner', 'max_norm'], + 'tf.nn.fractional_avg_pool': ['value', 'pooling_ratio', 'pseudo_random', 'overlapping', 'deterministic', 'seed', 'seed2', 'name'], + 'tf.nn.fractional_max_pool': ['value', 'pooling_ratio', 'pseudo_random', 'overlapping', 'deterministic', 'seed', 'seed2', 'name'], 'tf.nn.in_top_k': ['predictions', 'targets', 'k', 'name'], 'tf.nn.moments': ['x', 'axes', 'shift', 'name', 'keep_dims'], 'tf.nn.pool': ['input', 'window_shape', 'pooling_type', 'padding', 'dilation_rate', 'strides', 'name', 'data_format'], @@ -113,6 +121,7 @@ reorders = { 'tf.strings.reduce_join': ['inputs', 'axis', 'keep_dims', 'separator', 'name', 'reduction_indices'], 'tf.strings.substr': ['input', 'pos', 'len', 'name', 'unit'], 'tf.substr': ['input', 'pos', 'len', 'name', 'unit'], + 'tf.test.assert_equal_graph_def': ['actual', 'expected', 'checkpoint_v2'], 'tf.transpose': ['a', 'perm', 'name', 'conjugate'], 'tf.tuple': ['tensors', 'name', 'control_inputs'], 'tf.while_loop': ['cond', 'body', 'loop_vars', 'shape_invariants', 'parallel_iterations', 'back_prop', 'swap_memory', 'name', 'maximum_iterations', 'return_same_structure'] diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py index 241b08510f6..8d16c679b52 100644 --- a/tensorflow/tools/compatibility/tf_upgrade.py +++ b/tensorflow/tools/compatibility/tf_upgrade.py @@ -177,10 +177,11 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): # Warnings that should be printed if corresponding functions are used. self.function_warnings = { - "tf.reverse": - "ERROR: tf.reverse has had its argument semantics changed " + "tf.reverse": ( + ast_edits.ERROR, + "tf.reverse has had its argument semantics changed " "significantly. The converter cannot detect this reliably, so " - "you need to inspect this usage manually.\n", + "you need to inspect this usage manually.\n"), } diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index 03ecf5f303f..2abd3eac1fc 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function import ast +import functools +import sys import pasta import six @@ -27,6 +29,9 @@ from tensorflow.tools.compatibility import ast_edits from tensorflow.tools.compatibility import renames_v2 from tensorflow.tools.compatibility import reorders_v2 +# These pylint warnings are a mistake. +# pylint: disable=g-explicit-bool-comparison,g-bool-id-comparison + class TFAPIChangeSpec(ast_edits.APIChangeSpec): """List of maps that describe what changed in the API.""" @@ -38,6 +43,15 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): # Only keyword args are handled, so make sure to also put any function in # function_reorders to ensure that all args are made into keywords first. self.function_keyword_renames = { + "tf.test.assert_equal_graph_def": { + "checkpoint_v2": None, + }, + "tf.nn.embedding_lookup": { + "validate_indices": None, + }, + "tf.image.sample_distorted_bounding_box": { + "seed2": None, + }, "tf.gradients": { "colocate_gradients_with_ops": None, }, @@ -380,6 +394,21 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): "tf.nn.weighted_moments": { "keep_dims": "keepdims" }, + "tf.nn.conv1d": { + "value": "input", + "use_cudnn_on_gpu": None, + }, + "tf.nn.conv2d": { + "filter": "filters", + "use_cudnn_on_gpu": None, + }, + "tf.nn.conv2d_backprop_filter": { + "use_cudnn_on_gpu": None, + }, + "tf.nn.conv2d_backprop_input": { + "filter": "filters", + "use_cudnn_on_gpu": None, + }, } # pylint: disable=line-too-long @@ -662,6 +691,10 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): # Mapping from function to the new name of the function self.symbol_renames = renames_v2.renames self.symbol_renames.update(self.manual_symbol_renames) + self.symbol_renames = { + name: new_name + for name, new_name in self.symbol_renames.items() + } # Variables that should be changed to functions. self.change_to_function = {} @@ -684,6 +717,11 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): "tf.nn.space_to_batch", "tf.boolean_mask", "tf.convert_to_tensor", + "tf.nn.conv1d", + "tf.nn.conv2d", + "tf.nn.conv2d_backprop_filter", + "tf.nn.conv2d_backprop_input", + "tf.nn.ctc_beam_search_decoder", "tf.nn.moments", "tf.nn.convolution", "tf.nn.crelu", @@ -746,6 +784,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): "tf.nn.embedding_lookup_sparse", "tf.nn.in_top_k", "tf.nn.space_to_depth", + "tf.test.assert_equal_graph_def", "tf.linalg.norm", "tf.norm", "tf.reverse_sequence", @@ -754,6 +793,9 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): # keyword arguments. Add keyword arguments in rare case when they # are not specified. "tf.nn.softmax_cross_entropy_with_logits", + "tf.nn.fractional_avg_pool", + "tf.nn.fractional_max_pool", + "tf.image.sample_distorted_bounding_box", } # Functions that were reordered should be changed to the new keyword args @@ -761,43 +803,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): # positional arguments yourself, this could do the wrong thing. self.function_reorders = reorders_v2.reorders - # Specially handled functions (pasta version) - # Each transformer is a callable which will be called with the arguments - # transformer(parent, node, full_name, name, logs, errors) - # Where logs and errors are lists to which (line, col, msg) tuples can be - # appended, full_name is the FQN of the function called (or None if that is - # unknown), name is the name of the function called (or None is that is - # unknown). node is an ast.Call node representing this function call, and - # parent is its parent in the AST. - # The function may modify node (but not parent), and must return - # - none, if nothing was modified - # - node, if node was modified in place (make sure to use - # pasta.ast_utils.replace_child to swap out children, otherwise formatting - # may get messy) - # - a replacement for node, if the whole call node was replaced. The caller - # will take care of changing parent. - self.function_transformers = { - "*.make_initializable_iterator": self._iterator_transformer, - "*.make_one_shot_iterator": self._iterator_transformer, - "tf.nn.dropout": self._dropout_transformer, - "tf.batch_gather": self._batch_gather_transformer, - "tf.to_bfloat16": self._cast_transformer, - "tf.to_complex128": self._cast_transformer, - "tf.to_complex64": self._cast_transformer, - "tf.to_double": self._cast_transformer, - "tf.to_float": self._cast_transformer, - "tf.to_int32": self._cast_transformer, - "tf.to_int64": self._cast_transformer, - "tf.nn.softmax_cross_entropy_with_logits": - self._softmax_cross_entropy_with_logits_transformer, - "tf.image.resize_area": self._image_resize_transformer, - "tf.image.resize_bicubic": self._image_resize_transformer, - "tf.image.resize_bilinear": self._image_resize_transformer, - "tf.image.resize_nearest_neighbor": self._image_resize_transformer, - - } - decay_function_comment = ( + ast_edits.INFO, " has been changed to return a callable instead " "of a tensor when graph building, but its functionality remains " "unchanged during eager execution (returns a callable like " @@ -806,64 +813,66 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): " be correct).\n" ) - # TODO(b/118888586): add default value change to update script. - default_loss_reduction_changed = ( - "default value of loss_reduction has been changed to " - "SUM_OVER_BATCH_SIZE.\n" - ) - assert_return_type_comment = ( - "assert_* functions have been changed to return None, the " + ast_edits.INFO, + " has been changed to return None, the " "data argument has been removed, and arguments have been reordered." "\nThe calls have been converted to compat.v1 for safety (even though " " they may already have been correct)." ) assert_rank_comment = ( - "assert_rank_* functions have been changed to return None, and" + ast_edits.INFO, + " has been changed to return None, and" " the data and summarize arguments have been removed." "\nThe calls have been converted to compat.v1 for safety (even though " " they may already have been correct)." ) - tf_01s_like_no_optimize_comment = ( - "tf.zeros_like and tf.ones_like no longer have the optimize " - "argument in TF 2.0 or after (also, `tensor' argument is renamed to " - "`input')." - "\nThe calls have been converted to compat.v1 for safety (even though " - " they may already have been correct)." - ) - - deprecate_partition_strategy_comment = ( - "`partition_strategy` has been removed from `%s` " - " The 'div' strategy is used by default.") - initializers_no_dtype_comment = ( - "tf.initializers and tf.keras.initializers no longer have the " + ast_edits.INFO, + "Initializers no longer have the " "dtype argument in the constructor or partition_info argument in the " - "call method in TF 2.0 and after. The only API symbols are now " - "tf.keras.initializers.* or tf.initializers.*." - "\nThe calls have been converted to compat.v1 for safety (even though " - "they may already have been correct).") - - uniform_unit_scaling_initializer_comment = ( - "uniform_unit_scaling_initializer has been removed. Please use" - " tf.initializers.variance_scaling instead with distribution=uniform " - "to get equivalent behaviour.") + "__call__ method.\nThe calls have been converted to compat.v1 for" + "safety (even though they may already have been correct).") metrics_comment = ( - "tf.metrics have been converted to object oriented versions in" + ast_edits.INFO, + "tf.metrics have been replaced with object oriented versions in" " TF 2.0 and after. The metric function calls have been converted to " "compat.v1 for backward compatibility. Please update these calls to " "the TF 2.0 versions.") losses_comment = ( - "tf.losses have been converted to object oriented versions in" + ast_edits.INFO, + "tf.losses have been replaced with object oriented versions in" " TF 2.0 and after. The loss function calls have been converted to " "compat.v1 for backward compatibility. Please update these calls to " "the TF 2.0 versions.") +# This could be done with a _rename_if_arg_not_found_transformer + deprecate_partition_strategy_comment = ( + ast_edits.WARNING, + "`partition_strategy` has been removed from . " + " The 'div' strategy will be used by default.") + +# TODO(b/118888586): add default value change to update script. + default_loss_reduction_changed = ( + ast_edits.WARNING, + "default value of loss_reduction has been changed to " + "SUM_OVER_BATCH_SIZE.\n" + ) + +# make change instead + uniform_unit_scaling_initializer_comment = ( + ast_edits.ERROR, + "uniform_unit_scaling_initializer has been removed. Please use" + " tf.initializers.variance_scaling instead with distribution=uniform " + "to get equivalent behaviour.") + +# Make change instead (issue warning about strip_...) export_saved_model_renamed = ( + ast_edits.ERROR, "(Manual edit required) Please rename the method export_savedmodel() " "to export_saved_model(). Two things to note:\n\t(1) The argument " "strip_default_attributes has been removed. The function will always " @@ -947,13 +956,10 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): assert_rank_comment, "tf.debugging.assert_rank_in": assert_rank_comment, - "tf.device": - "tf.device no longer takes function as an argument. " - "'device_name_or_function' argument has been renamed to " - "'device_name'.", - "tf.flags": + "tf.flags": ( + ast_edits.ERROR, "tf.flags has been removed, please use the argparse or absl" - " module if you need command line parsing.", + " modules if you need command line parsing."), "tf.train.exponential_decay": decay_function_comment, "tf.train.piecewise_constant_decay": @@ -988,55 +994,12 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): default_loss_reduction_changed, "tf.estimator.BaselineRegressor": default_loss_reduction_changed, - "tf.nn.conv1d": - "WARNING: use_cudnn_on_gpu argument has been removed and \"value\"" - " was renamed to \"input\"", - "tf.nn.conv2d": - "WARNING: use_cudnn_on_gpu argument has been removed and " - "\"filter\" was renamed to \"filters\"", - "tf.nn.conv2d_backprop_filter": - "WARNING: use_cudnn_on_gpu argument has been removed", - "tf.nn.conv2d_backprop_input": - "WARNING: use_cudnn_on_gpu argument has been removed and " - "\"filter\" was renamed to \"filters\"", - "tf.nn.erosion2d": - "WARNING: now requires a data_format argument", "tf.nn.nce_loss": - deprecate_partition_strategy_comment % "tf.nn.nce_loss", + deprecate_partition_strategy_comment, "tf.nn.safe_embedding_lookup_sparse": - deprecate_partition_strategy_comment % - "tf.nn.safe_embedding_lookup_sparse", + deprecate_partition_strategy_comment, "tf.nn.sampled_softmax_loss": - deprecate_partition_strategy_comment % "tf.nn.sampled_softmax_loss", - "tf.zeros_like": - tf_01s_like_no_optimize_comment, - "tf.ones_like": - tf_01s_like_no_optimize_comment, - "tf.nn.embedding_lookup": - "WARNING: validate_indices argument has been removed.", - "tf.while_loop": - "tf.while_loop no longer takes 'return_same_structure' argument. " - "'return_same_structure' now defaults to True. Also, 'name'" - "argument is now the last argument.", - "tf.image.sample_distorted_bounding_box": - "tf.image.sample_distorted_bounding_box no longer takes 'seed2' " - "argument.", - "tf.nn.ctc_beam_search_decoder": - "tf.nn.ctc_beam_search_decoder no longer takes 'merge_repeated' " - "argument. 'merge_repeated' now defaults to False.", - "tf.nn.fractional_avg_pool": - "tf.nn.fractional_avg_pool no longer takes 'seed2' and " - "'deterministic' arguments. Now it takes a single 'seed' arg. If " - "'seed' is zero, the execution is random and deterministic " - "otherwise", - "tf.nn.fractional_max_pool": - "tf.nn.fractional_max_pool no longer takes 'seed2' and " - "'deterministic' arguments. Now it takes a single 'seed' arg. If " - "'seed' is zero, the execution is random and deterministic " - "otherwise", - "tf.test.assert_equal_graph_def": - "tf.assert_equal_graph_def no longer takes 'checkpoint_v2' " - "argument. 'checkpoint_v2' now defaults to True.", + deprecate_partition_strategy_comment, "tf.keras.initializers.Zeros": initializers_no_dtype_comment, "tf.keras.initializers.zeros": @@ -1211,227 +1174,489 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): # Warnings that are emitted only if a specific arg is found. self.function_arg_warnings = { + "tf.nn.conv1d": { + ("use_cudnn_on_gpu", 4): ( + ast_edits.WARNING, + "use_cudnn_on_gpu has been removed, behavior is now equivalent" + "to setting it to True."), + }, + "tf.nn.conv2d": { + ("use_cudnn_on_gpu", 4): ( + ast_edits.WARNING, + "use_cudnn_on_gpu has been removed, behavior is now equivalent" + "to setting it to True."), + }, + "tf.nn.conv2d_backprop_filter": { + ("use_cudnn_on_gpu", 5): ( + ast_edits.WARNING, + "use_cudnn_on_gpu has been removed, behavior is now equivalent" + "to setting it to True."), + }, + "tf.nn.conv2d_backprop_input": { + ("use_cudnn_on_gpu", 5): ( + ast_edits.WARNING, + "use_cudnn_on_gpu has been removed, behavior is now equivalent" + "to setting it to True."), + }, "tf.gradients": { - ("colocate_gradients_with_ops", 4): + ("colocate_gradients_with_ops", 4): ( + ast_edits.INFO, "tf.gradients no longer takes " "'colocate_gradients_with_ops' argument, it behaves as if it " - "was set to True.", + "was set to True."), }, "*.minimize": { - ("colocate_gradients_with_ops", 5): + ("colocate_gradients_with_ops", 5): ( + ast_edits.INFO, "Optimizer.minimize no longer takes " "'colocate_gradients_with_ops' argument, it behaves as if it " - "was set to True.", + "was set to True."), }, "*.compute_gradients": { - ("colocate_gradients_with_ops", 4): + ("colocate_gradients_with_ops", 4): ( + ast_edits.INFO, "Optimizer.compute_gradients no " "longer takes 'colocate_gradients_with_ops' argument, it " - "behaves as if it was set to True.", + "behaves as if it was set to True."), }, "tf.cond": { - ("strict", 3): + ("strict", 3): ( + ast_edits.WARNING, "tf.cond no longer takes 'strict' argument, it behaves as " - "if was set to True." + "if was set to True.") }, } - self.symbol_renames = { - name: new_name - for name, new_name in self.symbol_renames.items() + # Specially handled functions + # Each transformer is a callable which will be called with the arguments + # transformer(parent, node, full_name, name, logs, errors) + # Where logs is a list to which (level, line, col, msg) tuples can be + # appended, full_name is the FQN of the function called (or None if that is + # unknown), name is the name of the function called (or None is that is + # unknown). node is an ast.Call node representing this function call, and + # parent is its parent in the AST. + # The function may modify node (but not parent), and must return + # - none, if nothing was modified + # - node, if node was modified in place (make sure to use + # pasta.ast_utils.replace_child to swap out children, otherwise formatting + # may get messy) + # - a replacement for node, if the whole call node was replaced. The caller + # will take care of changing parent. + self.function_transformers = { + "*.make_initializable_iterator": _iterator_transformer, + "*.make_one_shot_iterator": _iterator_transformer, + "tf.nn.dropout": _dropout_transformer, + "tf.batch_gather": _batch_gather_transformer, + "tf.to_bfloat16": _cast_transformer, + "tf.to_complex128": _cast_transformer, + "tf.to_complex64": _cast_transformer, + "tf.to_double": _cast_transformer, + "tf.to_float": _cast_transformer, + "tf.to_int32": _cast_transformer, + "tf.to_int64": _cast_transformer, + "tf.nn.softmax_cross_entropy_with_logits": + _softmax_cross_entropy_with_logits_transformer, + "tf.image.resize_area": _image_resize_transformer, + "tf.image.resize_bicubic": _image_resize_transformer, + "tf.image.resize_bilinear": _image_resize_transformer, + "tf.image.resize_nearest_neighbor": _image_resize_transformer, + "tf.nn.fractional_avg_pool": _pool_seed_transformer, + "tf.nn.fractional_max_pool": _pool_seed_transformer, + "tf.device": functools.partial( + _rename_if_arg_found_transformer, arg_name="device_name", + arg_ok_predicate=_is_ast_str, remove_if_ok=False, + message="tf.device no longer takes functions as an argument. " + "We could not determine that the argument value is a string, so " + "the call was converted to compat.v1."), + "tf.zeros_like": functools.partial( + _rename_if_arg_found_transformer, arg_name="optimize", + arg_ok_predicate=_is_ast_true, remove_if_ok=True, + message="tf.zeros_like no longer takes an optimize argument, and " + "behaves as if optimize=True. This call site specifies something " + "other than optimize=True, so it was converted to compat.v1."), + "tf.ones_like": functools.partial( + _rename_if_arg_found_transformer, arg_name="optimize", + arg_ok_predicate=_is_ast_true, remove_if_ok=True, + message="tf.ones_like no longer takes an optimize argument, and " + "behaves as if optimize=True. This call site specifies something " + "other than optimize=True, so it was converted to compat.v1."), + "tf.while_loop": functools.partial( + _rename_if_arg_found_transformer, + arg_name="return_same_structure", + arg_ok_predicate=_is_ast_true, remove_if_ok=True, + message="tf.while_loop no longer takes 'return_same_structure' " + "argument and behaves as if return_same_structure=True. This call " + "site specifies something other than return_same_structure=True, " + "so it was converted to compat.v1."), + "tf.nn.ctc_beam_search_decoder": functools.partial( + _rename_if_arg_found_transformer, + arg_name="merge_repeated", + arg_ok_predicate=_is_ast_false, remove_if_ok=True, + message="tf.nn.ctc_beam_search_decoder no longer takes the " + "'merge_repeated' argument and behaves as if merge_repeated=False. " + "This call site specifies something other than " + "merge_repeated=False, so it was converted to compat.v1."), + "tf.nn.erosion2d": functools.partial( + _add_argument_transformer, + arg_name="data_format", + arg_value_ast=ast.Str("NHWC")), } - @staticmethod - def _iterator_transformer(parent, node, full_name, name, logs): - # First, check that node.func.value is not already something we like - # (tf.compat.v1.data), or something which is handled in the rename - # (tf.data). This transformer only handles the method call to function call - # conversion. - if full_name and (full_name.startswith("tf.compat.v1.data") or - full_name.startswith("tf.data")): - return - # This should never happen, since we're only called for Attribute nodes. - if not isinstance(node.func, ast.Attribute): - return +def _is_ast_str(node): + """Determine whether this node represents a string.""" + allowed_types = [ast.Str] + if hasattr(ast, "Bytes"): + allowed_types += [ast.Bytes] + if hasattr(ast, "JoinedStr"): + allowed_types += [ast.JoinedStr] + if hasattr(ast, "FormattedValue"): + allowed_types += [ast.FormattedValue] + return isinstance(node, allowed_types) - # Transform from x.f(y) to tf.compat.v1.data.f(x, y) - # Fortunately, node.func.value should already have valid position info - node.args = [node.func.value] + node.args - node.func.value = ast_edits.full_name_node("tf.compat.v1.data") - logs.append((ast_edits.WARNING, node.lineno, node.col_offset, - "Changing dataset.%s() to tf.compat.v1.data.%s(dataset). " - "Please check this transformation.\n" % (name, name))) +def _is_ast_true(node): + if hasattr(ast, "NameConstant"): + return isinstance(node, ast.NameConstant) and node.value is True + else: + return isinstance(node, ast.Name) and node.id == "True" - return node - @staticmethod - def _dropout_transformer(parent, node, full_name, name, logs): - def _replace_keep_prob_node(parent, old_value): - """Replaces old_value with 1-(old_value).""" - one = ast.Num(n=1) - one.lineno = 0 - one.col_offset = 0 - new_value = ast.BinOp(left=one, op=ast.Sub(), - right=old_value) - # This copies the prefix and suffix on old_value to new_value. - pasta.ast_utils.replace_child(parent, old_value, new_value) - ast.copy_location(new_value, old_value) - # Put parentheses around keep_prob.value (and remove the old prefix/ - # suffix, they should only be around new_value). - pasta.base.formatting.set(old_value, "prefix", "(") - pasta.base.formatting.set(old_value, "suffix", ")") +def _is_ast_false(node): + if hasattr(ast, "NameConstant"): + return isinstance(node, ast.NameConstant) and node.value is False + else: + return isinstance(node, ast.Name) and node.id == "False" - # Check if we have a keep_prob keyword arg - for keep_prob in node.keywords: - if keep_prob.arg == "keep_prob": - logs.append((ast_edits.INFO, node.lineno, node.col_offset, - "Changing keep_prob arg of tf.nn.dropout to rate\n")) - keep_prob.arg = "rate" - _replace_keep_prob_node(keep_prob, keep_prob.value) - return node - # Maybe it was a positional arg - if len(node.args) < 2: - logs.append((ast_edits.ERROR, node.lineno, node.col_offset, - "tf.nn.dropout called without arguments, so " - "automatic fix was disabled. tf.nn.dropout has changed " - "the semantics of the second argument.")) +# Lots of unused arguments below, since these are called in a standard manner. +# pylint: disable=unused-argument + + +def _rename_if_arg_found_transformer(parent, node, full_name, name, logs, + arg_name=None, + arg_ok_predicate=None, + remove_if_ok=False, + message=None): + """Replaces the given call with tf.compat.v1 if the given arg is found. + + This requires the function to be called with all named args, so for using + this transformer, the function should also be added to renames. + + If the arg is not found, the call site is left alone. + + If the arg is found, and if arg_ok_predicate is given, it is called with + the ast Expression representing the argument value found. If it returns + True, the function is left alone. + + If the arg is found, arg_ok_predicate is not None and returns ok, and + remove_if_ok is True, the argument is removed from the call. + + Otherwise, `compat.v1` is inserted between tf and the function name. + + Args: + parent: Parent of node. + node: ast.Call node to maybe modify. + full_name: full name of function to modify + name: name of function to modify + logs: list of logs to append to + arg_name: name of the argument to look for + arg_ok_predicate: predicate callable with the ast of the argument value, + returns whether the argument value is allowed. + remove_if_ok: remove the argument if present and ok as determined by + arg_ok_predicate. + message: message to print if a non-ok arg is found (and hence, the function + is renamed to its compat.v1 version). + + Returns: + node, if it was modified, else None. + """ + # Check whether arg is there. + arg_present, arg_value = ast_edits.get_arg_value(node, arg_name) + if not arg_present: + return + + # Check whether arg is problematic (and if not, maybe remove it). + if arg_ok_predicate and arg_ok_predicate(arg_value): + if remove_if_ok: + for i, kw in enumerate(node.keywords): + if kw.arg == arg_name: + node.keywords.pop(i) + logs.append((ast_edits.INFO, node.lineno, node.col_offset, + "Removed argument %s for function %s" % ( + arg_name, full_name or name))) + break + return node else: - _replace_keep_prob_node(node, node.args[1]) - logs.append((ast_edits.INFO, node.lineno, node.col_offset, - "Changing keep_prob arg of tf.nn.dropout to rate, and " - "recomputing value.\n")) + return + # All conditions met, insert v1 and log what we did. + # We must have a full name, so the func is an attribute. + new_name = full_name.replace("tf.", "tf.compat.v1.", 1) + node.func = ast_edits.full_name_node(new_name) + logs.append(( + ast_edits.INFO, node.lineno, node.col_offset, + "Renaming %s to %s because argument %s is present. %s" % + (full_name, new_name, arg_name, message if message is not None else "") + )) + return node + + +def _add_argument_transformer(parent, node, full_name, name, logs, + arg_name, arg_value_ast): + """Adds an argument (as a final kwarg arg_name=arg_value_ast).""" + node.keywords.append(ast.keyword(arg=arg_name, value=arg_value_ast)) + logs.add(( + ast_edits.INFO, node.lineno, node.col_offset, + "Adding argument '%s' to call to %s." % (pasta.dump(node.keywords[-1], + full_name or name)) + )) + return node + + +def _iterator_transformer(parent, node, full_name, name, logs): + """Transform iterator methods to compat function calls.""" + # First, check that node.func.value is not already something we like + # (tf.compat.v1.data), or something which is handled in the rename + # (tf.data). This transformer only handles the method call to function call + # conversion. + if full_name and (full_name.startswith("tf.compat.v1.data") or + full_name.startswith("tf.data")): + return + + # This should never happen, since we're only called for Attribute nodes. + if not isinstance(node.func, ast.Attribute): + return + + # Transform from x.f(y) to tf.compat.v1.data.f(x, y) + # Fortunately, node.func.value should already have valid position info + node.args = [node.func.value] + node.args + node.func.value = ast_edits.full_name_node("tf.compat.v1.data") + + logs.append((ast_edits.WARNING, node.lineno, node.col_offset, + "Changing dataset.%s() to tf.compat.v1.data.%s(dataset). " + "Please check this transformation.\n" % (name, name))) + + return node + + +def _dropout_transformer(parent, node, full_name, name, logs): + """Replace keep_prob with 1-rate.""" + def _replace_keep_prob_node(parent, old_value): + """Replaces old_value with 1-(old_value).""" + one = ast.Num(n=1) + one.lineno = 0 + one.col_offset = 0 + new_value = ast.BinOp(left=one, op=ast.Sub(), + right=old_value) + # This copies the prefix and suffix on old_value to new_value. + pasta.ast_utils.replace_child(parent, old_value, new_value) + ast.copy_location(new_value, old_value) + # Put parentheses around keep_prob.value (and remove the old prefix/ + # suffix, they should only be around new_value). + pasta.base.formatting.set(old_value, "prefix", "(") + pasta.base.formatting.set(old_value, "suffix", ")") + + # Check if we have a keep_prob keyword arg + for keep_prob in node.keywords: + if keep_prob.arg == "keep_prob": + logs.append((ast_edits.INFO, node.lineno, node.col_offset, + "Changing keep_prob arg of tf.nn.dropout to rate\n")) + keep_prob.arg = "rate" + _replace_keep_prob_node(keep_prob, keep_prob.value) return node - @staticmethod - def _cast_transformer(parent, node, full_name, name, logs): - """Transforms to_int and to_float to cast(..., dtype=...).""" - - # Find out the dtype to cast to from the function name - dtype_str = name[3:] - # Special cases where the full dtype is not given - if dtype_str == "float": - dtype_str = "float32" - elif dtype_str == "double": - dtype_str = "float64" - new_arg = ast.keyword(arg="dtype", - value=ast.Attribute(value=ast.Name(id="tf", - ctx=ast.Load()), - attr=dtype_str, ctx=ast.Load())) - # Ensures a valid transformation when a positional name arg is given - if len(node.args) == 2: - name_arg = ast.keyword(arg="name", - value=node.args[-1]) - node.args = node.args[:-1] - node.keywords.append(name_arg) - - # Python3 ast requires the args for the Attribute, but codegen will mess up - # the arg order if we just set them to 0. - new_arg.value.lineno = node.lineno - new_arg.value.col_offset = node.col_offset+100 - - node.keywords.append(new_arg) - if isinstance(node.func, ast.Attribute): - node.func.attr = "cast" - else: - assert isinstance(node.func, ast.Name) - node.func.id = "cast" - + # Maybe it was a positional arg + if len(node.args) < 2: + logs.append((ast_edits.ERROR, node.lineno, node.col_offset, + "tf.nn.dropout called without arguments, so " + "automatic fix was disabled. tf.nn.dropout has changed " + "the semantics of the second argument.")) + else: + _replace_keep_prob_node(node, node.args[1]) logs.append((ast_edits.INFO, node.lineno, node.col_offset, - "Changed %s call to tf.cast(..., dtype=tf.%s)." % (full_name, - dtype_str))) + "Changing keep_prob arg of tf.nn.dropout to rate, and " + "recomputing value.\n")) + return node - @staticmethod - def _softmax_cross_entropy_with_logits_transformer( - parent, node, full_name, name, logs): - def _wrap_label(parent, old_value): - """Wrap labels with tf.stop_gradient.""" - if six.PY3: - new_value = ast.Call( - ast.Name(id="tf.stop_gradient", ctx=ast.Load()), - [old_value], []) - else: - new_value = ast.Call( - ast.Name(id="tf.stop_gradient", ctx=ast.Load()), - [old_value], [], None, None) - # This copies the prefix and suffix on old_value to new_value. - pasta.ast_utils.replace_child(parent, old_value, new_value) - ast.copy_location(new_value, old_value) +def _cast_transformer(parent, node, full_name, name, logs): + """Transforms to_int and to_float to cast(..., dtype=...).""" - # Check if we have a labels keyword arg - for karg in node.keywords: - if karg.arg == "labels": - logs.append((ast_edits.INFO, node.lineno, node.col_offset, - "Changing labels arg of " - "tf.nn.softmax_cross_entropy_with_logits to " - "tf.stop_gradient(labels). Please check this " - "transformation.\n")) - _wrap_label(karg, karg.value) - return node - return node + # Find out the dtype to cast to from the function name + dtype_str = name[3:] + # Special cases where the full dtype is not given + if dtype_str == "float": + dtype_str = "float32" + elif dtype_str == "double": + dtype_str = "float64" + new_arg = ast.keyword(arg="dtype", + value=ast.Attribute(value=ast.Name(id="tf", + ctx=ast.Load()), + attr=dtype_str, ctx=ast.Load())) + # Ensures a valid transformation when a positional name arg is given + if len(node.args) == 2: + name_arg = ast.keyword(arg="name", + value=node.args[-1]) + node.args = node.args[:-1] + node.keywords.append(name_arg) - @staticmethod - def _batch_gather_transformer(parent, node, full_name, name, logs): - # Check if the call already has a batch_dims argument - if any([kw.arg == "batch_dims" for kw in node.keywords]): + # Python3 ast requires the args for the Attribute, but codegen will mess up + # the arg order if we just set them to 0. + new_arg.value.lineno = node.lineno + new_arg.value.col_offset = node.col_offset+100 + + node.keywords.append(new_arg) + if isinstance(node.func, ast.Attribute): + node.func.attr = "cast" + else: + assert isinstance(node.func, ast.Name) + node.func.id = "cast" + + logs.append((ast_edits.INFO, node.lineno, node.col_offset, + "Changed %s call to tf.cast(..., dtype=tf.%s)." % (full_name, + dtype_str))) + return node + + +def _softmax_cross_entropy_with_logits_transformer( + parent, node, full_name, name, logs): + """Wrap labels argument with stop_gradients.""" + def _wrap_label(parent, old_value): + """Wrap labels with tf.stop_gradient.""" + if six.PY3: + new_value = ast.Call( + ast.Name(id="tf.stop_gradient", ctx=ast.Load()), + [old_value], []) + else: + new_value = ast.Call( + ast.Name(id="tf.stop_gradient", ctx=ast.Load()), + [old_value], [], None, None) + + # This copies the prefix and suffix on old_value to new_value. + pasta.ast_utils.replace_child(parent, old_value, new_value) + ast.copy_location(new_value, old_value) + + # Check if we have a labels keyword arg + for karg in node.keywords: + if karg.arg == "labels": logs.append((ast_edits.INFO, node.lineno, node.col_offset, - "tf.batch_gather already has batch_dims argument. Neat.")) - return None + "Changing labels arg of " + "tf.nn.softmax_cross_entropy_with_logits to " + "tf.stop_gradient(labels). Please check this " + "transformation.\n")) + _wrap_label(karg, karg.value) + return node + return node - minus_one = ast.Num(n=-1) - minus_one.lineno = 0 - minus_one.col_offset = 0 - new_arg = ast.keyword("batch_dims", minus_one) - node.keywords.append(new_arg) + +def _batch_gather_transformer(parent, node, full_name, name, logs): + """Add batch_dims argument for gather calls.""" + # Check if the call already has a batch_dims argument + if any([kw.arg == "batch_dims" for kw in node.keywords]): logs.append((ast_edits.INFO, node.lineno, node.col_offset, - "Added keyword argument batch_dims=-1 to tf.batch_gather.")) - return node + "tf.batch_gather already has batch_dims argument. Neat.")) + return None - @staticmethod - def _image_resize_transformer(parent, node, full_name, name, logs): - """Transforms image.resize_* to image.resize(..., method=*, ...).""" + minus_one = ast.Num(n=-1) + minus_one.lineno = 0 + minus_one.col_offset = 0 + new_arg = ast.keyword("batch_dims", minus_one) + node.keywords.append(new_arg) + logs.append((ast_edits.INFO, node.lineno, node.col_offset, + "Added keyword argument batch_dims=-1 to tf.batch_gather.")) + return node - resize_method = name[7:].upper() - new_arg = ast.keyword(arg="method", - value=ast.Attribute( - value=ast.Attribute( - value=ast.Attribute( - value=ast.Name(id="tf", ctx=ast.Load()), - attr="image", ctx=ast.Load()), - attr="ResizeMethod", ctx=ast.Load()), - attr=resize_method, ctx=ast.Load())) - # Ensures a valid transformation when a positional name arg is given - if len(node.args) == 4: - pos_arg = ast.keyword(arg="preserve_aspect_ratio", - value=node.args[-1]) - node.args = node.args[:-1] - node.keywords.append(pos_arg) - if len(node.args) == 3: - pos_arg = ast.keyword(arg="align_corners", - value=node.args[-1]) - node.args = node.args[:-1] - node.keywords.append(pos_arg) +def _image_resize_transformer(parent, node, full_name, name, logs): + """Transforms image.resize_* to image.resize(..., method=*, ...).""" + resize_method = name[7:].upper() + new_arg = ast.keyword(arg="method", + value=ast.Attribute( + value=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id="tf", ctx=ast.Load()), + attr="image", ctx=ast.Load()), + attr="ResizeMethod", ctx=ast.Load()), + attr=resize_method, ctx=ast.Load())) - # Python3 ast requires the args for the Attribute, but codegen will mess up - # the arg order if we just set them to 0. - new_arg.value.lineno = node.lineno - new_arg.value.col_offset = node.col_offset+100 + # Ensures a valid transformation when a positional name arg is given + if len(node.args) == 4: + pos_arg = ast.keyword(arg="preserve_aspect_ratio", + value=node.args[-1]) + node.args = node.args[:-1] + node.keywords.append(pos_arg) + if len(node.args) == 3: + pos_arg = ast.keyword(arg="align_corners", + value=node.args[-1]) + node.args = node.args[:-1] + node.keywords.append(pos_arg) - node.keywords.append(new_arg) - if isinstance(node.func, ast.Attribute): - node.func.attr = "resize" + # Python3 ast requires the args for the Attribute, but codegen will mess up + # the arg order if we just set them to 0. + new_arg.value.lineno = node.lineno + new_arg.value.col_offset = node.col_offset+100 + + node.keywords.append(new_arg) + if isinstance(node.func, ast.Attribute): + node.func.attr = "resize" + else: + assert isinstance(node.func, ast.Name) + node.func.id = "resize" + + logs.append((ast_edits.INFO, node.lineno, node.col_offset, + "Changed %s call to tf.image.resize(..., " + "method=tf.image.ResizeMethod.%s)." % (full_name, + resize_method))) + return node + + +def _pool_seed_transformer(parent, node, full_name, name, logs): + """Removes seed2 and deterministic, and adds non-zero seed if needed.""" + # This requires that this function uses all kwargs (add to renames!). + seed_arg = None + deterministic = False + modified = False + new_keywords = [] + + for kw in node.keywords: + if sys.version_info[:2] >= (3, 5) and isinstance(kw, ast.Starred): + pass + elif kw.arg == "seed": + seed_arg = kw + elif kw.arg == "seed2" or kw.arg == "deterministic": + lineno = getattr(kw, "lineno", node.lineno) + col_offset = getattr(kw, "col_offset", node.col_offset) + logs.append((ast_edits.INFO, lineno, col_offset, + "Removed argument %s for function %s" % ( + kw.arg, full_name or name))) + if kw.arg == "deterministic": + if not _is_ast_false(kw.value): + deterministic = True + modified = True + continue + new_keywords.append(kw) + + if deterministic: + if seed_arg is None: + new_keywords.append(ast.keyword(arg="seed", value=ast.Num(42))) + logs.add(( + ast_edits.INFO, node.lineno, node.col_offset, + "Adding seed=42 to call to %s since determinism was requested" % ( + full_name or name) + )) else: - assert isinstance(node.func, ast.Name) - node.func.id = "resize" + logs.add(( + ast_edits.WARNING, node.lineno, node.col_offset, + "The deterministic argument is deprecated for %s, pass a " + "non-zero seed for determinism. The deterministic argument is " + "present, possibly not False, and the seed is already set. The " + "converter cannot determine whether it is nonzero, please check." + )) - logs.append((ast_edits.INFO, node.lineno, node.col_offset, - "Changed %s call to tf.image.resize(..., " - "method=tf.image.ResizeMethod.%s)." % (full_name, - resize_method))) + if modified: + node.keywords = new_keywords return node + else: + return diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_main.py b/tensorflow/tools/compatibility/tf_upgrade_v2_main.py index b446452cfe3..aecd74e71c8 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_main.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_main.py @@ -112,7 +112,7 @@ Simple usage: report = ("TensorFlow 2.0 Upgrade Script\n" "-----------------------------\n" "Converted %d files\n" % files_processed + - "Detected %d errors that require attention" % num_errors + "\n" + + "Detected %d issues that require attention" % num_errors + "\n" + "-" * 80 + "\n") + "".join(report) with open(report_filename, "w") as report_file: report_file.write(report) diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py index 8ff5d01ae6d..2acaaacc180 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -376,19 +376,13 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map "tf.train.inverse_time_decay", "tf.train.cosine_decay", "tf.train.cosine_decay_restarts", "tf.train.linear_cosine_decay", - "tf.train.noisy_linear_cosine_decay"]: + "tf.train.noisy_linear_cosine_decay", + "tf.train.piecewise_constant_decay", + ]: text = "%s(a, b)\n" % decay - _, report, errors, _ = self._upgrade(text) - self.assertIn("%s requires manual check" % decay, errors[0]) - self.assertIn("%s has been changed" % decay, report) - - def testPiecewiseDecay(self): - text = "tf.train.piecewise_constant_decay(a, b)\n" - _, report, errors, _ = self._upgrade(text) - self.assertIn("tf.train.piecewise_constant_decay requires manual check", - errors[0]) - self.assertIn("tf.train.piecewise_constant_decay has been changed", report) + _, report, unused_errors, _ = self._upgrade(text) + self.assertIn("%s has been changed to return a callable" % decay, report) def testMetrics(self): metrics = [ @@ -427,16 +421,11 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map "true_positives_at_thresholds", ] for m in metrics: - ns = "tf.metrics." + m - text = ns + "(a, b)" - _, report, errors, new_text = self._upgrade(text) + text = "tf.metrics." + m + "(a, b)" + _, report, unused_errors, new_text = self._upgrade(text) self.assertEqual("tf.compat.v1.metrics." + m + "(a, b)", new_text) - self.assertIn("%s requires manual check" % ns, errors[0]) self.assertIn( - "tf.metrics have been converted to object oriented" - " versions in TF 2.0 and after. The metric function calls have been " - "converted to compat.v1 for backward compatibility. Please update " - "these calls to the TF 2.0 versions.", report) + "tf.metrics have been replaced with object oriented versions", report) def testLosses(self): losses = [ @@ -458,16 +447,11 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map "sparse_softmax_cross_entropy", ] for l in losses: - ns = "tf.losses." + l - text = ns + "(a, b)" - _, report, errors, new_text = self._upgrade(text) + text = "tf.losses." + l + "(a, b)" + _, report, unused_errors, new_text = self._upgrade(text) self.assertEqual("tf.compat.v1.losses." + l + "(a, b)", new_text) - self.assertIn("%s requires manual check" % ns, errors[0]) self.assertIn( - "tf.losses have been converted to object oriented" - " versions in TF 2.0 and after. The loss function calls have been " - "converted to compat.v1 for backward compatibility. Please update " - "these calls to the TF 2.0 versions.", report) + "tf.losses have been replaced with object oriented versions", report) def testEstimatorLossReductionChange(self): classes = [ @@ -605,10 +589,9 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map self.assertEqual(errors, []) text = "tf.gradients(a, colocate_gradients_with_ops=False)\n" - _, unused_report, errors, new_text = self._upgrade(text) + _, report, unused_errors, new_text = self._upgrade(text) self.assertEqual("tf.gradients(a)\n", new_text) - self.assertIn("tf.gradients", errors[0]) - self.assertIn("requires manual check", errors[0]) + self.assertIn("tf.gradients no longer takes", report) def testColocateGradientsWithOpsMinimize(self): text = "optimizer.minimize(a, foo=False)\n" @@ -617,10 +600,9 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map self.assertEqual(errors, []) text = "optimizer.minimize(a, colocate_gradients_with_ops=False)\n" - _, unused_report, errors, new_text = self._upgrade(text) + _, report, unused_errors, new_text = self._upgrade(text) self.assertEqual("optimizer.minimize(a)\n", new_text) - self.assertIn("requires manual check", errors[0]) - self.assertIn("minimize", errors[0]) + self.assertIn("Optimizer.minimize no longer takes", report) def testColocateGradientsWithOpsComputeGradients(self): text = "optimizer.compute_gradients(a, foo=False)\n" @@ -629,10 +611,9 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map self.assertEqual(errors, []) text = "optimizer.compute_gradients(a, colocate_gradients_with_ops=False)\n" - _, unused_report, errors, new_text = self._upgrade(text) + _, report, unused_errors, new_text = self._upgrade(text) self.assertEqual("optimizer.compute_gradients(a)\n", new_text) - self.assertIn("requires manual check", errors[0]) - self.assertIn("compute_gradients", errors[0]) + self.assertIn("Optimizer.compute_gradients no longer takes", report) def testExportSavedModelRename(self): text = "self.est.export_savedmodel(path)" @@ -831,7 +812,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map "validate_indices, max_norm)") expected_text = ("tf.nn.embedding_lookup(params=params, ids=ids, " "partition_strategy=partition_strategy, name=name, " - "validate_indices=validate_indices, max_norm=max_norm)") + "max_norm=max_norm)") _, unused_report, unused_errors, new_text = self._upgrade(text) self.assertEqual(new_text, expected_text) @@ -1029,29 +1010,43 @@ def _log_prob(self, x): "assert_scalar"]: text = "tf.%s(a)" % name expected_text = "tf.compat.v1.%s(a)" % name - _, unused_report, errors, new_text = self._upgrade(text) + _, report, unused_errors, new_text = self._upgrade(text) self.assertEqual(expected_text, new_text) - self.assertIn("assert_* functions", errors[0]) + self.assertIn("%s has been" % name, report) text = "tf.debugging.%s(a)" % name expected_text = "tf.compat.v1.debugging.%s(a)" % name - _, unused_report, errors, new_text = self._upgrade(text) + _, report, unused_errors, new_text = self._upgrade(text) self.assertEqual(expected_text, new_text) - self.assertIn("assert_* functions", errors[0]) + self.assertIn("%s has been" % name, report) def testAssertRankStatements(self): for name in ["assert_rank", "assert_rank_at_least", "assert_rank_in"]: text = "tf.%s(a)" % name expected_text = "tf.compat.v1.%s(a)" % name - _, unused_report, errors, new_text = self._upgrade(text) + _, report, unused_errors, new_text = self._upgrade(text) self.assertEqual(expected_text, new_text) - self.assertIn("assert_rank_* functions", errors[0]) + self.assertIn("%s has been" % name, report) text = "tf.debugging.%s(a)" % name expected_text = "tf.compat.v1.debugging.%s(a)" % name - _, unused_report, errors, new_text = self._upgrade(text) + _, report, unused_errors, new_text = self._upgrade(text) self.assertEqual(expected_text, new_text) - self.assertIn("assert_rank_* functions", errors[0]) + self.assertIn("%s has been" % name, report) + + def test_assert_equal_graph_def(self): + text = "tf.test.assert_equal_graph_def(a, b, checkpoint_v2=x)" + expected = "tf.test.assert_equal_graph_def(actual=a, expected=b)" + _, _, _, new_text = self._upgrade(text) + self.assertEqual(expected, new_text) + + def test_sample_distorted_bounding_box(self): + # pylint: disable=line-too-long + text = "tf.image.sample_distorted_bounding_box(a, b, c, d, e, f, g, h, i, j)" + expected = "tf.image.sample_distorted_bounding_box(image_size=a, bounding_boxes=b, seed=c, min_object_covered=e, aspect_ratio_range=f, area_range=g, max_attempts=h, use_image_if_no_bounding_boxes=i, name=j)" + # pylint: enable=line-too-long + _, _, _, new_text = self._upgrade(text) + self.assertEqual(expected, new_text) class TestUpgradeFiles(test_util.TensorFlowTestCase):