Fix tf.contrib.summary migrations in tf_upgrade_v2
This addresses a few things I missed in 754b8a18b2
:
- "tensor" kwarg got renamed to "data"
- "step" argument, if not originally passed, should be set to a default value because in TF 2.0 it's a mandatory argument. For now this uses tf.compat.v1.train.get_or_create_global_step() to preserve the same behavior as TF 1.x.
The pasta upgrade is required because the previous version has a bug that the code I added hits.
PiperOrigin-RevId: 234224028
This commit is contained in:
parent
579fb17325
commit
4f787752f9
@ -411,17 +411,21 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"filter": "filters",
|
||||
},
|
||||
"tf.contrib.summary.audio": {
|
||||
"tensor": "data",
|
||||
"family": None,
|
||||
},
|
||||
"tf.contrib.summary.histogram": {
|
||||
"tensor": "data",
|
||||
"family": None,
|
||||
},
|
||||
"tf.contrib.summary.image": {
|
||||
"tensor": "data",
|
||||
"bad_color": None,
|
||||
"max_images": "max_outputs",
|
||||
"family": None,
|
||||
},
|
||||
"tf.contrib.summary.scalar": {
|
||||
"tensor": "data",
|
||||
"family": None,
|
||||
},
|
||||
}
|
||||
@ -962,6 +966,14 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"only effects core estimator. If you are using "
|
||||
"tf.contrib.learn.Estimator, please switch to using core estimator.")
|
||||
|
||||
# TODO(b/124529441): if possible eliminate need for manual checking.
|
||||
contrib_summary_comment = (
|
||||
ast_edits.WARNING,
|
||||
"(Manual check required) tf.contrib.summary.* functions have been "
|
||||
"migrated best-effort to tf.compat.v2.summary.* equivalents where "
|
||||
"possible, but the resulting code may not always work. Please check "
|
||||
"manually; you can report migration failures on b/124529441.")
|
||||
|
||||
# Function warnings. <function name> placeholder inside warnings will be
|
||||
# replaced by function name.
|
||||
# You can use *. to add items which do not check the FQN, and apply to e.g.,
|
||||
@ -1003,6 +1015,14 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
assert_rank_comment,
|
||||
"tf.assert_rank_in":
|
||||
assert_rank_comment,
|
||||
"tf.contrib.summary.audio":
|
||||
contrib_summary_comment,
|
||||
"tf.contrib.summary.histogram":
|
||||
contrib_summary_comment,
|
||||
"tf.contrib.summary.image":
|
||||
contrib_summary_comment,
|
||||
"tf.contrib.summary.scalar":
|
||||
contrib_summary_comment,
|
||||
"tf.debugging.assert_equal":
|
||||
assert_return_type_comment,
|
||||
"tf.debugging.assert_greater":
|
||||
@ -1340,7 +1360,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
|
||||
# Specially handled functions
|
||||
# Each transformer is a callable which will be called with the arguments
|
||||
# transformer(parent, node, full_name, name, logs, errors)
|
||||
# transformer(parent, node, full_name, name, logs)
|
||||
# Where logs is a list to which (level, line, col, msg) tuples can be
|
||||
# appended, full_name is the FQN of the function called (or None if that is
|
||||
# unknown), name is the name of the function called (or None is that is
|
||||
@ -1411,6 +1431,10 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
_add_argument_transformer,
|
||||
arg_name="data_format",
|
||||
arg_value_ast=ast.Str("NHWC")),
|
||||
"tf.contrib.summary.audio": _add_summary_step_transformer,
|
||||
"tf.contrib.summary.histogram": _add_summary_step_transformer,
|
||||
"tf.contrib.summary.image": _add_summary_step_transformer,
|
||||
"tf.contrib.summary.scalar": _add_summary_step_transformer,
|
||||
}
|
||||
|
||||
self.module_deprecations = {
|
||||
@ -1522,7 +1546,7 @@ def _add_argument_transformer(parent, node, full_name, name, logs,
|
||||
arg_name, arg_value_ast):
|
||||
"""Adds an argument (as a final kwarg arg_name=arg_value_ast)."""
|
||||
node.keywords.append(ast.keyword(arg=arg_name, value=arg_value_ast))
|
||||
logs.add((
|
||||
logs.append((
|
||||
ast_edits.INFO, node.lineno, node.col_offset,
|
||||
"Adding argument '%s' to call to %s." % (pasta.dump(node.keywords[-1],
|
||||
full_name or name))
|
||||
@ -1804,3 +1828,22 @@ def _extract_glimpse_transformer(parent, node, full_name, name, logs):
|
||||
"Changing uniform_noise arg of tf.image.extract_glimpse to "
|
||||
"noise, and recomputing value.\n"))
|
||||
return node
|
||||
|
||||
|
||||
def _add_summary_step_transformer(parent, node, full_name, name, logs):
|
||||
"""Adds a step argument to the summary API call if not specified.
|
||||
|
||||
The inserted argument value is tf.compat.v1.train.get_or_create_global_step().
|
||||
"""
|
||||
for keyword_arg in node.keywords:
|
||||
if keyword_arg.arg == "step":
|
||||
return node
|
||||
default_value = "tf.compat.v1.train.get_or_create_global_step()"
|
||||
# Parse with pasta instead of ast to avoid emitting a spurious trailing \n.
|
||||
ast_value = pasta.parse(default_value)
|
||||
node.keywords.append(ast.keyword(arg="step", value=ast_value))
|
||||
logs.append((
|
||||
ast_edits.WARNING, node.lineno, node.col_offset,
|
||||
"Summary API writing function %s now requires a 'step' argument; "
|
||||
"inserting default of %s." % (full_name or name, default_value)))
|
||||
return node
|
||||
|
@ -1233,36 +1233,77 @@ def _log_prob(self, x):
|
||||
|
||||
def test_contrib_summary_audio(self):
|
||||
text = "tf.contrib.summary.audio('foo', myval, 44100, 3, 'fam', 42)"
|
||||
expected = ("tf.compat.v2.summary.audio(name='foo', tensor=myval, "
|
||||
expected = ("tf.compat.v2.summary.audio(name='foo', data=myval, "
|
||||
"sample_rate=44100, max_outputs=3, step=42)")
|
||||
_, _, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected, new_text)
|
||||
self.assertIn("'family' argument", errors[0])
|
||||
self.assertIn("Manual check required", errors[1])
|
||||
|
||||
def test_contrib_summary_histogram(self):
|
||||
text = "tf.contrib.summary.histogram('foo', myval, 'fam', 42)"
|
||||
expected = ("tf.compat.v2.summary.histogram(name='foo', tensor=myval, "
|
||||
expected = ("tf.compat.v2.summary.histogram(name='foo', data=myval, "
|
||||
"step=42)")
|
||||
_, _, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected, new_text)
|
||||
self.assertIn("'family' argument", errors[0])
|
||||
self.assertIn("Manual check required", errors[1])
|
||||
|
||||
def test_contrib_summary_image(self):
|
||||
text = "tf.contrib.summary.image('foo', myval, red, 3, 'fam', 42)"
|
||||
expected = ("tf.compat.v2.summary.image(name='foo', tensor=myval, "
|
||||
expected = ("tf.compat.v2.summary.image(name='foo', data=myval, "
|
||||
"max_outputs=3, step=42)")
|
||||
_, _, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected, new_text)
|
||||
self.assertIn("'bad_color' argument", errors[0])
|
||||
self.assertIn("'family' argument", errors[1])
|
||||
self.assertIn("Manual check required", errors[2])
|
||||
|
||||
def test_contrib_summary_scalar(self):
|
||||
text = "tf.contrib.summary.scalar('foo', myval, 'fam', 42)"
|
||||
expected = ("tf.compat.v2.summary.scalar(name='foo', tensor=myval, "
|
||||
expected = ("tf.compat.v2.summary.scalar(name='foo', data=myval, "
|
||||
"step=42)")
|
||||
_, _, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected, new_text)
|
||||
self.assertIn("'family' argument", errors[0])
|
||||
self.assertIn("Manual check required", errors[1])
|
||||
|
||||
def test_contrib_summary_audio_nostep(self):
|
||||
text = "tf.contrib.summary.audio('foo', myval, 44100)"
|
||||
expected = ("tf.compat.v2.summary.audio(name='foo', data=myval, "
|
||||
"sample_rate=44100, "
|
||||
"step=tf.compat.v1.train.get_or_create_global_step())")
|
||||
_, _, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected, new_text)
|
||||
self.assertIn("'step' argument", errors[0])
|
||||
self.assertIn("Manual check required", errors[1])
|
||||
|
||||
def test_contrib_summary_histogram_nostep(self):
|
||||
text = "tf.contrib.summary.histogram('foo', myval)"
|
||||
expected = ("tf.compat.v2.summary.histogram(name='foo', data=myval, "
|
||||
"step=tf.compat.v1.train.get_or_create_global_step())")
|
||||
_, _, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected, new_text)
|
||||
self.assertIn("'step' argument", errors[0])
|
||||
self.assertIn("Manual check required", errors[1])
|
||||
|
||||
def test_contrib_summary_image_nostep(self):
|
||||
text = "tf.contrib.summary.image('foo', myval)"
|
||||
expected = ("tf.compat.v2.summary.image(name='foo', data=myval, "
|
||||
"step=tf.compat.v1.train.get_or_create_global_step())")
|
||||
_, _, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected, new_text)
|
||||
self.assertIn("'step' argument", errors[0])
|
||||
self.assertIn("Manual check required", errors[1])
|
||||
|
||||
def test_contrib_summary_scalar_nostep(self):
|
||||
text = "tf.contrib.summary.scalar('foo', myval)"
|
||||
expected = ("tf.compat.v2.summary.scalar(name='foo', data=myval, "
|
||||
"step=tf.compat.v1.train.get_or_create_global_step())")
|
||||
_, _, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected, new_text)
|
||||
self.assertIn("'step' argument", errors[0])
|
||||
self.assertIn("Manual check required", errors[1])
|
||||
|
||||
|
||||
class TestUpgradeFiles(test_util.TensorFlowTestCase):
|
||||
|
8
third_party/pasta/workspace.bzl
vendored
8
third_party/pasta/workspace.bzl
vendored
@ -6,11 +6,11 @@ def repo():
|
||||
third_party_http_archive(
|
||||
name = "pasta",
|
||||
urls = [
|
||||
"https://mirror.bazel.build/github.com/google/pasta/archive/c3d72cdee6fc806251949e912510444d58d7413c.tar.gz",
|
||||
"https://github.com/google/pasta/archive/c3d72cdee6fc806251949e912510444d58d7413c.tar.gz",
|
||||
"https://mirror.bazel.build/github.com/google/pasta/archive/v0.1.2.tar.gz",
|
||||
"https://github.com/google/pasta/archive/v0.1.2.tar.gz",
|
||||
],
|
||||
strip_prefix = "pasta-c3d72cdee6fc806251949e912510444d58d7413c",
|
||||
sha256 = "b5905f9cecc4b28363c563f3c4cb0545288bd35f7cc72c55066e97e53befc084",
|
||||
strip_prefix = "pasta-0.1.2",
|
||||
sha256 = "53e4c009a5eac38e942deb48bfc2d3cfca62cd457255fa86ffedb7e40f726a0c",
|
||||
build_file = "//third_party/pasta:BUILD.bazel",
|
||||
system_build_file = "//third_party/pasta:BUILD.system",
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user