From 82e2d5a454137d9a96a57f4bef6f0e0ebabc8f03 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Thu, 19 Dec 2019 08:53:06 -0800 Subject: [PATCH] Revert "Revert "[r2.1 cherry-pick] Fix pip package API generation"" --- tensorflow/api_template.__init__.py | 10 ++++++---- tensorflow/api_template_v1.__init__.py | 9 +++++---- .../python/tools/api/generator/create_python_api.py | 12 +++++++++--- .../tools/api/generator/create_python_api_test.py | 10 +++++----- tensorflow/virtual_root_template_v1.__init__.py | 3 --- tensorflow/virtual_root_template_v2.__init__.py | 10 ---------- 6 files changed, 25 insertions(+), 29 deletions(-) diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 56d65d45faf..c515cc76b9a 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -119,11 +119,11 @@ def _running_from_pip_package(): _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) if _running_from_pip_package(): - for s in _site_packages_dirs: + for _s in _site_packages_dirs: # TODO(gunan): Add sanity checks to loaded modules here. - plugin_dir = _os.path.join(s, 'tensorflow-plugins') - if _fi.file_exists(plugin_dir): - _ll.load_library(plugin_dir) + _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') + if _fi.file_exists(_plugin_dir): + _ll.load_library(_plugin_dir) # Add module aliases if hasattr(_current_module, 'keras'): @@ -136,3 +136,5 @@ if hasattr(_current_module, 'keras'): setattr(_current_module, "optimizers", optimizers) setattr(_current_module, "initializers", initializers) # pylint: enable=undefined-variable + +# __all__ PLACEHOLDER diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 97478a18b8a..2b2899c3fe0 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -132,9 +132,10 @@ def _running_from_pip_package(): _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) if _running_from_pip_package(): - for s in _site_packages_dirs: + for _s in _site_packages_dirs: # TODO(gunan): Add sanity checks to loaded modules here. - plugin_dir = _os.path.join(s, 'tensorflow-plugins') - if _fi.file_exists(plugin_dir): - _ll.load_library(plugin_dir) + _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') + if _fi.file_exists(_plugin_dir): + _ll.load_library(_plugin_dir) +# __all__ PLACEHOLDER diff --git a/tensorflow/python/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py index 3af677322d6..80f663683c3 100644 --- a/tensorflow/python/tools/api/generator/create_python_api.py +++ b/tensorflow/python/tools/api/generator/create_python_api.py @@ -243,11 +243,12 @@ class _ModuleInitCodeBuilder(object): # from it using * import. Don't need this for lazy_loading because the # underscore symbols are already included in __all__ when passed in and # handled by TFModuleWrapper. + root_module_footer = '' if not self._lazy_loading: underscore_names_str = ', '.join( '\'%s\'' % name for name in self._underscore_names_in_root) - module_text_map[''] = module_text_map.get('', '') + ''' + root_module_footer = ''' _names_with_underscore = [%s] __all__ = [_s for _s in dir() if not _s.startswith('_')] __all__.extend([_s for _s in _names_with_underscore]) @@ -273,7 +274,7 @@ __all__.extend([_s for _s in _names_with_underscore]) footer_text_map[dest_module] = _DEPRECATION_FOOTER % ( dest_module, public_apis_name, deprecation, has_lite) - return module_text_map, footer_text_map + return module_text_map, footer_text_map, root_module_footer def format_import(self, source_module_name, source_name, dest_name): """Formats import statement. @@ -620,7 +621,11 @@ def create_api_files(output_files, packages, root_init_template, output_dir, os.makedirs(os.path.dirname(file_path)) open(file_path, 'a').close() - module_text_map, deprecation_footer_map = get_api_init_text( + ( + module_text_map, + deprecation_footer_map, + root_module_footer, + ) = get_api_init_text( packages, output_package, api_name, api_version, compat_api_versions, lazy_loading, use_relative_imports) @@ -652,6 +657,7 @@ def create_api_files(output_files, packages, root_init_template, output_dir, with open(root_init_template, 'r') as root_init_template_file: contents = root_init_template_file.read() contents = contents.replace('# API IMPORTS PLACEHOLDER', text) + contents = contents.replace('# __all__ PLACEHOLDER', root_module_footer) elif module in compat_module_to_template: # Read base init file for compat module with open(compat_module_to_template[module], 'r') as init_template_file: diff --git a/tensorflow/python/tools/api/generator/create_python_api_test.py b/tensorflow/python/tools/api/generator/create_python_api_test.py index 010f189dcb2..76404d6c82b 100644 --- a/tensorflow/python/tools/api/generator/create_python_api_test.py +++ b/tensorflow/python/tools/api/generator/create_python_api_test.py @@ -62,7 +62,7 @@ class CreatePythonApiTest(test.TestCase): del sys.modules[_MODULE_NAME] def testFunctionImportIsAdded(self): - imports, _ = create_python_api.get_api_init_text( + imports, _, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', @@ -97,7 +97,7 @@ class CreatePythonApiTest(test.TestCase): msg='compat.v1 in %s' % str(imports.keys())) def testClassImportIsAdded(self): - imports, _ = create_python_api.get_api_init_text( + imports, _, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', @@ -116,7 +116,7 @@ class CreatePythonApiTest(test.TestCase): msg='%s not in %s' % (expected_import, str(imports))) def testConstantIsAdded(self): - imports, _ = create_python_api.get_api_init_text( + imports, _, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', @@ -132,7 +132,7 @@ class CreatePythonApiTest(test.TestCase): msg='%s not in %s' % (expected, str(imports))) def testCompatModuleIsAdded(self): - imports, _ = create_python_api.get_api_init_text( + imports, _, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', @@ -144,7 +144,7 @@ class CreatePythonApiTest(test.TestCase): msg='compat.v1.test not in %s' % str(imports.keys())) def testNestedCompatModulesAreAdded(self): - imports, _ = create_python_api.get_api_init_text( + imports, _, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', diff --git a/tensorflow/virtual_root_template_v1.__init__.py b/tensorflow/virtual_root_template_v1.__init__.py index 236e9f52258..9a45bc0355d 100644 --- a/tensorflow/virtual_root_template_v1.__init__.py +++ b/tensorflow/virtual_root_template_v1.__init__.py @@ -132,7 +132,4 @@ try: except NameError: pass -# Manually patch keras and estimator so tf.keras and tf.estimator work -keras = _sys.modules["tensorflow.keras"] -if not _root_estimator: estimator = _sys.modules["tensorflow.estimator"] # LINT.ThenChange(//tensorflow/virtual_root_template_v2.__init__.py.oss) diff --git a/tensorflow/virtual_root_template_v2.__init__.py b/tensorflow/virtual_root_template_v2.__init__.py index 83c020182a8..bd8c903e455 100644 --- a/tensorflow/virtual_root_template_v2.__init__.py +++ b/tensorflow/virtual_root_template_v2.__init__.py @@ -126,14 +126,4 @@ try: except NameError: pass -# TODO(mihaimaruseac): Revisit all of this once we release 2.1 -# Manually patch keras and estimator so tf.keras and tf.estimator work -keras = _sys.modules["tensorflow.keras"] -if not _root_estimator: estimator = _sys.modules["tensorflow.estimator"] -# Also import module aliases -try: - from tensorflow_core import losses, metrics, initializers, optimizers -except ImportError: - pass - # LINT.ThenChange(//tensorflow/virtual_root_template_v1.__init__.py.oss)