Fix assert_* conversions, ensure emit warnings and change to compat.v1
PiperOrigin-RevId: 229154096
This commit is contained in:
parent
8f593c48c8
commit
ac2f7471b9
@ -603,6 +603,48 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"tf.image.resize",
|
||||
"tf.random_poisson":
|
||||
"tf.random.poisson",
|
||||
"tf.debugging.assert_greater":
|
||||
"tf.compat.v1.debugging.assert_greater",
|
||||
"tf.debugging.assert_greater_equal":
|
||||
"tf.compat.v1.debugging.assert_greater_equal",
|
||||
"tf.debugging.assert_integer":
|
||||
"tf.compat.v1.debugging.assert_integer",
|
||||
"tf.debugging.assert_less":
|
||||
"tf.compat.v1.debugging.assert_less",
|
||||
"tf.debugging.assert_less_equal":
|
||||
"tf.compat.v1.debugging.assert_less_equal",
|
||||
"tf.debugging.assert_near":
|
||||
"tf.compat.v1.debugging.assert_near",
|
||||
"tf.debugging.assert_negative":
|
||||
"tf.compat.v1.debugging.assert_negative",
|
||||
"tf.debugging.assert_non_negative":
|
||||
"tf.compat.v1.debugging.assert_non_negative",
|
||||
"tf.debugging.assert_non_positive":
|
||||
"tf.compat.v1.debugging.assert_non_positive",
|
||||
"tf.debugging.assert_none_equal":
|
||||
"tf.compat.v1.debugging.assert_none_equal",
|
||||
"tf.debugging.assert_type":
|
||||
"tf.compat.v1.debugging.assert_type",
|
||||
"tf.debugging.assert_positive":
|
||||
"tf.compat.v1.debugging.assert_positive",
|
||||
"tf.debugging.assert_equal":
|
||||
"tf.compat.v1.debugging.assert_equal",
|
||||
"tf.debugging.assert_scalar":
|
||||
"tf.compat.v1.debugging.assert_scalar",
|
||||
"tf.assert_equal":
|
||||
"tf.compat.v1.assert_equal",
|
||||
"tf.assert_less":
|
||||
"tf.compat.v1.assert_less",
|
||||
"tf.assert_greater":
|
||||
"tf.compat.v1.assert_greater",
|
||||
"tf.debugging.assert_rank":
|
||||
"tf.compat.v1.debugging.assert_rank",
|
||||
"tf.debugging.assert_rank_at_least":
|
||||
"tf.compat.v1.debugging.assert_rank_at_least",
|
||||
"tf.debugging.assert_rank_in":
|
||||
"tf.compat.v1.debugging.assert_rank_in",
|
||||
"tf.assert_rank":
|
||||
"tf.compat.v1.assert_rank",
|
||||
}
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
@ -851,10 +893,40 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
assert_return_type_comment,
|
||||
"tf.assert_equal":
|
||||
assert_return_type_comment,
|
||||
"tf.assert_none_equal":
|
||||
assert_return_type_comment,
|
||||
"tf.assert_less":
|
||||
assert_return_type_comment,
|
||||
"tf.assert_negative":
|
||||
assert_return_type_comment,
|
||||
"tf.assert_positive":
|
||||
assert_return_type_comment,
|
||||
"tf.assert_non_negative":
|
||||
assert_return_type_comment,
|
||||
"tf.assert_non_positive":
|
||||
assert_return_type_comment,
|
||||
"tf.assert_near":
|
||||
assert_return_type_comment,
|
||||
"tf.assert_less":
|
||||
assert_return_type_comment,
|
||||
"tf.assert_less_equal":
|
||||
assert_return_type_comment,
|
||||
"tf.assert_greater":
|
||||
assert_return_type_comment,
|
||||
"tf.assert_greater_equal":
|
||||
assert_return_type_comment,
|
||||
"tf.assert_integer":
|
||||
assert_return_type_comment,
|
||||
"tf.assert_type":
|
||||
assert_return_type_comment,
|
||||
"tf.assert_scalar":
|
||||
assert_return_type_comment,
|
||||
"tf.assert_rank":
|
||||
assert_rank_comment,
|
||||
"tf.assert_rank_at_least":
|
||||
assert_rank_comment,
|
||||
"tf.assert_rank_in":
|
||||
assert_rank_comment,
|
||||
"tf.debugging.assert_equal":
|
||||
assert_return_type_comment,
|
||||
"tf.debugging.assert_greater":
|
||||
@ -879,6 +951,10 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
assert_return_type_comment,
|
||||
"tf.debugging.assert_positive":
|
||||
assert_return_type_comment,
|
||||
"tf.debugging.assert_type":
|
||||
assert_return_type_comment,
|
||||
"tf.debugging.assert_scalar":
|
||||
assert_return_type_comment,
|
||||
"tf.debugging.assert_rank":
|
||||
assert_rank_comment,
|
||||
"tf.debugging.assert_rank_at_least":
|
||||
|
@ -114,12 +114,12 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
self.assertTrue(report.find("Failed to parse") != -1)
|
||||
|
||||
def testReport(self):
|
||||
text = "tf.assert_near(a)\n"
|
||||
text = "tf.angle(a)\n"
|
||||
_, report, unused_errors, unused_new_text = self._upgrade(text)
|
||||
# This is not a complete test, but it is a sanity test that a report
|
||||
# is generating information.
|
||||
self.assertTrue(report.find("Renamed function `tf.assert_near` to "
|
||||
"`tf.debugging.assert_near`"))
|
||||
self.assertTrue(report.find("Renamed function `tf.angle` to "
|
||||
"`tf.math.angle`"))
|
||||
|
||||
def testRename(self):
|
||||
text = "tf.conj(a)\n"
|
||||
@ -937,6 +937,39 @@ def _log_prob(self, x):
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
def testAssertStatements(self):
|
||||
for name in ["assert_greater", "assert_equal", "assert_none_equal",
|
||||
"assert_less", "assert_negative", "assert_positive",
|
||||
"assert_non_negative", "assert_non_positive", "assert_near",
|
||||
"assert_less", "assert_less_equal", "assert_greater",
|
||||
"assert_greater_equal", "assert_integer", "assert_type",
|
||||
"assert_scalar"]:
|
||||
text = "tf.%s(a)" % name
|
||||
expected_text = "tf.compat.v1.%s(a)" % name
|
||||
_, unused_report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
self.assertIn("assert_* functions", errors[0])
|
||||
|
||||
text = "tf.debugging.%s(a)" % name
|
||||
expected_text = "tf.compat.v1.debugging.%s(a)" % name
|
||||
_, unused_report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
self.assertIn("assert_* functions", errors[0])
|
||||
|
||||
def testAssertRankStatements(self):
|
||||
for name in ["assert_rank", "assert_rank_at_least", "assert_rank_in"]:
|
||||
text = "tf.%s(a)" % name
|
||||
expected_text = "tf.compat.v1.%s(a)" % name
|
||||
_, unused_report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
self.assertIn("assert_rank_* functions", errors[0])
|
||||
|
||||
text = "tf.debugging.%s(a)" % name
|
||||
expected_text = "tf.compat.v1.debugging.%s(a)" % name
|
||||
_, unused_report, errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
self.assertIn("assert_rank_* functions", errors[0])
|
||||
|
||||
|
||||
class TestUpgradeFiles(test_util.TensorFlowTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user