diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index c515cc76b9a..56d65d45faf 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -119,11 +119,11 @@ def _running_from_pip_package(): _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) if _running_from_pip_package(): - for _s in _site_packages_dirs: + for s in _site_packages_dirs: # TODO(gunan): Add sanity checks to loaded modules here. - _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') - if _fi.file_exists(_plugin_dir): - _ll.load_library(_plugin_dir) + plugin_dir = _os.path.join(s, 'tensorflow-plugins') + if _fi.file_exists(plugin_dir): + _ll.load_library(plugin_dir) # Add module aliases if hasattr(_current_module, 'keras'): @@ -136,5 +136,3 @@ if hasattr(_current_module, 'keras'): setattr(_current_module, "optimizers", optimizers) setattr(_current_module, "initializers", initializers) # pylint: enable=undefined-variable - -# __all__ PLACEHOLDER diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 2b2899c3fe0..97478a18b8a 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -132,10 +132,9 @@ def _running_from_pip_package(): _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) if _running_from_pip_package(): - for _s in _site_packages_dirs: + for s in _site_packages_dirs: # TODO(gunan): Add sanity checks to loaded modules here. - _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') - if _fi.file_exists(_plugin_dir): - _ll.load_library(_plugin_dir) + plugin_dir = _os.path.join(s, 'tensorflow-plugins') + if _fi.file_exists(plugin_dir): + _ll.load_library(plugin_dir) -# __all__ PLACEHOLDER diff --git a/tensorflow/python/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py index 80f663683c3..3af677322d6 100644 --- a/tensorflow/python/tools/api/generator/create_python_api.py +++ b/tensorflow/python/tools/api/generator/create_python_api.py @@ -243,12 +243,11 @@ class _ModuleInitCodeBuilder(object): # from it using * import. Don't need this for lazy_loading because the # underscore symbols are already included in __all__ when passed in and # handled by TFModuleWrapper. - root_module_footer = '' if not self._lazy_loading: underscore_names_str = ', '.join( '\'%s\'' % name for name in self._underscore_names_in_root) - root_module_footer = ''' + module_text_map[''] = module_text_map.get('', '') + ''' _names_with_underscore = [%s] __all__ = [_s for _s in dir() if not _s.startswith('_')] __all__.extend([_s for _s in _names_with_underscore]) @@ -274,7 +273,7 @@ __all__.extend([_s for _s in _names_with_underscore]) footer_text_map[dest_module] = _DEPRECATION_FOOTER % ( dest_module, public_apis_name, deprecation, has_lite) - return module_text_map, footer_text_map, root_module_footer + return module_text_map, footer_text_map def format_import(self, source_module_name, source_name, dest_name): """Formats import statement. @@ -621,11 +620,7 @@ def create_api_files(output_files, packages, root_init_template, output_dir, os.makedirs(os.path.dirname(file_path)) open(file_path, 'a').close() - ( - module_text_map, - deprecation_footer_map, - root_module_footer, - ) = get_api_init_text( + module_text_map, deprecation_footer_map = get_api_init_text( packages, output_package, api_name, api_version, compat_api_versions, lazy_loading, use_relative_imports) @@ -657,7 +652,6 @@ def create_api_files(output_files, packages, root_init_template, output_dir, with open(root_init_template, 'r') as root_init_template_file: contents = root_init_template_file.read() contents = contents.replace('# API IMPORTS PLACEHOLDER', text) - contents = contents.replace('# __all__ PLACEHOLDER', root_module_footer) elif module in compat_module_to_template: # Read base init file for compat module with open(compat_module_to_template[module], 'r') as init_template_file: diff --git a/tensorflow/python/tools/api/generator/create_python_api_test.py b/tensorflow/python/tools/api/generator/create_python_api_test.py index 76404d6c82b..010f189dcb2 100644 --- a/tensorflow/python/tools/api/generator/create_python_api_test.py +++ b/tensorflow/python/tools/api/generator/create_python_api_test.py @@ -62,7 +62,7 @@ class CreatePythonApiTest(test.TestCase): del sys.modules[_MODULE_NAME] def testFunctionImportIsAdded(self): - imports, _, _ = create_python_api.get_api_init_text( + imports, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', @@ -97,7 +97,7 @@ class CreatePythonApiTest(test.TestCase): msg='compat.v1 in %s' % str(imports.keys())) def testClassImportIsAdded(self): - imports, _, _ = create_python_api.get_api_init_text( + imports, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', @@ -116,7 +116,7 @@ class CreatePythonApiTest(test.TestCase): msg='%s not in %s' % (expected_import, str(imports))) def testConstantIsAdded(self): - imports, _, _ = create_python_api.get_api_init_text( + imports, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', @@ -132,7 +132,7 @@ class CreatePythonApiTest(test.TestCase): msg='%s not in %s' % (expected, str(imports))) def testCompatModuleIsAdded(self): - imports, _, _ = create_python_api.get_api_init_text( + imports, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', @@ -144,7 +144,7 @@ class CreatePythonApiTest(test.TestCase): msg='compat.v1.test not in %s' % str(imports.keys())) def testNestedCompatModulesAreAdded(self): - imports, _, _ = create_python_api.get_api_init_text( + imports, _ = create_python_api.get_api_init_text( packages=[create_python_api._DEFAULT_PACKAGE], output_package='tensorflow', api_name='tensorflow', diff --git a/tensorflow/virtual_root_template_v1.__init__.py b/tensorflow/virtual_root_template_v1.__init__.py index 9a45bc0355d..236e9f52258 100644 --- a/tensorflow/virtual_root_template_v1.__init__.py +++ b/tensorflow/virtual_root_template_v1.__init__.py @@ -132,4 +132,7 @@ try: except NameError: pass +# Manually patch keras and estimator so tf.keras and tf.estimator work +keras = _sys.modules["tensorflow.keras"] +if not _root_estimator: estimator = _sys.modules["tensorflow.estimator"] # LINT.ThenChange(//tensorflow/virtual_root_template_v2.__init__.py.oss) diff --git a/tensorflow/virtual_root_template_v2.__init__.py b/tensorflow/virtual_root_template_v2.__init__.py index bd8c903e455..83c020182a8 100644 --- a/tensorflow/virtual_root_template_v2.__init__.py +++ b/tensorflow/virtual_root_template_v2.__init__.py @@ -126,4 +126,14 @@ try: except NameError: pass +# TODO(mihaimaruseac): Revisit all of this once we release 2.1 +# Manually patch keras and estimator so tf.keras and tf.estimator work +keras = _sys.modules["tensorflow.keras"] +if not _root_estimator: estimator = _sys.modules["tensorflow.estimator"] +# Also import module aliases +try: + from tensorflow_core import losses, metrics, initializers, optimizers +except ImportError: + pass + # LINT.ThenChange(//tensorflow/virtual_root_template_v1.__init__.py.oss)