diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index 02f0e8401e5..07ca6cec1a2 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -411,17 +411,21 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): "filter": "filters", }, "tf.contrib.summary.audio": { + "tensor": "data", "family": None, }, "tf.contrib.summary.histogram": { + "tensor": "data", "family": None, }, "tf.contrib.summary.image": { + "tensor": "data", "bad_color": None, "max_images": "max_outputs", "family": None, }, "tf.contrib.summary.scalar": { + "tensor": "data", "family": None, }, } @@ -962,6 +966,14 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): "only effects core estimator. If you are using " "tf.contrib.learn.Estimator, please switch to using core estimator.") + # TODO(b/124529441): if possible eliminate need for manual checking. + contrib_summary_comment = ( + ast_edits.WARNING, + "(Manual check required) tf.contrib.summary.* functions have been " + "migrated best-effort to tf.compat.v2.summary.* equivalents where " + "possible, but the resulting code may not always work. Please check " + "manually; you can report migration failures on b/124529441.") + # Function warnings. placeholder inside warnings will be # replaced by function name. # You can use *. to add items which do not check the FQN, and apply to e.g., @@ -1003,6 +1015,14 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): assert_rank_comment, "tf.assert_rank_in": assert_rank_comment, + "tf.contrib.summary.audio": + contrib_summary_comment, + "tf.contrib.summary.histogram": + contrib_summary_comment, + "tf.contrib.summary.image": + contrib_summary_comment, + "tf.contrib.summary.scalar": + contrib_summary_comment, "tf.debugging.assert_equal": assert_return_type_comment, "tf.debugging.assert_greater": @@ -1340,7 +1360,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): # Specially handled functions # Each transformer is a callable which will be called with the arguments - # transformer(parent, node, full_name, name, logs, errors) + # transformer(parent, node, full_name, name, logs) # Where logs is a list to which (level, line, col, msg) tuples can be # appended, full_name is the FQN of the function called (or None if that is # unknown), name is the name of the function called (or None is that is @@ -1411,6 +1431,10 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): _add_argument_transformer, arg_name="data_format", arg_value_ast=ast.Str("NHWC")), + "tf.contrib.summary.audio": _add_summary_step_transformer, + "tf.contrib.summary.histogram": _add_summary_step_transformer, + "tf.contrib.summary.image": _add_summary_step_transformer, + "tf.contrib.summary.scalar": _add_summary_step_transformer, } self.module_deprecations = { @@ -1522,7 +1546,7 @@ def _add_argument_transformer(parent, node, full_name, name, logs, arg_name, arg_value_ast): """Adds an argument (as a final kwarg arg_name=arg_value_ast).""" node.keywords.append(ast.keyword(arg=arg_name, value=arg_value_ast)) - logs.add(( + logs.append(( ast_edits.INFO, node.lineno, node.col_offset, "Adding argument '%s' to call to %s." % (pasta.dump(node.keywords[-1], full_name or name)) @@ -1804,3 +1828,22 @@ def _extract_glimpse_transformer(parent, node, full_name, name, logs): "Changing uniform_noise arg of tf.image.extract_glimpse to " "noise, and recomputing value.\n")) return node + + +def _add_summary_step_transformer(parent, node, full_name, name, logs): + """Adds a step argument to the summary API call if not specified. + + The inserted argument value is tf.compat.v1.train.get_or_create_global_step(). + """ + for keyword_arg in node.keywords: + if keyword_arg.arg == "step": + return node + default_value = "tf.compat.v1.train.get_or_create_global_step()" + # Parse with pasta instead of ast to avoid emitting a spurious trailing \n. + ast_value = pasta.parse(default_value) + node.keywords.append(ast.keyword(arg="step", value=ast_value)) + logs.append(( + ast_edits.WARNING, node.lineno, node.col_offset, + "Summary API writing function %s now requires a 'step' argument; " + "inserting default of %s." % (full_name or name, default_value))) + return node diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py index 98d9cfc3817..e53f5ae79f2 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -1233,36 +1233,77 @@ def _log_prob(self, x): def test_contrib_summary_audio(self): text = "tf.contrib.summary.audio('foo', myval, 44100, 3, 'fam', 42)" - expected = ("tf.compat.v2.summary.audio(name='foo', tensor=myval, " + expected = ("tf.compat.v2.summary.audio(name='foo', data=myval, " "sample_rate=44100, max_outputs=3, step=42)") _, _, errors, new_text = self._upgrade(text) self.assertEqual(expected, new_text) self.assertIn("'family' argument", errors[0]) + self.assertIn("Manual check required", errors[1]) def test_contrib_summary_histogram(self): text = "tf.contrib.summary.histogram('foo', myval, 'fam', 42)" - expected = ("tf.compat.v2.summary.histogram(name='foo', tensor=myval, " + expected = ("tf.compat.v2.summary.histogram(name='foo', data=myval, " "step=42)") _, _, errors, new_text = self._upgrade(text) self.assertEqual(expected, new_text) self.assertIn("'family' argument", errors[0]) + self.assertIn("Manual check required", errors[1]) def test_contrib_summary_image(self): text = "tf.contrib.summary.image('foo', myval, red, 3, 'fam', 42)" - expected = ("tf.compat.v2.summary.image(name='foo', tensor=myval, " + expected = ("tf.compat.v2.summary.image(name='foo', data=myval, " "max_outputs=3, step=42)") _, _, errors, new_text = self._upgrade(text) self.assertEqual(expected, new_text) self.assertIn("'bad_color' argument", errors[0]) self.assertIn("'family' argument", errors[1]) + self.assertIn("Manual check required", errors[2]) def test_contrib_summary_scalar(self): text = "tf.contrib.summary.scalar('foo', myval, 'fam', 42)" - expected = ("tf.compat.v2.summary.scalar(name='foo', tensor=myval, " + expected = ("tf.compat.v2.summary.scalar(name='foo', data=myval, " "step=42)") _, _, errors, new_text = self._upgrade(text) self.assertEqual(expected, new_text) self.assertIn("'family' argument", errors[0]) + self.assertIn("Manual check required", errors[1]) + + def test_contrib_summary_audio_nostep(self): + text = "tf.contrib.summary.audio('foo', myval, 44100)" + expected = ("tf.compat.v2.summary.audio(name='foo', data=myval, " + "sample_rate=44100, " + "step=tf.compat.v1.train.get_or_create_global_step())") + _, _, errors, new_text = self._upgrade(text) + self.assertEqual(expected, new_text) + self.assertIn("'step' argument", errors[0]) + self.assertIn("Manual check required", errors[1]) + + def test_contrib_summary_histogram_nostep(self): + text = "tf.contrib.summary.histogram('foo', myval)" + expected = ("tf.compat.v2.summary.histogram(name='foo', data=myval, " + "step=tf.compat.v1.train.get_or_create_global_step())") + _, _, errors, new_text = self._upgrade(text) + self.assertEqual(expected, new_text) + self.assertIn("'step' argument", errors[0]) + self.assertIn("Manual check required", errors[1]) + + def test_contrib_summary_image_nostep(self): + text = "tf.contrib.summary.image('foo', myval)" + expected = ("tf.compat.v2.summary.image(name='foo', data=myval, " + "step=tf.compat.v1.train.get_or_create_global_step())") + _, _, errors, new_text = self._upgrade(text) + self.assertEqual(expected, new_text) + self.assertIn("'step' argument", errors[0]) + self.assertIn("Manual check required", errors[1]) + + def test_contrib_summary_scalar_nostep(self): + text = "tf.contrib.summary.scalar('foo', myval)" + expected = ("tf.compat.v2.summary.scalar(name='foo', data=myval, " + "step=tf.compat.v1.train.get_or_create_global_step())") + _, _, errors, new_text = self._upgrade(text) + self.assertEqual(expected, new_text) + self.assertIn("'step' argument", errors[0]) + self.assertIn("Manual check required", errors[1]) class TestUpgradeFiles(test_util.TensorFlowTestCase): diff --git a/third_party/pasta/workspace.bzl b/third_party/pasta/workspace.bzl index e46cc4a45e4..9961835328e 100644 --- a/third_party/pasta/workspace.bzl +++ b/third_party/pasta/workspace.bzl @@ -6,11 +6,11 @@ def repo(): third_party_http_archive( name = "pasta", urls = [ - "https://mirror.bazel.build/github.com/google/pasta/archive/c3d72cdee6fc806251949e912510444d58d7413c.tar.gz", - "https://github.com/google/pasta/archive/c3d72cdee6fc806251949e912510444d58d7413c.tar.gz", + "https://mirror.bazel.build/github.com/google/pasta/archive/v0.1.2.tar.gz", + "https://github.com/google/pasta/archive/v0.1.2.tar.gz", ], - strip_prefix = "pasta-c3d72cdee6fc806251949e912510444d58d7413c", - sha256 = "b5905f9cecc4b28363c563f3c4cb0545288bd35f7cc72c55066e97e53befc084", + strip_prefix = "pasta-0.1.2", + sha256 = "53e4c009a5eac38e942deb48bfc2d3cfca62cd457255fa86ffedb7e40f726a0c", build_file = "//third_party/pasta:BUILD.bazel", system_build_file = "//third_party/pasta:BUILD.system", )