Enable all tests in api_compatibility_test.py in both v1 and v2.

Preferably import symbols from the files where they were defined when generating imports in create_python_api.py.

PiperOrigin-RevId: 254610838
This commit is contained in:
Anna R 2019-06-23 00:30:24 -07:00 committed by TensorFlower Gardener
parent eeda570f69
commit a45a9d2ac8
2 changed files with 33 additions and 20 deletions

View File

@ -94,10 +94,13 @@ def get_canonical_import(import_set):
One symbol might come from multiple places as it is being imported and
reexported. To simplify API changes, we always use the same import for the
same module, and give preference to imports coming from main tensorflow code.
same module, and give preference based on higher priority and alphabetical
ordering.
Args:
import_set: (set) Imports providing the same symbol
import_set: (set) Imports providing the same symbol. This is a set of
tuples in the form (import, priority). We want to pick an import
with highest priority.
Returns:
A module name to import
@ -105,9 +108,12 @@ def get_canonical_import(import_set):
# We use the fact that list sorting is stable, so first we convert the set to
# a sorted list of the names and then we resort this list to move elements
# not in core tensorflow to the end.
import_list = sorted(import_set)
import_list.sort(key=lambda x: 'lite' in x)
return import_list[0]
# Here we sort by priority (higher preferred) and then alphabetically by
# import string.
import_list = sorted(
import_set,
key=lambda imp_and_priority: (-imp_and_priority[1], imp_and_priority[0]))
return import_list[0][0]
class _ModuleInitCodeBuilder(object):
@ -115,10 +121,12 @@ class _ModuleInitCodeBuilder(object):
def __init__(self, output_package, api_version):
self._output_package = output_package
# Maps API module to API symbol name to set of tuples of the form
# (module name, priority).
# The same symbol can be imported from multiple locations. Higher
# "priority" indicates that import location is preferred over others.
self._module_imports = collections.defaultdict(
lambda: collections.defaultdict(set))
self._deprecated_module_imports = collections.defaultdict(
lambda: collections.defaultdict(set))
self._dest_import_to_id = collections.defaultdict(int)
# Names that start with underscore in the root module.
self._underscore_names_in_root = []
@ -134,15 +142,15 @@ class _ModuleInitCodeBuilder(object):
self._dest_import_to_id[api_name] = symbol_id
def add_import(
self, symbol_id, dest_module_name, source_module_name, source_name,
self, symbol, source_module_name, source_name, dest_module_name,
dest_name):
"""Adds this import to module_imports.
Args:
symbol_id: (number) Unique identifier of the symbol to import.
dest_module_name: (string) Module name to add import to.
symbol: TensorFlow Python symbol.
source_module_name: (string) Module to import from.
source_name: (string) Name of the symbol to import.
dest_module_name: (string) Module name to add import to.
dest_name: (string) Import the symbol using this name.
Raises:
@ -155,6 +163,7 @@ class _ModuleInitCodeBuilder(object):
full_api_name = dest_name
if dest_module_name:
full_api_name = dest_module_name + '.' + full_api_name
symbol_id = -1 if not symbol else id(symbol)
self._check_already_imported(symbol_id, full_api_name)
if not dest_module_name and dest_name.startswith('_'):
@ -163,7 +172,13 @@ class _ModuleInitCodeBuilder(object):
# The same symbol can be available in multiple modules.
# We store all possible ways of importing this symbol and later pick just
# one.
self._module_imports[dest_module_name][full_api_name].add(import_str)
priority = 0
if symbol and hasattr(symbol, '__module__'):
# Give higher priority to source module if it matches
# symbol's original module.
priority = int(source_module_name == symbol.__module__)
self._module_imports[dest_module_name][full_api_name].add(
(import_str, priority))
def _import_submodules(self):
"""Add imports for all destination modules in self._module_imports."""
@ -171,8 +186,6 @@ class _ModuleInitCodeBuilder(object):
# For e.g. if we import 'foo.bar.Value'. Then, we also
# import 'bar' in 'foo'.
imported_modules = set(self._module_imports.keys())
imported_modules = imported_modules.union(
set(self._deprecated_module_imports.keys()))
for module in imported_modules:
if not module:
continue
@ -187,8 +200,8 @@ class _ModuleInitCodeBuilder(object):
if submodule_index > 0:
import_from += '.' + '.'.join(module_split[:submodule_index])
self.add_import(
-1, parent_module, import_from,
module_split[submodule_index], module_split[submodule_index])
None, import_from, module_split[submodule_index],
parent_module, module_split[submodule_index])
def build(self):
"""Get a map from destination module to __init__.py code for that module.
@ -296,7 +309,7 @@ def add_imports_for_symbol(
dest_module, dest_name = _get_name_and_module(export)
dest_module = _join_modules(output_module_prefix, dest_module)
module_code_builder.add_import(
-1, dest_module, source_module_name, name, dest_name)
None, source_module_name, name, dest_module, dest_name)
# If symbol has _tf_api_names attribute, then add import for it.
if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__):
@ -306,7 +319,7 @@ def add_imports_for_symbol(
dest_module, dest_name = _get_name_and_module(export)
dest_module = _join_modules(output_module_prefix, dest_module)
module_code_builder.add_import(
id(symbol), dest_module, source_module_name, source_name, dest_name)
symbol, source_module_name, source_name, dest_module, dest_name)
def get_api_init_text(packages,

View File

@ -38,7 +38,6 @@ import tensorflow as tf
from google.protobuf import message
from google.protobuf import text_format
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
@ -355,7 +354,6 @@ class ApiCompatibilityTest(test.TestCase):
update_goldens=FLAGS.update_goldens,
api_version=api_version)
@test_util.run_v1_only('b/120545219')
def testAPIBackwardsCompatibility(self):
api_version = 2 if '_api.v2' in tf.bitwise.__name__ else 1
golden_file_pattern = os.path.join(
@ -378,10 +376,12 @@ class ApiCompatibilityTest(test.TestCase):
# Also check that V1 API has contrib
self.assertTrue(
api_version == 2 or
'tensorflow.python.util.lazy_loader.LazyLoader'
in str(type(tf.contrib)))
# Check that V2 API does not have contrib
self.assertTrue(api_version == 1 or not hasattr(tf, 'contrib'))
@test_util.run_v1_only('b/120545219')
def testAPIBackwardsCompatibilityV1(self):
api_version = 1
golden_file_pattern = os.path.join(