Enable all tests in api_compatibility_test.py in both v1 and v2.
Preferably import symbols from the files where they were defined when generating imports in create_python_api.py. PiperOrigin-RevId: 254610838
This commit is contained in:
		
							parent
							
								
									eeda570f69
								
							
						
					
					
						commit
						a45a9d2ac8
					
				| @ -94,10 +94,13 @@ def get_canonical_import(import_set): | |||||||
| 
 | 
 | ||||||
|   One symbol might come from multiple places as it is being imported and |   One symbol might come from multiple places as it is being imported and | ||||||
|   reexported. To simplify API changes, we always use the same import for the |   reexported. To simplify API changes, we always use the same import for the | ||||||
|   same module, and give preference to imports coming from main tensorflow code. |   same module, and give preference based on higher priority and alphabetical | ||||||
|  |   ordering. | ||||||
| 
 | 
 | ||||||
|   Args: |   Args: | ||||||
|     import_set: (set) Imports providing the same symbol |     import_set: (set) Imports providing the same symbol. This is a set of | ||||||
|  |       tuples in the form (import, priority). We want to pick an import | ||||||
|  |       with highest priority. | ||||||
| 
 | 
 | ||||||
|   Returns: |   Returns: | ||||||
|     A module name to import |     A module name to import | ||||||
| @ -105,9 +108,12 @@ def get_canonical_import(import_set): | |||||||
|   # We use the fact that list sorting is stable, so first we convert the set to |   # We use the fact that list sorting is stable, so first we convert the set to | ||||||
|   # a sorted list of the names and then we resort this list to move elements |   # a sorted list of the names and then we resort this list to move elements | ||||||
|   # not in core tensorflow to the end. |   # not in core tensorflow to the end. | ||||||
|   import_list = sorted(import_set) |   # Here we sort by priority (higher preferred) and then  alphabetically by | ||||||
|   import_list.sort(key=lambda x: 'lite' in x) |   # import string. | ||||||
|   return import_list[0] |   import_list = sorted( | ||||||
|  |       import_set, | ||||||
|  |       key=lambda imp_and_priority: (-imp_and_priority[1], imp_and_priority[0])) | ||||||
|  |   return import_list[0][0] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class _ModuleInitCodeBuilder(object): | class _ModuleInitCodeBuilder(object): | ||||||
| @ -115,10 +121,12 @@ class _ModuleInitCodeBuilder(object): | |||||||
| 
 | 
 | ||||||
|   def __init__(self, output_package, api_version): |   def __init__(self, output_package, api_version): | ||||||
|     self._output_package = output_package |     self._output_package = output_package | ||||||
|  |     # Maps API module to API symbol name to set of tuples of the form | ||||||
|  |     # (module name, priority). | ||||||
|  |     # The same symbol can be imported from multiple locations. Higher | ||||||
|  |     # "priority" indicates that import location is preferred over others. | ||||||
|     self._module_imports = collections.defaultdict( |     self._module_imports = collections.defaultdict( | ||||||
|         lambda: collections.defaultdict(set)) |         lambda: collections.defaultdict(set)) | ||||||
|     self._deprecated_module_imports = collections.defaultdict( |  | ||||||
|         lambda: collections.defaultdict(set)) |  | ||||||
|     self._dest_import_to_id = collections.defaultdict(int) |     self._dest_import_to_id = collections.defaultdict(int) | ||||||
|     # Names that start with underscore in the root module. |     # Names that start with underscore in the root module. | ||||||
|     self._underscore_names_in_root = [] |     self._underscore_names_in_root = [] | ||||||
| @ -134,15 +142,15 @@ class _ModuleInitCodeBuilder(object): | |||||||
|     self._dest_import_to_id[api_name] = symbol_id |     self._dest_import_to_id[api_name] = symbol_id | ||||||
| 
 | 
 | ||||||
|   def add_import( |   def add_import( | ||||||
|       self, symbol_id, dest_module_name, source_module_name, source_name, |       self, symbol, source_module_name, source_name, dest_module_name, | ||||||
|       dest_name): |       dest_name): | ||||||
|     """Adds this import to module_imports. |     """Adds this import to module_imports. | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
|       symbol_id: (number) Unique identifier of the symbol to import. |       symbol: TensorFlow Python symbol. | ||||||
|       dest_module_name: (string) Module name to add import to. |  | ||||||
|       source_module_name: (string) Module to import from. |       source_module_name: (string) Module to import from. | ||||||
|       source_name: (string) Name of the symbol to import. |       source_name: (string) Name of the symbol to import. | ||||||
|  |       dest_module_name: (string) Module name to add import to. | ||||||
|       dest_name: (string) Import the symbol using this name. |       dest_name: (string) Import the symbol using this name. | ||||||
| 
 | 
 | ||||||
|     Raises: |     Raises: | ||||||
| @ -155,6 +163,7 @@ class _ModuleInitCodeBuilder(object): | |||||||
|     full_api_name = dest_name |     full_api_name = dest_name | ||||||
|     if dest_module_name: |     if dest_module_name: | ||||||
|       full_api_name = dest_module_name + '.' + full_api_name |       full_api_name = dest_module_name + '.' + full_api_name | ||||||
|  |     symbol_id = -1 if not symbol else id(symbol) | ||||||
|     self._check_already_imported(symbol_id, full_api_name) |     self._check_already_imported(symbol_id, full_api_name) | ||||||
| 
 | 
 | ||||||
|     if not dest_module_name and dest_name.startswith('_'): |     if not dest_module_name and dest_name.startswith('_'): | ||||||
| @ -163,7 +172,13 @@ class _ModuleInitCodeBuilder(object): | |||||||
|     # The same symbol can be available in multiple modules. |     # The same symbol can be available in multiple modules. | ||||||
|     # We store all possible ways of importing this symbol and later pick just |     # We store all possible ways of importing this symbol and later pick just | ||||||
|     # one. |     # one. | ||||||
|     self._module_imports[dest_module_name][full_api_name].add(import_str) |     priority = 0 | ||||||
|  |     if symbol and hasattr(symbol, '__module__'): | ||||||
|  |       # Give higher priority to source module if it matches | ||||||
|  |       # symbol's original module. | ||||||
|  |       priority = int(source_module_name == symbol.__module__) | ||||||
|  |     self._module_imports[dest_module_name][full_api_name].add( | ||||||
|  |         (import_str, priority)) | ||||||
| 
 | 
 | ||||||
|   def _import_submodules(self): |   def _import_submodules(self): | ||||||
|     """Add imports for all destination modules in self._module_imports.""" |     """Add imports for all destination modules in self._module_imports.""" | ||||||
| @ -171,8 +186,6 @@ class _ModuleInitCodeBuilder(object): | |||||||
|     # For e.g. if we import 'foo.bar.Value'. Then, we also |     # For e.g. if we import 'foo.bar.Value'. Then, we also | ||||||
|     # import 'bar' in 'foo'. |     # import 'bar' in 'foo'. | ||||||
|     imported_modules = set(self._module_imports.keys()) |     imported_modules = set(self._module_imports.keys()) | ||||||
|     imported_modules = imported_modules.union( |  | ||||||
|         set(self._deprecated_module_imports.keys())) |  | ||||||
|     for module in imported_modules: |     for module in imported_modules: | ||||||
|       if not module: |       if not module: | ||||||
|         continue |         continue | ||||||
| @ -187,8 +200,8 @@ class _ModuleInitCodeBuilder(object): | |||||||
|         if submodule_index > 0: |         if submodule_index > 0: | ||||||
|           import_from += '.' + '.'.join(module_split[:submodule_index]) |           import_from += '.' + '.'.join(module_split[:submodule_index]) | ||||||
|         self.add_import( |         self.add_import( | ||||||
|             -1, parent_module, import_from, |             None, import_from, module_split[submodule_index], | ||||||
|             module_split[submodule_index], module_split[submodule_index]) |             parent_module, module_split[submodule_index]) | ||||||
| 
 | 
 | ||||||
|   def build(self): |   def build(self): | ||||||
|     """Get a map from destination module to __init__.py code for that module. |     """Get a map from destination module to __init__.py code for that module. | ||||||
| @ -296,7 +309,7 @@ def add_imports_for_symbol( | |||||||
|         dest_module, dest_name = _get_name_and_module(export) |         dest_module, dest_name = _get_name_and_module(export) | ||||||
|         dest_module = _join_modules(output_module_prefix, dest_module) |         dest_module = _join_modules(output_module_prefix, dest_module) | ||||||
|         module_code_builder.add_import( |         module_code_builder.add_import( | ||||||
|             -1, dest_module, source_module_name, name, dest_name) |             None, source_module_name, name, dest_module, dest_name) | ||||||
| 
 | 
 | ||||||
|   # If symbol has _tf_api_names attribute, then add import for it. |   # If symbol has _tf_api_names attribute, then add import for it. | ||||||
|   if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__): |   if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__): | ||||||
| @ -306,7 +319,7 @@ def add_imports_for_symbol( | |||||||
|       dest_module, dest_name = _get_name_and_module(export) |       dest_module, dest_name = _get_name_and_module(export) | ||||||
|       dest_module = _join_modules(output_module_prefix, dest_module) |       dest_module = _join_modules(output_module_prefix, dest_module) | ||||||
|       module_code_builder.add_import( |       module_code_builder.add_import( | ||||||
|           id(symbol), dest_module, source_module_name, source_name, dest_name) |           symbol, source_module_name, source_name, dest_module, dest_name) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_api_init_text(packages, | def get_api_init_text(packages, | ||||||
|  | |||||||
| @ -38,7 +38,6 @@ import tensorflow as tf | |||||||
| from google.protobuf import message | from google.protobuf import message | ||||||
| from google.protobuf import text_format | from google.protobuf import text_format | ||||||
| 
 | 
 | ||||||
| from tensorflow.python.framework import test_util |  | ||||||
| from tensorflow.python.lib.io import file_io | from tensorflow.python.lib.io import file_io | ||||||
| from tensorflow.python.platform import resource_loader | from tensorflow.python.platform import resource_loader | ||||||
| from tensorflow.python.platform import test | from tensorflow.python.platform import test | ||||||
| @ -355,7 +354,6 @@ class ApiCompatibilityTest(test.TestCase): | |||||||
|         update_goldens=FLAGS.update_goldens, |         update_goldens=FLAGS.update_goldens, | ||||||
|         api_version=api_version) |         api_version=api_version) | ||||||
| 
 | 
 | ||||||
|   @test_util.run_v1_only('b/120545219') |  | ||||||
|   def testAPIBackwardsCompatibility(self): |   def testAPIBackwardsCompatibility(self): | ||||||
|     api_version = 2 if '_api.v2' in tf.bitwise.__name__ else 1 |     api_version = 2 if '_api.v2' in tf.bitwise.__name__ else 1 | ||||||
|     golden_file_pattern = os.path.join( |     golden_file_pattern = os.path.join( | ||||||
| @ -378,10 +376,12 @@ class ApiCompatibilityTest(test.TestCase): | |||||||
| 
 | 
 | ||||||
|     # Also check that V1 API has contrib |     # Also check that V1 API has contrib | ||||||
|     self.assertTrue( |     self.assertTrue( | ||||||
|  |         api_version == 2 or | ||||||
|         'tensorflow.python.util.lazy_loader.LazyLoader' |         'tensorflow.python.util.lazy_loader.LazyLoader' | ||||||
|         in str(type(tf.contrib))) |         in str(type(tf.contrib))) | ||||||
|  |     # Check that V2 API does not have contrib | ||||||
|  |     self.assertTrue(api_version == 1 or not hasattr(tf, 'contrib')) | ||||||
| 
 | 
 | ||||||
|   @test_util.run_v1_only('b/120545219') |  | ||||||
|   def testAPIBackwardsCompatibilityV1(self): |   def testAPIBackwardsCompatibilityV1(self): | ||||||
|     api_version = 1 |     api_version = 1 | ||||||
|     golden_file_pattern = os.path.join( |     golden_file_pattern = os.path.join( | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user