From a45a9d2ac8862b8f7162f6189ab567ed194c02b7 Mon Sep 17 00:00:00 2001 From: Anna R Date: Sun, 23 Jun 2019 00:30:24 -0700 Subject: [PATCH] Enable all tests in api_compatibility_test.py in both v1 and v2. Preferably import symbols from the files where they were defined when generating imports in create_python_api.py. PiperOrigin-RevId: 254610838 --- .../tools/api/generator/create_python_api.py | 47 ++++++++++++------- .../tools/api/tests/api_compatibility_test.py | 6 +-- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/tensorflow/python/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py index 1bae5ef406a..7dd3f97b79d 100644 --- a/tensorflow/python/tools/api/generator/create_python_api.py +++ b/tensorflow/python/tools/api/generator/create_python_api.py @@ -94,10 +94,13 @@ def get_canonical_import(import_set): One symbol might come from multiple places as it is being imported and reexported. To simplify API changes, we always use the same import for the - same module, and give preference to imports coming from main tensorflow code. + same module, and give preference based on higher priority and alphabetical + ordering. Args: - import_set: (set) Imports providing the same symbol + import_set: (set) Imports providing the same symbol. This is a set of + tuples in the form (import, priority). We want to pick an import + with highest priority. Returns: A module name to import @@ -105,9 +108,12 @@ def get_canonical_import(import_set): # We use the fact that list sorting is stable, so first we convert the set to # a sorted list of the names and then we resort this list to move elements # not in core tensorflow to the end. - import_list = sorted(import_set) - import_list.sort(key=lambda x: 'lite' in x) - return import_list[0] + # Here we sort by priority (higher preferred) and then alphabetically by + # import string. + import_list = sorted( + import_set, + key=lambda imp_and_priority: (-imp_and_priority[1], imp_and_priority[0])) + return import_list[0][0] class _ModuleInitCodeBuilder(object): @@ -115,10 +121,12 @@ class _ModuleInitCodeBuilder(object): def __init__(self, output_package, api_version): self._output_package = output_package + # Maps API module to API symbol name to set of tuples of the form + # (module name, priority). + # The same symbol can be imported from multiple locations. Higher + # "priority" indicates that import location is preferred over others. self._module_imports = collections.defaultdict( lambda: collections.defaultdict(set)) - self._deprecated_module_imports = collections.defaultdict( - lambda: collections.defaultdict(set)) self._dest_import_to_id = collections.defaultdict(int) # Names that start with underscore in the root module. self._underscore_names_in_root = [] @@ -134,15 +142,15 @@ class _ModuleInitCodeBuilder(object): self._dest_import_to_id[api_name] = symbol_id def add_import( - self, symbol_id, dest_module_name, source_module_name, source_name, + self, symbol, source_module_name, source_name, dest_module_name, dest_name): """Adds this import to module_imports. Args: - symbol_id: (number) Unique identifier of the symbol to import. - dest_module_name: (string) Module name to add import to. + symbol: TensorFlow Python symbol. source_module_name: (string) Module to import from. source_name: (string) Name of the symbol to import. + dest_module_name: (string) Module name to add import to. dest_name: (string) Import the symbol using this name. Raises: @@ -155,6 +163,7 @@ class _ModuleInitCodeBuilder(object): full_api_name = dest_name if dest_module_name: full_api_name = dest_module_name + '.' + full_api_name + symbol_id = -1 if not symbol else id(symbol) self._check_already_imported(symbol_id, full_api_name) if not dest_module_name and dest_name.startswith('_'): @@ -163,7 +172,13 @@ class _ModuleInitCodeBuilder(object): # The same symbol can be available in multiple modules. # We store all possible ways of importing this symbol and later pick just # one. - self._module_imports[dest_module_name][full_api_name].add(import_str) + priority = 0 + if symbol and hasattr(symbol, '__module__'): + # Give higher priority to source module if it matches + # symbol's original module. + priority = int(source_module_name == symbol.__module__) + self._module_imports[dest_module_name][full_api_name].add( + (import_str, priority)) def _import_submodules(self): """Add imports for all destination modules in self._module_imports.""" @@ -171,8 +186,6 @@ class _ModuleInitCodeBuilder(object): # For e.g. if we import 'foo.bar.Value'. Then, we also # import 'bar' in 'foo'. imported_modules = set(self._module_imports.keys()) - imported_modules = imported_modules.union( - set(self._deprecated_module_imports.keys())) for module in imported_modules: if not module: continue @@ -187,8 +200,8 @@ class _ModuleInitCodeBuilder(object): if submodule_index > 0: import_from += '.' + '.'.join(module_split[:submodule_index]) self.add_import( - -1, parent_module, import_from, - module_split[submodule_index], module_split[submodule_index]) + None, import_from, module_split[submodule_index], + parent_module, module_split[submodule_index]) def build(self): """Get a map from destination module to __init__.py code for that module. @@ -296,7 +309,7 @@ def add_imports_for_symbol( dest_module, dest_name = _get_name_and_module(export) dest_module = _join_modules(output_module_prefix, dest_module) module_code_builder.add_import( - -1, dest_module, source_module_name, name, dest_name) + None, source_module_name, name, dest_module, dest_name) # If symbol has _tf_api_names attribute, then add import for it. if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__): @@ -306,7 +319,7 @@ def add_imports_for_symbol( dest_module, dest_name = _get_name_and_module(export) dest_module = _join_modules(output_module_prefix, dest_module) module_code_builder.add_import( - id(symbol), dest_module, source_module_name, source_name, dest_name) + symbol, source_module_name, source_name, dest_module, dest_name) def get_api_init_text(packages, diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index b1529370df9..9f4d27afbb5 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -38,7 +38,6 @@ import tensorflow as tf from google.protobuf import message from google.protobuf import text_format -from tensorflow.python.framework import test_util from tensorflow.python.lib.io import file_io from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test @@ -355,7 +354,6 @@ class ApiCompatibilityTest(test.TestCase): update_goldens=FLAGS.update_goldens, api_version=api_version) - @test_util.run_v1_only('b/120545219') def testAPIBackwardsCompatibility(self): api_version = 2 if '_api.v2' in tf.bitwise.__name__ else 1 golden_file_pattern = os.path.join( @@ -378,10 +376,12 @@ class ApiCompatibilityTest(test.TestCase): # Also check that V1 API has contrib self.assertTrue( + api_version == 2 or 'tensorflow.python.util.lazy_loader.LazyLoader' in str(type(tf.contrib))) + # Check that V2 API does not have contrib + self.assertTrue(api_version == 1 or not hasattr(tf, 'contrib')) - @test_util.run_v1_only('b/120545219') def testAPIBackwardsCompatibilityV1(self): api_version = 1 golden_file_pattern = os.path.join(