Test that all outputs of the converter are in the v2 API.
Fixes issue where piecewise_constant was a warning only despite not being a part of the v2 API. PiperOrigin-RevId: 223364226
This commit is contained in:
parent
f6e8f7a1fb
commit
bfec759230
@ -78,8 +78,11 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":tf_upgrade_v2",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/tools/common:public_api",
|
||||
"//tensorflow/tools/common:traverse",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -97,6 +97,7 @@ renames = {
|
||||
'tf.check_numerics': 'tf.debugging.check_numerics',
|
||||
'tf.cholesky': 'tf.linalg.cholesky',
|
||||
'tf.cholesky_solve': 'tf.linalg.cholesky_solve',
|
||||
'tf.clip_by_average_norm': 'tf.compat.v1.clip_by_average_norm',
|
||||
'tf.colocate_with': 'tf.compat.v1.colocate_with',
|
||||
'tf.conj': 'tf.math.conj',
|
||||
'tf.container': 'tf.compat.v1.container',
|
||||
@ -107,7 +108,6 @@ renames = {
|
||||
'tf.create_partitioned_variables': 'tf.compat.v1.create_partitioned_variables',
|
||||
'tf.cross': 'tf.linalg.cross',
|
||||
'tf.cumprod': 'tf.math.cumprod',
|
||||
'tf.data.Iterator': 'tf.compat.v1.data.Iterator',
|
||||
'tf.debugging.is_finite': 'tf.math.is_finite',
|
||||
'tf.debugging.is_inf': 'tf.math.is_inf',
|
||||
'tf.debugging.is_nan': 'tf.math.is_nan',
|
||||
@ -358,7 +358,6 @@ renames = {
|
||||
'tf.nn.depthwise_conv2d_native_backprop_filter': 'tf.nn.depthwise_conv2d_backprop_filter',
|
||||
'tf.nn.depthwise_conv2d_native_backprop_input': 'tf.nn.depthwise_conv2d_backprop_input',
|
||||
'tf.nn.dynamic_rnn': 'tf.compat.v1.nn.dynamic_rnn',
|
||||
'tf.nn.fused_batch_norm': 'tf.compat.v1.nn.fused_batch_norm',
|
||||
'tf.nn.log_uniform_candidate_sampler': 'tf.random.log_uniform_candidate_sampler',
|
||||
'tf.nn.quantized_avg_pool': 'tf.compat.v1.nn.quantized_avg_pool',
|
||||
'tf.nn.quantized_conv2d': 'tf.compat.v1.nn.quantized_conv2d',
|
||||
@ -427,13 +426,14 @@ renames = {
|
||||
'tf.rsqrt': 'tf.math.rsqrt',
|
||||
'tf.saved_model.Builder': 'tf.compat.v1.saved_model.Builder',
|
||||
'tf.saved_model.LEGACY_INIT_OP_KEY': 'tf.compat.v1.saved_model.LEGACY_INIT_OP_KEY',
|
||||
'tf.saved_model.MAIN_OP_KEY': 'tf.compat.v1.saved_model.MAIN_OP_KEY',
|
||||
'tf.saved_model.TRAINING': 'tf.saved_model.TRANING',
|
||||
'tf.saved_model.build_tensor_info': 'tf.compat.v1.saved_model.build_tensor_info',
|
||||
'tf.saved_model.builder.SavedModelBuilder': 'tf.compat.v1.saved_model.builder.SavedModelBuilder',
|
||||
'tf.saved_model.constants.ASSETS_DIRECTORY': 'tf.saved_model.ASSETS_DIRECTORY',
|
||||
'tf.saved_model.constants.ASSETS_KEY': 'tf.saved_model.ASSETS_KEY',
|
||||
'tf.saved_model.constants.LEGACY_INIT_OP_KEY': 'tf.compat.v1.saved_model.constants.LEGACY_INIT_OP_KEY',
|
||||
'tf.saved_model.constants.MAIN_OP_KEY': 'tf.saved_model.MAIN_OP_KEY',
|
||||
'tf.saved_model.constants.MAIN_OP_KEY': 'tf.compat.v1.saved_model.constants.MAIN_OP_KEY',
|
||||
'tf.saved_model.constants.SAVED_MODEL_FILENAME_PB': 'tf.saved_model.SAVED_MODEL_FILENAME_PB',
|
||||
'tf.saved_model.constants.SAVED_MODEL_FILENAME_PBTXT': 'tf.saved_model.SAVED_MODEL_FILENAME_PBTXT',
|
||||
'tf.saved_model.constants.SAVED_MODEL_SCHEMA_VERSION': 'tf.saved_model.SAVED_MODEL_SCHEMA_VERSION',
|
||||
|
@ -433,6 +433,10 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"tf.math.confusion_matrix",
|
||||
"tf.decode_csv":
|
||||
"tf.io.decode_csv",
|
||||
"tf.data.Iterator":
|
||||
"tf.compat.v1.data.Iterator",
|
||||
"tf.nn.fused_batch_norm":
|
||||
"tf.compat.v1.nn.fused_batch_norm",
|
||||
}
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
@ -681,7 +685,7 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
" module if you need command line parsing.",
|
||||
"tf.train.exponential_decay":
|
||||
decay_function_comment,
|
||||
"tf.train.piecewise_constant":
|
||||
"tf.train.piecewise_constant_decay":
|
||||
decay_function_comment,
|
||||
"tf.train.polynomial_decay":
|
||||
decay_function_comment,
|
||||
|
@ -17,11 +17,21 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import six
|
||||
import tensorflow as tf
|
||||
# OSS TF V2 import placeholder.
|
||||
|
||||
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test as test_lib
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_export
|
||||
from tensorflow.tools.common import public_api
|
||||
from tensorflow.tools.common import traverse
|
||||
from tensorflow.tools.compatibility import ast_edits
|
||||
from tensorflow.tools.compatibility import tf_upgrade_v2
|
||||
|
||||
@ -64,6 +74,51 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, "tf.math.rsqrt(tf.math.log_sigmoid(3.8))\n")
|
||||
|
||||
def testAllAPI(self):
|
||||
if not hasattr(tf.compat, "v2"):
|
||||
return
|
||||
|
||||
v2_symbols = set([])
|
||||
attr_v2 = tf_export.API_ATTRS[
|
||||
tf_export.TENSORFLOW_API_NAME].names
|
||||
|
||||
def symbol_collector(unused_path, unused_parent, children):
|
||||
for child in children:
|
||||
_, attr = tf_decorator.unwrap(child[1])
|
||||
if not hasattr(attr, "__dict__"):
|
||||
continue
|
||||
api_names_v2 = attr.__dict__.get(attr_v2, [])
|
||||
for name in api_names_v2:
|
||||
v2_symbols.add("tf." + name)
|
||||
|
||||
visitor = public_api.PublicAPIVisitor(symbol_collector)
|
||||
traverse.traverse(tf.compat.v2, visitor)
|
||||
|
||||
attr_v1 = (
|
||||
tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].names)
|
||||
|
||||
# Converts all symbols in the v1 namespace to the v2 namespace, raising
|
||||
# an error if the target of the conversion is not in the v2 namespace.
|
||||
def conversion_visitor(unused_path, unused_parent, children):
|
||||
for child in children:
|
||||
_, attr = tf_decorator.unwrap(child[1])
|
||||
if not hasattr(attr, "__dict__"):
|
||||
continue
|
||||
api_names = attr.__dict__.get(attr_v1, [])
|
||||
for name in api_names:
|
||||
_, _, _, text = self._upgrade("tf." + name)
|
||||
if (text and
|
||||
not text.startswith("tf.compat.v1") and
|
||||
text not in v2_symbols):
|
||||
self.assertFalse(
|
||||
True, "Symbol %s generated from %s not in v2 API" % (
|
||||
text, name))
|
||||
|
||||
visitor = public_api.PublicAPIVisitor(conversion_visitor)
|
||||
visitor.do_not_descend_map["tf"].append("contrib")
|
||||
visitor.private_map["tf.compat"] = ["v1", "v2"]
|
||||
traverse.traverse(tf.compat.v1, visitor)
|
||||
|
||||
def testRenameConstant(self):
|
||||
text = "tf.MONOLITHIC_BUILD\n"
|
||||
_, unused_report, unused_errors, new_text = self._upgrade(text)
|
||||
@ -89,7 +144,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
"tf.boolean_mask(tensor=a, mask=b, name=c, axis=d)\n")
|
||||
|
||||
def testLearningRateDecay(self):
|
||||
for decay in ["tf.train.exponential_decay", "tf.train.piecewise_constant",
|
||||
for decay in ["tf.train.exponential_decay",
|
||||
"tf.train.polynomial_decay", "tf.train.natural_exp_decay",
|
||||
"tf.train.inverse_time_decay", "tf.train.cosine_decay",
|
||||
"tf.train.cosine_decay_restarts",
|
||||
@ -101,6 +156,14 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(errors, ["test.py:1: %s requires manual check." % decay])
|
||||
self.assertIn("%s has been changed" % decay, report)
|
||||
|
||||
def testPiecewiseDecay(self):
|
||||
text = "tf.train.piecewise_constant_decay(a, b)\n"
|
||||
_, report, errors, _ = self._upgrade(text)
|
||||
self.assertEqual(
|
||||
errors,
|
||||
["test.py:1: tf.train.piecewise_constant_decay requires manual check."])
|
||||
self.assertIn("tf.train.piecewise_constant_decay has been changed", report)
|
||||
|
||||
def testEstimatorLossReductionChange(self):
|
||||
classes = [
|
||||
"LinearClassifier", "LinearRegressor", "DNNLinearCombinedClassifier",
|
||||
|
Loading…
Reference in New Issue
Block a user