Enable all tests in api_compatibility_test.py in both v1 and v2.
Preferably import symbols from the files where they were defined when generating imports in create_python_api.py. PiperOrigin-RevId: 254610838
This commit is contained in:
parent
eeda570f69
commit
a45a9d2ac8
@ -94,10 +94,13 @@ def get_canonical_import(import_set):
|
||||
|
||||
One symbol might come from multiple places as it is being imported and
|
||||
reexported. To simplify API changes, we always use the same import for the
|
||||
same module, and give preference to imports coming from main tensorflow code.
|
||||
same module, and give preference based on higher priority and alphabetical
|
||||
ordering.
|
||||
|
||||
Args:
|
||||
import_set: (set) Imports providing the same symbol
|
||||
import_set: (set) Imports providing the same symbol. This is a set of
|
||||
tuples in the form (import, priority). We want to pick an import
|
||||
with highest priority.
|
||||
|
||||
Returns:
|
||||
A module name to import
|
||||
@ -105,9 +108,12 @@ def get_canonical_import(import_set):
|
||||
# We use the fact that list sorting is stable, so first we convert the set to
|
||||
# a sorted list of the names and then we resort this list to move elements
|
||||
# not in core tensorflow to the end.
|
||||
import_list = sorted(import_set)
|
||||
import_list.sort(key=lambda x: 'lite' in x)
|
||||
return import_list[0]
|
||||
# Here we sort by priority (higher preferred) and then alphabetically by
|
||||
# import string.
|
||||
import_list = sorted(
|
||||
import_set,
|
||||
key=lambda imp_and_priority: (-imp_and_priority[1], imp_and_priority[0]))
|
||||
return import_list[0][0]
|
||||
|
||||
|
||||
class _ModuleInitCodeBuilder(object):
|
||||
@ -115,10 +121,12 @@ class _ModuleInitCodeBuilder(object):
|
||||
|
||||
def __init__(self, output_package, api_version):
|
||||
self._output_package = output_package
|
||||
# Maps API module to API symbol name to set of tuples of the form
|
||||
# (module name, priority).
|
||||
# The same symbol can be imported from multiple locations. Higher
|
||||
# "priority" indicates that import location is preferred over others.
|
||||
self._module_imports = collections.defaultdict(
|
||||
lambda: collections.defaultdict(set))
|
||||
self._deprecated_module_imports = collections.defaultdict(
|
||||
lambda: collections.defaultdict(set))
|
||||
self._dest_import_to_id = collections.defaultdict(int)
|
||||
# Names that start with underscore in the root module.
|
||||
self._underscore_names_in_root = []
|
||||
@ -134,15 +142,15 @@ class _ModuleInitCodeBuilder(object):
|
||||
self._dest_import_to_id[api_name] = symbol_id
|
||||
|
||||
def add_import(
|
||||
self, symbol_id, dest_module_name, source_module_name, source_name,
|
||||
self, symbol, source_module_name, source_name, dest_module_name,
|
||||
dest_name):
|
||||
"""Adds this import to module_imports.
|
||||
|
||||
Args:
|
||||
symbol_id: (number) Unique identifier of the symbol to import.
|
||||
dest_module_name: (string) Module name to add import to.
|
||||
symbol: TensorFlow Python symbol.
|
||||
source_module_name: (string) Module to import from.
|
||||
source_name: (string) Name of the symbol to import.
|
||||
dest_module_name: (string) Module name to add import to.
|
||||
dest_name: (string) Import the symbol using this name.
|
||||
|
||||
Raises:
|
||||
@ -155,6 +163,7 @@ class _ModuleInitCodeBuilder(object):
|
||||
full_api_name = dest_name
|
||||
if dest_module_name:
|
||||
full_api_name = dest_module_name + '.' + full_api_name
|
||||
symbol_id = -1 if not symbol else id(symbol)
|
||||
self._check_already_imported(symbol_id, full_api_name)
|
||||
|
||||
if not dest_module_name and dest_name.startswith('_'):
|
||||
@ -163,7 +172,13 @@ class _ModuleInitCodeBuilder(object):
|
||||
# The same symbol can be available in multiple modules.
|
||||
# We store all possible ways of importing this symbol and later pick just
|
||||
# one.
|
||||
self._module_imports[dest_module_name][full_api_name].add(import_str)
|
||||
priority = 0
|
||||
if symbol and hasattr(symbol, '__module__'):
|
||||
# Give higher priority to source module if it matches
|
||||
# symbol's original module.
|
||||
priority = int(source_module_name == symbol.__module__)
|
||||
self._module_imports[dest_module_name][full_api_name].add(
|
||||
(import_str, priority))
|
||||
|
||||
def _import_submodules(self):
|
||||
"""Add imports for all destination modules in self._module_imports."""
|
||||
@ -171,8 +186,6 @@ class _ModuleInitCodeBuilder(object):
|
||||
# For e.g. if we import 'foo.bar.Value'. Then, we also
|
||||
# import 'bar' in 'foo'.
|
||||
imported_modules = set(self._module_imports.keys())
|
||||
imported_modules = imported_modules.union(
|
||||
set(self._deprecated_module_imports.keys()))
|
||||
for module in imported_modules:
|
||||
if not module:
|
||||
continue
|
||||
@ -187,8 +200,8 @@ class _ModuleInitCodeBuilder(object):
|
||||
if submodule_index > 0:
|
||||
import_from += '.' + '.'.join(module_split[:submodule_index])
|
||||
self.add_import(
|
||||
-1, parent_module, import_from,
|
||||
module_split[submodule_index], module_split[submodule_index])
|
||||
None, import_from, module_split[submodule_index],
|
||||
parent_module, module_split[submodule_index])
|
||||
|
||||
def build(self):
|
||||
"""Get a map from destination module to __init__.py code for that module.
|
||||
@ -296,7 +309,7 @@ def add_imports_for_symbol(
|
||||
dest_module, dest_name = _get_name_and_module(export)
|
||||
dest_module = _join_modules(output_module_prefix, dest_module)
|
||||
module_code_builder.add_import(
|
||||
-1, dest_module, source_module_name, name, dest_name)
|
||||
None, source_module_name, name, dest_module, dest_name)
|
||||
|
||||
# If symbol has _tf_api_names attribute, then add import for it.
|
||||
if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__):
|
||||
@ -306,7 +319,7 @@ def add_imports_for_symbol(
|
||||
dest_module, dest_name = _get_name_and_module(export)
|
||||
dest_module = _join_modules(output_module_prefix, dest_module)
|
||||
module_code_builder.add_import(
|
||||
id(symbol), dest_module, source_module_name, source_name, dest_name)
|
||||
symbol, source_module_name, source_name, dest_module, dest_name)
|
||||
|
||||
|
||||
def get_api_init_text(packages,
|
||||
|
@ -38,7 +38,6 @@ import tensorflow as tf
|
||||
from google.protobuf import message
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import test
|
||||
@ -355,7 +354,6 @@ class ApiCompatibilityTest(test.TestCase):
|
||||
update_goldens=FLAGS.update_goldens,
|
||||
api_version=api_version)
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
def testAPIBackwardsCompatibility(self):
|
||||
api_version = 2 if '_api.v2' in tf.bitwise.__name__ else 1
|
||||
golden_file_pattern = os.path.join(
|
||||
@ -378,10 +376,12 @@ class ApiCompatibilityTest(test.TestCase):
|
||||
|
||||
# Also check that V1 API has contrib
|
||||
self.assertTrue(
|
||||
api_version == 2 or
|
||||
'tensorflow.python.util.lazy_loader.LazyLoader'
|
||||
in str(type(tf.contrib)))
|
||||
# Check that V2 API does not have contrib
|
||||
self.assertTrue(api_version == 1 or not hasattr(tf, 'contrib'))
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
def testAPIBackwardsCompatibilityV1(self):
|
||||
api_version = 1
|
||||
golden_file_pattern = os.path.join(
|
||||
|
Loading…
x
Reference in New Issue
Block a user