Use "in symbol.__dict__" instead of "hasattr" to check if a symbol has api
names set. The former would behave correctly for subclasses. Also, moving get_v1_names|constants and get_v2_names|constants functions to tf_export.py to reduce code duplication. PiperOrigin-RevId: 225063242
This commit is contained in:
parent
f9dbe98610
commit
93439a5539
@ -147,6 +147,94 @@ def get_canonical_name(api_names, deprecated_api_names):
|
||||
return None
|
||||
|
||||
|
||||
def get_v1_names(symbol):
|
||||
"""Get a list of TF 1.* names for this symbol.
|
||||
|
||||
Args:
|
||||
symbol: symbol to get API names for.
|
||||
|
||||
Returns:
|
||||
List of all API names for this symbol including TensorFlow and
|
||||
Estimator names.
|
||||
"""
|
||||
names_v1 = []
|
||||
tensorflow_api_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].names
|
||||
estimator_api_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].names
|
||||
|
||||
if not hasattr(symbol, tensorflow_api_attr_v1):
|
||||
return names_v1
|
||||
if tensorflow_api_attr_v1 in symbol.__dict__:
|
||||
names_v1.extend(getattr(symbol, tensorflow_api_attr_v1))
|
||||
if estimator_api_attr_v1 in symbol.__dict__:
|
||||
names_v1.extend(getattr(symbol, estimator_api_attr_v1))
|
||||
return names_v1
|
||||
|
||||
|
||||
def get_v2_names(symbol):
|
||||
"""Get a list of TF 2.0 names for this symbol.
|
||||
|
||||
Args:
|
||||
symbol: symbol to get API names for.
|
||||
|
||||
Returns:
|
||||
List of all API names for this symbol including TensorFlow and
|
||||
Estimator names.
|
||||
"""
|
||||
names_v2 = []
|
||||
tensorflow_api_attr = API_ATTRS[TENSORFLOW_API_NAME].names
|
||||
estimator_api_attr = API_ATTRS[ESTIMATOR_API_NAME].names
|
||||
|
||||
if not hasattr(symbol, tensorflow_api_attr):
|
||||
return names_v2
|
||||
if tensorflow_api_attr in symbol.__dict__:
|
||||
names_v2.extend(getattr(symbol, tensorflow_api_attr))
|
||||
if estimator_api_attr in symbol.__dict__:
|
||||
names_v2.extend(getattr(symbol, estimator_api_attr))
|
||||
return names_v2
|
||||
|
||||
|
||||
def get_v1_constants(module):
|
||||
"""Get a list of TF 1.* constants in this module.
|
||||
|
||||
Args:
|
||||
module: TensorFlow module.
|
||||
|
||||
Returns:
|
||||
List of all API constants under the given module including TensorFlow and
|
||||
Estimator constants.
|
||||
"""
|
||||
constants_v1 = []
|
||||
tensorflow_constants_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].constants
|
||||
estimator_constants_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].constants
|
||||
|
||||
if hasattr(module, tensorflow_constants_attr_v1):
|
||||
constants_v1.extend(getattr(module, tensorflow_constants_attr_v1))
|
||||
if hasattr(module, estimator_constants_attr_v1):
|
||||
constants_v1.extend(getattr(module, estimator_constants_attr_v1))
|
||||
return constants_v1
|
||||
|
||||
|
||||
def get_v2_constants(module):
|
||||
"""Get a list of TF 2.0 constants in this module.
|
||||
|
||||
Args:
|
||||
module: TensorFlow module.
|
||||
|
||||
Returns:
|
||||
List of all API constants under the given module including TensorFlow and
|
||||
Estimator constants.
|
||||
"""
|
||||
constants_v2 = []
|
||||
tensorflow_constants_attr = API_ATTRS[TENSORFLOW_API_NAME].constants
|
||||
estimator_constants_attr = API_ATTRS[ESTIMATOR_API_NAME].constants
|
||||
|
||||
if hasattr(module, tensorflow_constants_attr):
|
||||
constants_v2.extend(getattr(module, tensorflow_constants_attr))
|
||||
if hasattr(module, estimator_constants_attr):
|
||||
constants_v2.extend(getattr(module, estimator_constants_attr))
|
||||
return constants_v2
|
||||
|
||||
|
||||
class api_export(object): # pylint: disable=invalid-name
|
||||
"""Provides ways to export symbols to the TensorFlow API."""
|
||||
|
||||
|
@ -62,6 +62,10 @@ class ValidateExportTest(test.TestCase):
|
||||
del symbol._tf_api_names
|
||||
if hasattr(symbol, '_tf_api_names_v1'):
|
||||
del symbol._tf_api_names_v1
|
||||
if hasattr(symbol, '_estimator_api_names'):
|
||||
del symbol._estimator_api_names
|
||||
if hasattr(symbol, '_estimator_api_names_v1'):
|
||||
del symbol._estimator_api_names_v1
|
||||
|
||||
def _CreateMockModule(self, name):
|
||||
mock_module = self.MockModule(name)
|
||||
@ -74,6 +78,10 @@ class ValidateExportTest(test.TestCase):
|
||||
decorated_function = export_decorator(_test_function)
|
||||
self.assertEquals(decorated_function, _test_function)
|
||||
self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names)
|
||||
self.assertEquals(['nameA', 'nameB'],
|
||||
tf_export.get_v1_names(decorated_function))
|
||||
self.assertEquals(['nameA', 'nameB'],
|
||||
tf_export.get_v2_names(decorated_function))
|
||||
|
||||
def testExportMultipleFunctions(self):
|
||||
export_decorator1 = tf_export.tf_export('nameA', 'nameB')
|
||||
@ -95,6 +103,22 @@ class ValidateExportTest(test.TestCase):
|
||||
export_decorator_b(TestClassB)
|
||||
self.assertEquals(('TestClassA1',), TestClassA._tf_api_names)
|
||||
self.assertEquals(('TestClassB1',), TestClassB._tf_api_names)
|
||||
self.assertEquals(['TestClassA1'], tf_export.get_v1_names(TestClassA))
|
||||
self.assertEquals(['TestClassB1'], tf_export.get_v1_names(TestClassB))
|
||||
|
||||
def testExportClassInEstimator(self):
|
||||
export_decorator_a = tf_export.tf_export('TestClassA1')
|
||||
export_decorator_a(TestClassA)
|
||||
self.assertEquals(('TestClassA1',), TestClassA._tf_api_names)
|
||||
|
||||
export_decorator_b = tf_export.estimator_export(
|
||||
'estimator.TestClassB1')
|
||||
export_decorator_b(TestClassB)
|
||||
self.assertTrue('_tf_api_names' not in TestClassB.__dict__)
|
||||
self.assertEquals(('TestClassA1',), TestClassA._tf_api_names)
|
||||
self.assertEquals(['TestClassA1'], tf_export.get_v1_names(TestClassA))
|
||||
self.assertEquals(['estimator.TestClassB1'],
|
||||
tf_export.get_v1_names(TestClassB))
|
||||
|
||||
def testExportSingleConstant(self):
|
||||
module1 = self._CreateMockModule('module1')
|
||||
@ -103,6 +127,10 @@ class ValidateExportTest(test.TestCase):
|
||||
export_decorator.export_constant('module1', 'test_constant')
|
||||
self.assertEquals([(('NAME_A', 'NAME_B'), 'test_constant')],
|
||||
module1._tf_api_constants)
|
||||
self.assertEquals([(('NAME_A', 'NAME_B'), 'test_constant')],
|
||||
tf_export.get_v1_constants(module1))
|
||||
self.assertEquals([(('NAME_A', 'NAME_B'), 'test_constant')],
|
||||
tf_export.get_v2_constants(module1))
|
||||
|
||||
def testExportMultipleConstants(self):
|
||||
module1 = self._CreateMockModule('module1')
|
||||
|
@ -37,32 +37,6 @@ from tensorflow.tools.compatibility import ast_edits
|
||||
from tensorflow.tools.compatibility import tf_upgrade_v2
|
||||
|
||||
|
||||
_TENSORFLOW_API_ATTR_V1 = (
|
||||
tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].names)
|
||||
_TENSORFLOW_API_ATTR = tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names
|
||||
_ESTIMATOR_API_ATTR_V1 = (
|
||||
tf_export.API_ATTRS_V1[tf_export.ESTIMATOR_API_NAME].names)
|
||||
_ESTIMATOR_API_ATTR = tf_export.API_ATTRS[tf_export.ESTIMATOR_API_NAME].names
|
||||
|
||||
|
||||
def get_v1_names(symbol):
|
||||
names_v1 = []
|
||||
if hasattr(symbol, _TENSORFLOW_API_ATTR_V1):
|
||||
names_v1.extend(getattr(symbol, _TENSORFLOW_API_ATTR_V1))
|
||||
if hasattr(symbol, _ESTIMATOR_API_ATTR_V1):
|
||||
names_v1.extend(getattr(symbol, _ESTIMATOR_API_ATTR_V1))
|
||||
return names_v1
|
||||
|
||||
|
||||
def get_v2_names(symbol):
|
||||
names_v2 = set()
|
||||
if hasattr(symbol, _TENSORFLOW_API_ATTR):
|
||||
names_v2.update(getattr(symbol, _TENSORFLOW_API_ATTR))
|
||||
if hasattr(symbol, _ESTIMATOR_API_ATTR):
|
||||
names_v2.update(getattr(symbol, _ESTIMATOR_API_ATTR))
|
||||
return list(names_v2)
|
||||
|
||||
|
||||
def get_symbol_for_name(root, name):
|
||||
name_parts = name.split(".")
|
||||
symbol = root
|
||||
@ -118,7 +92,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
def symbol_collector(unused_path, unused_parent, children):
|
||||
for child in children:
|
||||
_, attr = tf_decorator.unwrap(child[1])
|
||||
api_names_v2 = get_v2_names(attr)
|
||||
api_names_v2 = tf_export.get_v2_names(attr)
|
||||
for name in api_names_v2:
|
||||
cls.v2_symbols["tf." + name] = attr
|
||||
|
||||
@ -166,7 +140,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
def conversion_visitor(unused_path, unused_parent, children):
|
||||
for child in children:
|
||||
_, attr = tf_decorator.unwrap(child[1])
|
||||
api_names = get_v1_names(attr)
|
||||
api_names = tf_export.get_v1_names(attr)
|
||||
for name in api_names:
|
||||
_, _, _, text = self._upgrade("tf." + name)
|
||||
if (text and
|
||||
@ -190,7 +164,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
def conversion_visitor(unused_path, unused_parent, children):
|
||||
for child in children:
|
||||
_, attr = tf_decorator.unwrap(child[1])
|
||||
api_names = get_v1_names(attr)
|
||||
api_names = tf_export.get_v1_names(attr)
|
||||
for name in api_names:
|
||||
if collect:
|
||||
v1_symbols.add("tf." + name)
|
||||
@ -219,7 +193,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
def arg_test_visitor(unused_path, unused_parent, children):
|
||||
for child in children:
|
||||
_, attr = tf_decorator.unwrap(child[1])
|
||||
names_v1 = get_v1_names(attr)
|
||||
names_v1 = tf_export.get_v1_names(attr)
|
||||
|
||||
for name in names_v1:
|
||||
name = "tf.%s" % name
|
||||
@ -270,7 +244,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
_, attr = tf_decorator.unwrap(child[1])
|
||||
if not tf_inspect.isfunction(attr):
|
||||
continue
|
||||
names_v1 = get_v1_names(attr)
|
||||
names_v1 = tf_export.get_v1_names(attr)
|
||||
arg_names_v1 = get_args(attr)
|
||||
|
||||
for name in names_v1:
|
||||
@ -340,7 +314,7 @@ bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
|
||||
# get other names for this function
|
||||
attr = get_symbol_for_name(tf.compat.v1, name)
|
||||
_, attr = tf_decorator.unwrap(attr)
|
||||
v1_names = get_v1_names(attr)
|
||||
v1_names = tf_export.get_v1_names(attr)
|
||||
self.assertTrue(v1_names)
|
||||
v1_names = ["tf.%s" % n for n in v1_names]
|
||||
# check if any other name is in
|
||||
|
@ -64,58 +64,6 @@ from __future__ import print_function
|
||||
|
||||
"""
|
||||
|
||||
_TENSORFLOW_API_ATTR_V1 = (
|
||||
tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].names)
|
||||
_TENSORFLOW_API_ATTR = tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names
|
||||
_TENSORFLOW_CONSTANTS_ATTR_V1 = (
|
||||
tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].constants)
|
||||
_TENSORFLOW_CONSTANTS_ATTR = (
|
||||
tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].constants)
|
||||
|
||||
_ESTIMATOR_API_ATTR_V1 = (
|
||||
tf_export.API_ATTRS_V1[tf_export.ESTIMATOR_API_NAME].names)
|
||||
_ESTIMATOR_API_ATTR = tf_export.API_ATTRS[tf_export.ESTIMATOR_API_NAME].names
|
||||
_ESTIMATOR_CONSTANTS_ATTR_V1 = (
|
||||
tf_export.API_ATTRS_V1[tf_export.ESTIMATOR_API_NAME].constants)
|
||||
_ESTIMATOR_CONSTANTS_ATTR = (
|
||||
tf_export.API_ATTRS[tf_export.ESTIMATOR_API_NAME].constants)
|
||||
|
||||
|
||||
def get_v1_names(symbol):
|
||||
names_v1 = []
|
||||
if hasattr(symbol, _TENSORFLOW_API_ATTR_V1):
|
||||
names_v1.extend(getattr(symbol, _TENSORFLOW_API_ATTR_V1))
|
||||
if hasattr(symbol, _ESTIMATOR_API_ATTR_V1):
|
||||
names_v1.extend(getattr(symbol, _ESTIMATOR_API_ATTR_V1))
|
||||
return names_v1
|
||||
|
||||
|
||||
def get_v2_names(symbol):
|
||||
names_v2 = []
|
||||
if hasattr(symbol, _TENSORFLOW_API_ATTR):
|
||||
names_v2.extend(getattr(symbol, _TENSORFLOW_API_ATTR))
|
||||
if hasattr(symbol, _ESTIMATOR_API_ATTR):
|
||||
names_v2.extend(getattr(symbol, _ESTIMATOR_API_ATTR))
|
||||
return list(names_v2)
|
||||
|
||||
|
||||
def get_v1_constants(module):
|
||||
constants_v1 = []
|
||||
if hasattr(module, _TENSORFLOW_CONSTANTS_ATTR_V1):
|
||||
constants_v1.extend(getattr(module, _TENSORFLOW_CONSTANTS_ATTR_V1))
|
||||
if hasattr(module, _ESTIMATOR_CONSTANTS_ATTR_V1):
|
||||
constants_v1.extend(getattr(module, _ESTIMATOR_CONSTANTS_ATTR_V1))
|
||||
return constants_v1
|
||||
|
||||
|
||||
def get_v2_constants(module):
|
||||
constants_v2 = []
|
||||
if hasattr(module, _TENSORFLOW_CONSTANTS_ATTR):
|
||||
constants_v2.extend(getattr(module, _TENSORFLOW_CONSTANTS_ATTR))
|
||||
if hasattr(module, _ESTIMATOR_CONSTANTS_ATTR):
|
||||
constants_v2.extend(getattr(module, _ESTIMATOR_CONSTANTS_ATTR))
|
||||
return constants_v2
|
||||
|
||||
|
||||
def get_canonical_name(v2_names, v1_name):
|
||||
if v2_names:
|
||||
@ -131,7 +79,7 @@ def get_all_v2_names():
|
||||
"""Visitor that collects TF 2.0 names."""
|
||||
for child in children:
|
||||
_, attr = tf_decorator.unwrap(child[1])
|
||||
api_names_v2 = get_v2_names(attr)
|
||||
api_names_v2 = tf_export.get_v2_names(attr)
|
||||
for name in api_names_v2:
|
||||
v2_names.add(name)
|
||||
|
||||
@ -149,8 +97,8 @@ def collect_constant_renames():
|
||||
"""
|
||||
renames = set()
|
||||
for module in sys.modules.values():
|
||||
constants_v1_list = get_v1_constants(module)
|
||||
constants_v2_list = get_v2_constants(module)
|
||||
constants_v1_list = tf_export.get_v1_constants(module)
|
||||
constants_v2_list = tf_export.get_v2_constants(module)
|
||||
|
||||
# _tf_api_constants attribute contains a list of tuples:
|
||||
# (api_names_list, constant_name)
|
||||
@ -186,8 +134,8 @@ def collect_function_renames():
|
||||
"""Visitor that collects rename strings to add to rename_line_set."""
|
||||
for child in children:
|
||||
_, attr = tf_decorator.unwrap(child[1])
|
||||
api_names_v1 = get_v1_names(attr)
|
||||
api_names_v2 = get_v2_names(attr)
|
||||
api_names_v1 = tf_export.get_v1_names(attr)
|
||||
api_names_v2 = tf_export.get_v2_names(attr)
|
||||
deprecated_api_names = set(api_names_v1) - set(api_names_v2)
|
||||
for name in deprecated_api_names:
|
||||
renames.add((name, get_canonical_name(api_names_v2, name)))
|
||||
|
@ -64,40 +64,6 @@ from __future__ import print_function
|
||||
|
||||
"""
|
||||
|
||||
_TENSORFLOW_API_ATTR_V1 = (
|
||||
tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].names)
|
||||
_TENSORFLOW_API_ATTR = tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names
|
||||
_TENSORFLOW_CONSTANTS_ATTR_V1 = (
|
||||
tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].constants)
|
||||
_TENSORFLOW_CONSTANTS_ATTR = (
|
||||
tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].constants)
|
||||
|
||||
_ESTIMATOR_API_ATTR_V1 = (
|
||||
tf_export.API_ATTRS_V1[tf_export.ESTIMATOR_API_NAME].names)
|
||||
_ESTIMATOR_API_ATTR = tf_export.API_ATTRS[tf_export.ESTIMATOR_API_NAME].names
|
||||
_ESTIMATOR_CONSTANTS_ATTR_V1 = (
|
||||
tf_export.API_ATTRS_V1[tf_export.ESTIMATOR_API_NAME].constants)
|
||||
_ESTIMATOR_CONSTANTS_ATTR = (
|
||||
tf_export.API_ATTRS[tf_export.ESTIMATOR_API_NAME].constants)
|
||||
|
||||
|
||||
def get_v1_names(symbol):
|
||||
names_v1 = []
|
||||
if hasattr(symbol, _TENSORFLOW_API_ATTR_V1):
|
||||
names_v1.extend(getattr(symbol, _TENSORFLOW_API_ATTR_V1))
|
||||
if hasattr(symbol, _ESTIMATOR_API_ATTR_V1):
|
||||
names_v1.extend(getattr(symbol, _ESTIMATOR_API_ATTR_V1))
|
||||
return names_v1
|
||||
|
||||
|
||||
def get_v2_names(symbol):
|
||||
names_v2 = []
|
||||
if hasattr(symbol, _TENSORFLOW_API_ATTR):
|
||||
names_v2.extend(getattr(symbol, _TENSORFLOW_API_ATTR))
|
||||
if hasattr(symbol, _ESTIMATOR_API_ATTR):
|
||||
names_v2.extend(getattr(symbol, _ESTIMATOR_API_ATTR))
|
||||
return list(names_v2)
|
||||
|
||||
|
||||
def collect_function_arg_names(function_names):
|
||||
"""Determines argument names for reordered function signatures.
|
||||
@ -115,7 +81,7 @@ def collect_function_arg_names(function_names):
|
||||
"""Visitor that collects arguments for reordered functions."""
|
||||
for child in children:
|
||||
_, attr = tf_decorator.unwrap(child[1])
|
||||
api_names_v1 = get_v1_names(attr)
|
||||
api_names_v1 = tf_export.get_v1_names(attr)
|
||||
api_names_v1 = ['tf.%s' % name for name in api_names_v1]
|
||||
matches_function_names = any(
|
||||
name in function_names for name in api_names_v1)
|
||||
|
Loading…
Reference in New Issue
Block a user