Revert "Revert "[r2.1 cherry-pick] Fix pip package API generation""
This commit is contained in:
parent
c49396cf71
commit
82e2d5a454
@ -119,11 +119,11 @@ def _running_from_pip_package():
|
|||||||
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
|
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
|
||||||
|
|
||||||
if _running_from_pip_package():
|
if _running_from_pip_package():
|
||||||
for s in _site_packages_dirs:
|
for _s in _site_packages_dirs:
|
||||||
# TODO(gunan): Add sanity checks to loaded modules here.
|
# TODO(gunan): Add sanity checks to loaded modules here.
|
||||||
plugin_dir = _os.path.join(s, 'tensorflow-plugins')
|
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||||
if _fi.file_exists(plugin_dir):
|
if _fi.file_exists(_plugin_dir):
|
||||||
_ll.load_library(plugin_dir)
|
_ll.load_library(_plugin_dir)
|
||||||
|
|
||||||
# Add module aliases
|
# Add module aliases
|
||||||
if hasattr(_current_module, 'keras'):
|
if hasattr(_current_module, 'keras'):
|
||||||
@ -136,3 +136,5 @@ if hasattr(_current_module, 'keras'):
|
|||||||
setattr(_current_module, "optimizers", optimizers)
|
setattr(_current_module, "optimizers", optimizers)
|
||||||
setattr(_current_module, "initializers", initializers)
|
setattr(_current_module, "initializers", initializers)
|
||||||
# pylint: enable=undefined-variable
|
# pylint: enable=undefined-variable
|
||||||
|
|
||||||
|
# __all__ PLACEHOLDER
|
||||||
|
@ -132,9 +132,10 @@ def _running_from_pip_package():
|
|||||||
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
|
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
|
||||||
|
|
||||||
if _running_from_pip_package():
|
if _running_from_pip_package():
|
||||||
for s in _site_packages_dirs:
|
for _s in _site_packages_dirs:
|
||||||
# TODO(gunan): Add sanity checks to loaded modules here.
|
# TODO(gunan): Add sanity checks to loaded modules here.
|
||||||
plugin_dir = _os.path.join(s, 'tensorflow-plugins')
|
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
|
||||||
if _fi.file_exists(plugin_dir):
|
if _fi.file_exists(_plugin_dir):
|
||||||
_ll.load_library(plugin_dir)
|
_ll.load_library(_plugin_dir)
|
||||||
|
|
||||||
|
# __all__ PLACEHOLDER
|
||||||
|
@ -243,11 +243,12 @@ class _ModuleInitCodeBuilder(object):
|
|||||||
# from it using * import. Don't need this for lazy_loading because the
|
# from it using * import. Don't need this for lazy_loading because the
|
||||||
# underscore symbols are already included in __all__ when passed in and
|
# underscore symbols are already included in __all__ when passed in and
|
||||||
# handled by TFModuleWrapper.
|
# handled by TFModuleWrapper.
|
||||||
|
root_module_footer = ''
|
||||||
if not self._lazy_loading:
|
if not self._lazy_loading:
|
||||||
underscore_names_str = ', '.join(
|
underscore_names_str = ', '.join(
|
||||||
'\'%s\'' % name for name in self._underscore_names_in_root)
|
'\'%s\'' % name for name in self._underscore_names_in_root)
|
||||||
|
|
||||||
module_text_map[''] = module_text_map.get('', '') + '''
|
root_module_footer = '''
|
||||||
_names_with_underscore = [%s]
|
_names_with_underscore = [%s]
|
||||||
__all__ = [_s for _s in dir() if not _s.startswith('_')]
|
__all__ = [_s for _s in dir() if not _s.startswith('_')]
|
||||||
__all__.extend([_s for _s in _names_with_underscore])
|
__all__.extend([_s for _s in _names_with_underscore])
|
||||||
@ -273,7 +274,7 @@ __all__.extend([_s for _s in _names_with_underscore])
|
|||||||
footer_text_map[dest_module] = _DEPRECATION_FOOTER % (
|
footer_text_map[dest_module] = _DEPRECATION_FOOTER % (
|
||||||
dest_module, public_apis_name, deprecation, has_lite)
|
dest_module, public_apis_name, deprecation, has_lite)
|
||||||
|
|
||||||
return module_text_map, footer_text_map
|
return module_text_map, footer_text_map, root_module_footer
|
||||||
|
|
||||||
def format_import(self, source_module_name, source_name, dest_name):
|
def format_import(self, source_module_name, source_name, dest_name):
|
||||||
"""Formats import statement.
|
"""Formats import statement.
|
||||||
@ -620,7 +621,11 @@ def create_api_files(output_files, packages, root_init_template, output_dir,
|
|||||||
os.makedirs(os.path.dirname(file_path))
|
os.makedirs(os.path.dirname(file_path))
|
||||||
open(file_path, 'a').close()
|
open(file_path, 'a').close()
|
||||||
|
|
||||||
module_text_map, deprecation_footer_map = get_api_init_text(
|
(
|
||||||
|
module_text_map,
|
||||||
|
deprecation_footer_map,
|
||||||
|
root_module_footer,
|
||||||
|
) = get_api_init_text(
|
||||||
packages, output_package, api_name,
|
packages, output_package, api_name,
|
||||||
api_version, compat_api_versions, lazy_loading, use_relative_imports)
|
api_version, compat_api_versions, lazy_loading, use_relative_imports)
|
||||||
|
|
||||||
@ -652,6 +657,7 @@ def create_api_files(output_files, packages, root_init_template, output_dir,
|
|||||||
with open(root_init_template, 'r') as root_init_template_file:
|
with open(root_init_template, 'r') as root_init_template_file:
|
||||||
contents = root_init_template_file.read()
|
contents = root_init_template_file.read()
|
||||||
contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
|
contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
|
||||||
|
contents = contents.replace('# __all__ PLACEHOLDER', root_module_footer)
|
||||||
elif module in compat_module_to_template:
|
elif module in compat_module_to_template:
|
||||||
# Read base init file for compat module
|
# Read base init file for compat module
|
||||||
with open(compat_module_to_template[module], 'r') as init_template_file:
|
with open(compat_module_to_template[module], 'r') as init_template_file:
|
||||||
|
@ -62,7 +62,7 @@ class CreatePythonApiTest(test.TestCase):
|
|||||||
del sys.modules[_MODULE_NAME]
|
del sys.modules[_MODULE_NAME]
|
||||||
|
|
||||||
def testFunctionImportIsAdded(self):
|
def testFunctionImportIsAdded(self):
|
||||||
imports, _ = create_python_api.get_api_init_text(
|
imports, _, _ = create_python_api.get_api_init_text(
|
||||||
packages=[create_python_api._DEFAULT_PACKAGE],
|
packages=[create_python_api._DEFAULT_PACKAGE],
|
||||||
output_package='tensorflow',
|
output_package='tensorflow',
|
||||||
api_name='tensorflow',
|
api_name='tensorflow',
|
||||||
@ -97,7 +97,7 @@ class CreatePythonApiTest(test.TestCase):
|
|||||||
msg='compat.v1 in %s' % str(imports.keys()))
|
msg='compat.v1 in %s' % str(imports.keys()))
|
||||||
|
|
||||||
def testClassImportIsAdded(self):
|
def testClassImportIsAdded(self):
|
||||||
imports, _ = create_python_api.get_api_init_text(
|
imports, _, _ = create_python_api.get_api_init_text(
|
||||||
packages=[create_python_api._DEFAULT_PACKAGE],
|
packages=[create_python_api._DEFAULT_PACKAGE],
|
||||||
output_package='tensorflow',
|
output_package='tensorflow',
|
||||||
api_name='tensorflow',
|
api_name='tensorflow',
|
||||||
@ -116,7 +116,7 @@ class CreatePythonApiTest(test.TestCase):
|
|||||||
msg='%s not in %s' % (expected_import, str(imports)))
|
msg='%s not in %s' % (expected_import, str(imports)))
|
||||||
|
|
||||||
def testConstantIsAdded(self):
|
def testConstantIsAdded(self):
|
||||||
imports, _ = create_python_api.get_api_init_text(
|
imports, _, _ = create_python_api.get_api_init_text(
|
||||||
packages=[create_python_api._DEFAULT_PACKAGE],
|
packages=[create_python_api._DEFAULT_PACKAGE],
|
||||||
output_package='tensorflow',
|
output_package='tensorflow',
|
||||||
api_name='tensorflow',
|
api_name='tensorflow',
|
||||||
@ -132,7 +132,7 @@ class CreatePythonApiTest(test.TestCase):
|
|||||||
msg='%s not in %s' % (expected, str(imports)))
|
msg='%s not in %s' % (expected, str(imports)))
|
||||||
|
|
||||||
def testCompatModuleIsAdded(self):
|
def testCompatModuleIsAdded(self):
|
||||||
imports, _ = create_python_api.get_api_init_text(
|
imports, _, _ = create_python_api.get_api_init_text(
|
||||||
packages=[create_python_api._DEFAULT_PACKAGE],
|
packages=[create_python_api._DEFAULT_PACKAGE],
|
||||||
output_package='tensorflow',
|
output_package='tensorflow',
|
||||||
api_name='tensorflow',
|
api_name='tensorflow',
|
||||||
@ -144,7 +144,7 @@ class CreatePythonApiTest(test.TestCase):
|
|||||||
msg='compat.v1.test not in %s' % str(imports.keys()))
|
msg='compat.v1.test not in %s' % str(imports.keys()))
|
||||||
|
|
||||||
def testNestedCompatModulesAreAdded(self):
|
def testNestedCompatModulesAreAdded(self):
|
||||||
imports, _ = create_python_api.get_api_init_text(
|
imports, _, _ = create_python_api.get_api_init_text(
|
||||||
packages=[create_python_api._DEFAULT_PACKAGE],
|
packages=[create_python_api._DEFAULT_PACKAGE],
|
||||||
output_package='tensorflow',
|
output_package='tensorflow',
|
||||||
api_name='tensorflow',
|
api_name='tensorflow',
|
||||||
|
@ -132,7 +132,4 @@ try:
|
|||||||
except NameError:
|
except NameError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Manually patch keras and estimator so tf.keras and tf.estimator work
|
|
||||||
keras = _sys.modules["tensorflow.keras"]
|
|
||||||
if not _root_estimator: estimator = _sys.modules["tensorflow.estimator"]
|
|
||||||
# LINT.ThenChange(//tensorflow/virtual_root_template_v2.__init__.py.oss)
|
# LINT.ThenChange(//tensorflow/virtual_root_template_v2.__init__.py.oss)
|
||||||
|
@ -126,14 +126,4 @@ try:
|
|||||||
except NameError:
|
except NameError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# TODO(mihaimaruseac): Revisit all of this once we release 2.1
|
|
||||||
# Manually patch keras and estimator so tf.keras and tf.estimator work
|
|
||||||
keras = _sys.modules["tensorflow.keras"]
|
|
||||||
if not _root_estimator: estimator = _sys.modules["tensorflow.estimator"]
|
|
||||||
# Also import module aliases
|
|
||||||
try:
|
|
||||||
from tensorflow_core import losses, metrics, initializers, optimizers
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# LINT.ThenChange(//tensorflow/virtual_root_template_v1.__init__.py.oss)
|
# LINT.ThenChange(//tensorflow/virtual_root_template_v1.__init__.py.oss)
|
||||||
|
Loading…
Reference in New Issue
Block a user