Fix tf.contrib.summary migrations in tf_upgrade_v2

This addresses a few things I missed in 754b8a18b2:

- "tensor" kwarg got renamed to "data"
- "step" argument, if not originally passed, should be set to a default value because in TF 2.0 it's a mandatory argument. For now this uses tf.compat.v1.train.get_or_create_global_step() to preserve the same behavior as TF 1.x.

The pasta upgrade is required because the previous version has a bug that the code I added hits.

PiperOrigin-RevId: 234224028
This commit is contained in:
Nick Felt 2019-02-15 15:26:42 -08:00 committed by TensorFlower Gardener
parent 579fb17325
commit 4f787752f9
3 changed files with 94 additions and 10 deletions

View File

@ -411,17 +411,21 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"filter": "filters",
},
"tf.contrib.summary.audio": {
"tensor": "data",
"family": None,
},
"tf.contrib.summary.histogram": {
"tensor": "data",
"family": None,
},
"tf.contrib.summary.image": {
"tensor": "data",
"bad_color": None,
"max_images": "max_outputs",
"family": None,
},
"tf.contrib.summary.scalar": {
"tensor": "data",
"family": None,
},
}
@ -962,6 +966,14 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"only effects core estimator. If you are using "
"tf.contrib.learn.Estimator, please switch to using core estimator.")
# TODO(b/124529441): if possible eliminate need for manual checking.
contrib_summary_comment = (
ast_edits.WARNING,
"(Manual check required) tf.contrib.summary.* functions have been "
"migrated best-effort to tf.compat.v2.summary.* equivalents where "
"possible, but the resulting code may not always work. Please check "
"manually; you can report migration failures on b/124529441.")
# Function warnings. <function name> placeholder inside warnings will be
# replaced by function name.
# You can use *. to add items which do not check the FQN, and apply to e.g.,
@ -1003,6 +1015,14 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
assert_rank_comment,
"tf.assert_rank_in":
assert_rank_comment,
"tf.contrib.summary.audio":
contrib_summary_comment,
"tf.contrib.summary.histogram":
contrib_summary_comment,
"tf.contrib.summary.image":
contrib_summary_comment,
"tf.contrib.summary.scalar":
contrib_summary_comment,
"tf.debugging.assert_equal":
assert_return_type_comment,
"tf.debugging.assert_greater":
@ -1340,7 +1360,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
# Specially handled functions
# Each transformer is a callable which will be called with the arguments
# transformer(parent, node, full_name, name, logs, errors)
# transformer(parent, node, full_name, name, logs)
# Where logs is a list to which (level, line, col, msg) tuples can be
# appended, full_name is the FQN of the function called (or None if that is
# unknown), name is the name of the function called (or None is that is
@ -1411,6 +1431,10 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
_add_argument_transformer,
arg_name="data_format",
arg_value_ast=ast.Str("NHWC")),
"tf.contrib.summary.audio": _add_summary_step_transformer,
"tf.contrib.summary.histogram": _add_summary_step_transformer,
"tf.contrib.summary.image": _add_summary_step_transformer,
"tf.contrib.summary.scalar": _add_summary_step_transformer,
}
self.module_deprecations = {
@ -1522,7 +1546,7 @@ def _add_argument_transformer(parent, node, full_name, name, logs,
arg_name, arg_value_ast):
"""Adds an argument (as a final kwarg arg_name=arg_value_ast)."""
node.keywords.append(ast.keyword(arg=arg_name, value=arg_value_ast))
logs.add((
logs.append((
ast_edits.INFO, node.lineno, node.col_offset,
"Adding argument '%s' to call to %s." % (pasta.dump(node.keywords[-1],
full_name or name))
@ -1804,3 +1828,22 @@ def _extract_glimpse_transformer(parent, node, full_name, name, logs):
"Changing uniform_noise arg of tf.image.extract_glimpse to "
"noise, and recomputing value.\n"))
return node
def _add_summary_step_transformer(parent, node, full_name, name, logs):
"""Adds a step argument to the summary API call if not specified.
The inserted argument value is tf.compat.v1.train.get_or_create_global_step().
"""
for keyword_arg in node.keywords:
if keyword_arg.arg == "step":
return node
default_value = "tf.compat.v1.train.get_or_create_global_step()"
# Parse with pasta instead of ast to avoid emitting a spurious trailing \n.
ast_value = pasta.parse(default_value)
node.keywords.append(ast.keyword(arg="step", value=ast_value))
logs.append((
ast_edits.WARNING, node.lineno, node.col_offset,
"Summary API writing function %s now requires a 'step' argument; "
"inserting default of %s." % (full_name or name, default_value)))
return node

View File

@ -1233,36 +1233,77 @@ def _log_prob(self, x):
def test_contrib_summary_audio(self):
text = "tf.contrib.summary.audio('foo', myval, 44100, 3, 'fam', 42)"
expected = ("tf.compat.v2.summary.audio(name='foo', tensor=myval, "
expected = ("tf.compat.v2.summary.audio(name='foo', data=myval, "
"sample_rate=44100, max_outputs=3, step=42)")
_, _, errors, new_text = self._upgrade(text)
self.assertEqual(expected, new_text)
self.assertIn("'family' argument", errors[0])
self.assertIn("Manual check required", errors[1])
def test_contrib_summary_histogram(self):
text = "tf.contrib.summary.histogram('foo', myval, 'fam', 42)"
expected = ("tf.compat.v2.summary.histogram(name='foo', tensor=myval, "
expected = ("tf.compat.v2.summary.histogram(name='foo', data=myval, "
"step=42)")
_, _, errors, new_text = self._upgrade(text)
self.assertEqual(expected, new_text)
self.assertIn("'family' argument", errors[0])
self.assertIn("Manual check required", errors[1])
def test_contrib_summary_image(self):
text = "tf.contrib.summary.image('foo', myval, red, 3, 'fam', 42)"
expected = ("tf.compat.v2.summary.image(name='foo', tensor=myval, "
expected = ("tf.compat.v2.summary.image(name='foo', data=myval, "
"max_outputs=3, step=42)")
_, _, errors, new_text = self._upgrade(text)
self.assertEqual(expected, new_text)
self.assertIn("'bad_color' argument", errors[0])
self.assertIn("'family' argument", errors[1])
self.assertIn("Manual check required", errors[2])
def test_contrib_summary_scalar(self):
text = "tf.contrib.summary.scalar('foo', myval, 'fam', 42)"
expected = ("tf.compat.v2.summary.scalar(name='foo', tensor=myval, "
expected = ("tf.compat.v2.summary.scalar(name='foo', data=myval, "
"step=42)")
_, _, errors, new_text = self._upgrade(text)
self.assertEqual(expected, new_text)
self.assertIn("'family' argument", errors[0])
self.assertIn("Manual check required", errors[1])
def test_contrib_summary_audio_nostep(self):
text = "tf.contrib.summary.audio('foo', myval, 44100)"
expected = ("tf.compat.v2.summary.audio(name='foo', data=myval, "
"sample_rate=44100, "
"step=tf.compat.v1.train.get_or_create_global_step())")
_, _, errors, new_text = self._upgrade(text)
self.assertEqual(expected, new_text)
self.assertIn("'step' argument", errors[0])
self.assertIn("Manual check required", errors[1])
def test_contrib_summary_histogram_nostep(self):
text = "tf.contrib.summary.histogram('foo', myval)"
expected = ("tf.compat.v2.summary.histogram(name='foo', data=myval, "
"step=tf.compat.v1.train.get_or_create_global_step())")
_, _, errors, new_text = self._upgrade(text)
self.assertEqual(expected, new_text)
self.assertIn("'step' argument", errors[0])
self.assertIn("Manual check required", errors[1])
def test_contrib_summary_image_nostep(self):
text = "tf.contrib.summary.image('foo', myval)"
expected = ("tf.compat.v2.summary.image(name='foo', data=myval, "
"step=tf.compat.v1.train.get_or_create_global_step())")
_, _, errors, new_text = self._upgrade(text)
self.assertEqual(expected, new_text)
self.assertIn("'step' argument", errors[0])
self.assertIn("Manual check required", errors[1])
def test_contrib_summary_scalar_nostep(self):
text = "tf.contrib.summary.scalar('foo', myval)"
expected = ("tf.compat.v2.summary.scalar(name='foo', data=myval, "
"step=tf.compat.v1.train.get_or_create_global_step())")
_, _, errors, new_text = self._upgrade(text)
self.assertEqual(expected, new_text)
self.assertIn("'step' argument", errors[0])
self.assertIn("Manual check required", errors[1])
class TestUpgradeFiles(test_util.TensorFlowTestCase):

View File

@ -6,11 +6,11 @@ def repo():
third_party_http_archive(
name = "pasta",
urls = [
"https://mirror.bazel.build/github.com/google/pasta/archive/c3d72cdee6fc806251949e912510444d58d7413c.tar.gz",
"https://github.com/google/pasta/archive/c3d72cdee6fc806251949e912510444d58d7413c.tar.gz",
"https://mirror.bazel.build/github.com/google/pasta/archive/v0.1.2.tar.gz",
"https://github.com/google/pasta/archive/v0.1.2.tar.gz",
],
strip_prefix = "pasta-c3d72cdee6fc806251949e912510444d58d7413c",
sha256 = "b5905f9cecc4b28363c563f3c4cb0545288bd35f7cc72c55066e97e53befc084",
strip_prefix = "pasta-0.1.2",
sha256 = "53e4c009a5eac38e942deb48bfc2d3cfca62cd457255fa86ffedb7e40f726a0c",
build_file = "//third_party/pasta:BUILD.bazel",
system_build_file = "//third_party/pasta:BUILD.system",
)