First step towards reducing the size of python/__init__.py.
API generation has been relying on python/__init__.py to look for modules with @tf_export decorators. Now, modules_with_exports.py should be used to specify modules to scan instead. PiperOrigin-RevId: 314177907 Change-Id: I13460b5db55d4e56dd810106086ecb2aeca69a6b
This commit is contained in:
parent
25213f58c4
commit
5498a3cb12
@ -239,6 +239,21 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# This target should only be used for API generation.
|
||||||
|
py_library(
|
||||||
|
name = "modules_with_exports",
|
||||||
|
srcs = ["modules_with_exports.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow:__pkg__",
|
||||||
|
"//tensorflow/python/tools/api/generator:__pkg__",
|
||||||
|
"//third_party/py/tensorflow_core:__subpackages__",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":no_contrib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(gunan): Investigate making this action hermetic so we do not need
|
# TODO(gunan): Investigate making this action hermetic so we do not need
|
||||||
# to run it locally.
|
# to run it locally.
|
||||||
tf_py_build_info_genrule(
|
tf_py_build_info_genrule(
|
||||||
|
|||||||
@ -30,51 +30,14 @@ import importlib
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
# TODO(drpng): write up instructions for editing this file in a doc and point to
|
# We aim to keep this file minimal and ideally remove completely.
|
||||||
# the doc instead.
|
# If you are adding a new file with @tf_export decorators,
|
||||||
# If you want to edit this file to expose modules in public tensorflow API, you
|
# import it in modules_with_exports.py instead.
|
||||||
# need to follow these steps:
|
|
||||||
# 1. Consult with tensorflow team and get approval for adding a new API to the
|
|
||||||
# public interface.
|
|
||||||
# 2. Document the module in the gen_docs_combined.py.
|
|
||||||
# 3. Import the module in the main tensorflow namespace by adding an import
|
|
||||||
# statement in this file.
|
|
||||||
# 4. Sanitize the entry point by making sure that your module does not expose
|
|
||||||
# transitively imported modules used for implementation, such as os, sys.
|
|
||||||
|
|
||||||
# go/tf-wildcard-import
|
# go/tf-wildcard-import
|
||||||
# pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top
|
# pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top
|
||||||
|
|
||||||
import numpy as np
|
from tensorflow.python.eager import context
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
|
||||||
|
|
||||||
# Protocol buffers
|
|
||||||
from tensorflow.core.framework.graph_pb2 import *
|
|
||||||
from tensorflow.core.framework.node_def_pb2 import *
|
|
||||||
from tensorflow.core.framework.summary_pb2 import *
|
|
||||||
from tensorflow.core.framework.attr_value_pb2 import *
|
|
||||||
from tensorflow.core.protobuf.meta_graph_pb2 import TensorInfo
|
|
||||||
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
|
|
||||||
from tensorflow.core.protobuf.config_pb2 import *
|
|
||||||
from tensorflow.core.protobuf.tensorflow_server_pb2 import *
|
|
||||||
from tensorflow.core.util.event_pb2 import *
|
|
||||||
|
|
||||||
# Framework
|
|
||||||
from tensorflow.python.framework.framework_lib import * # pylint: disable=redefined-builtin
|
|
||||||
from tensorflow.python.framework.versions import *
|
|
||||||
from tensorflow.python.framework import config
|
|
||||||
from tensorflow.python.framework import errors
|
|
||||||
from tensorflow.python.framework import graph_util
|
|
||||||
|
|
||||||
# Session
|
|
||||||
from tensorflow.python.client.client_lib import *
|
|
||||||
|
|
||||||
# Ops
|
|
||||||
from tensorflow.python.ops.standard_ops import *
|
|
||||||
|
|
||||||
# Namespaces
|
|
||||||
from tensorflow.python.ops import initializers_ns as initializers
|
|
||||||
|
|
||||||
# pylint: enable=wildcard-import
|
# pylint: enable=wildcard-import
|
||||||
|
|
||||||
@ -152,8 +115,8 @@ from tensorflow.python.framework.ops import enable_eager_execution
|
|||||||
# Check whether TF2_BEHAVIOR is turned on.
|
# Check whether TF2_BEHAVIOR is turned on.
|
||||||
from tensorflow.python.eager import monitoring as _monitoring
|
from tensorflow.python.eager import monitoring as _monitoring
|
||||||
from tensorflow.python import tf2 as _tf2
|
from tensorflow.python import tf2 as _tf2
|
||||||
_tf2_gauge = _monitoring.BoolGauge('/tensorflow/api/tf2_enable',
|
_tf2_gauge = _monitoring.BoolGauge(
|
||||||
'Environment variable TF2_BEHAVIOR is set".')
|
'/tensorflow/api/tf2_enable', 'Environment variable TF2_BEHAVIOR is set".')
|
||||||
_tf2_gauge.get_cell().set(_tf2.enabled())
|
_tf2_gauge.get_cell().set(_tf2.enabled())
|
||||||
|
|
||||||
# Necessary for the symbols in this module to be taken into account by
|
# Necessary for the symbols in this module to be taken into account by
|
||||||
@ -186,30 +149,6 @@ nn.bidirectional_dynamic_rnn = rnn.bidirectional_dynamic_rnn
|
|||||||
nn.static_state_saving_rnn = rnn.static_state_saving_rnn
|
nn.static_state_saving_rnn = rnn.static_state_saving_rnn
|
||||||
nn.rnn_cell = rnn_cell
|
nn.rnn_cell = rnn_cell
|
||||||
|
|
||||||
# Export protos
|
|
||||||
# pylint: disable=undefined-variable
|
|
||||||
tf_export(v1=['AttrValue'])(AttrValue)
|
|
||||||
tf_export(v1=['ConfigProto'])(ConfigProto)
|
|
||||||
tf_export(v1=['Event', 'summary.Event'])(Event)
|
|
||||||
tf_export(v1=['GPUOptions'])(GPUOptions)
|
|
||||||
tf_export(v1=['GraphDef'])(GraphDef)
|
|
||||||
tf_export(v1=['GraphOptions'])(GraphOptions)
|
|
||||||
tf_export(v1=['HistogramProto'])(HistogramProto)
|
|
||||||
tf_export(v1=['LogMessage'])(LogMessage)
|
|
||||||
tf_export(v1=['MetaGraphDef'])(MetaGraphDef)
|
|
||||||
tf_export(v1=['NameAttrList'])(NameAttrList)
|
|
||||||
tf_export(v1=['NodeDef'])(NodeDef)
|
|
||||||
tf_export(v1=['OptimizerOptions'])(OptimizerOptions)
|
|
||||||
tf_export(v1=['RunMetadata'])(RunMetadata)
|
|
||||||
tf_export(v1=['RunOptions'])(RunOptions)
|
|
||||||
tf_export(v1=['SessionLog', 'summary.SessionLog'])(SessionLog)
|
|
||||||
tf_export(v1=['Summary', 'summary.Summary'])(Summary)
|
|
||||||
tf_export(v1=['summary.SummaryDescription'])(SummaryDescription)
|
|
||||||
tf_export(v1=['SummaryMetadata'])(SummaryMetadata)
|
|
||||||
tf_export(v1=['summary.TaggedRunMetadata'])(TaggedRunMetadata)
|
|
||||||
tf_export(v1=['TensorInfo'])(TensorInfo)
|
|
||||||
# pylint: enable=undefined-variable
|
|
||||||
|
|
||||||
# Special dunders that we choose to export:
|
# Special dunders that we choose to export:
|
||||||
_exported_dunders = set([
|
_exported_dunders = set([
|
||||||
'__version__',
|
'__version__',
|
||||||
|
|||||||
78
tensorflow/python/modules_with_exports.py
Normal file
78
tensorflow/python/modules_with_exports.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Imports modules that should be scanned during API generation.
|
||||||
|
|
||||||
|
This file should eventually contain everything we need to scan looking for
|
||||||
|
tf_export decorators.
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# go/tf-wildcard-import
|
||||||
|
# pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top
|
||||||
|
# pylint: disable=unused-import,g-importing-member
|
||||||
|
|
||||||
|
# Protocol buffers
|
||||||
|
from tensorflow.core.framework.graph_pb2 import *
|
||||||
|
from tensorflow.core.framework.node_def_pb2 import *
|
||||||
|
from tensorflow.core.framework.summary_pb2 import *
|
||||||
|
from tensorflow.core.framework.attr_value_pb2 import *
|
||||||
|
from tensorflow.core.protobuf.meta_graph_pb2 import TensorInfo
|
||||||
|
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
|
||||||
|
from tensorflow.core.protobuf.config_pb2 import *
|
||||||
|
from tensorflow.core.util.event_pb2 import *
|
||||||
|
|
||||||
|
# Framework
|
||||||
|
from tensorflow.python.framework.framework_lib import * # pylint: disable=redefined-builtin
|
||||||
|
from tensorflow.python.framework.versions import *
|
||||||
|
from tensorflow.python.framework import config
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import graph_util
|
||||||
|
|
||||||
|
# Session
|
||||||
|
from tensorflow.python.client.client_lib import *
|
||||||
|
|
||||||
|
# Ops
|
||||||
|
from tensorflow.python.ops.standard_ops import *
|
||||||
|
|
||||||
|
# Namespaces
|
||||||
|
from tensorflow.python.ops import initializers_ns as initializers
|
||||||
|
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
# Export protos
|
||||||
|
# pylint: disable=undefined-variable
|
||||||
|
tf_export(v1=['AttrValue'])(AttrValue)
|
||||||
|
tf_export(v1=['ConfigProto'])(ConfigProto)
|
||||||
|
tf_export(v1=['Event', 'summary.Event'])(Event)
|
||||||
|
tf_export(v1=['GPUOptions'])(GPUOptions)
|
||||||
|
tf_export(v1=['GraphDef'])(GraphDef)
|
||||||
|
tf_export(v1=['GraphOptions'])(GraphOptions)
|
||||||
|
tf_export(v1=['HistogramProto'])(HistogramProto)
|
||||||
|
tf_export(v1=['LogMessage'])(LogMessage)
|
||||||
|
tf_export(v1=['MetaGraphDef'])(MetaGraphDef)
|
||||||
|
tf_export(v1=['NameAttrList'])(NameAttrList)
|
||||||
|
tf_export(v1=['NodeDef'])(NodeDef)
|
||||||
|
tf_export(v1=['OptimizerOptions'])(OptimizerOptions)
|
||||||
|
tf_export(v1=['RunMetadata'])(RunMetadata)
|
||||||
|
tf_export(v1=['RunOptions'])(RunOptions)
|
||||||
|
tf_export(v1=['SessionLog', 'summary.SessionLog'])(SessionLog)
|
||||||
|
tf_export(v1=['Summary', 'summary.Summary'])(Summary)
|
||||||
|
tf_export(v1=['summary.SummaryDescription'])(SummaryDescription)
|
||||||
|
tf_export(v1=['SummaryMetadata'])(SummaryMetadata)
|
||||||
|
tf_export(v1=['summary.TaggedRunMetadata'])(TaggedRunMetadata)
|
||||||
|
tf_export(v1=['TensorInfo'])(TensorInfo)
|
||||||
|
# pylint: enable=undefined-variable
|
||||||
@ -42,8 +42,15 @@ def gen_api_init_files(
|
|||||||
api_version = 2,
|
api_version = 2,
|
||||||
compat_api_versions = [],
|
compat_api_versions = [],
|
||||||
compat_init_templates = [],
|
compat_init_templates = [],
|
||||||
packages = ["tensorflow.python", "tensorflow.lite.python.lite"],
|
packages = [
|
||||||
package_deps = ["//tensorflow/python:no_contrib"],
|
"tensorflow.python",
|
||||||
|
"tensorflow.lite.python.lite",
|
||||||
|
"tensorflow.python.modules_with_exports",
|
||||||
|
],
|
||||||
|
package_deps = [
|
||||||
|
"//tensorflow/python:no_contrib",
|
||||||
|
"//tensorflow/python:modules_with_exports",
|
||||||
|
],
|
||||||
output_package = "tensorflow",
|
output_package = "tensorflow",
|
||||||
output_dir = "",
|
output_dir = "",
|
||||||
root_file_name = "__init__.py"):
|
root_file_name = "__init__.py"):
|
||||||
|
|||||||
@ -12,8 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
"""Generates and prints out imports and constants for new TensorFlow python api.
|
"""Generates and prints out imports and constants for new TensorFlow python api."""
|
||||||
"""
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
@ -85,9 +84,9 @@ def get_canonical_import(import_set):
|
|||||||
ordering.
|
ordering.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
import_set: (set) Imports providing the same symbol. This is a set of
|
import_set: (set) Imports providing the same symbol. This is a set of tuples
|
||||||
tuples in the form (import, priority). We want to pick an import
|
in the form (import, priority). We want to pick an import with highest
|
||||||
with highest priority.
|
priority.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A module name to import
|
A module name to import
|
||||||
@ -106,8 +105,10 @@ def get_canonical_import(import_set):
|
|||||||
class _ModuleInitCodeBuilder(object):
|
class _ModuleInitCodeBuilder(object):
|
||||||
"""Builds a map from module name to imports included in that module."""
|
"""Builds a map from module name to imports included in that module."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self, output_package, api_version, lazy_loading=_LAZY_LOADING,
|
output_package,
|
||||||
|
api_version,
|
||||||
|
lazy_loading=_LAZY_LOADING,
|
||||||
use_relative_imports=False):
|
use_relative_imports=False):
|
||||||
self._output_package = output_package
|
self._output_package = output_package
|
||||||
# Maps API module to API symbol name to set of tuples of the form
|
# Maps API module to API symbol name to set of tuples of the form
|
||||||
@ -127,16 +128,13 @@ class _ModuleInitCodeBuilder(object):
|
|||||||
|
|
||||||
def _check_already_imported(self, symbol_id, api_name):
|
def _check_already_imported(self, symbol_id, api_name):
|
||||||
if (api_name in self._dest_import_to_id and
|
if (api_name in self._dest_import_to_id and
|
||||||
symbol_id != self._dest_import_to_id[api_name] and
|
symbol_id != self._dest_import_to_id[api_name] and symbol_id != -1):
|
||||||
symbol_id != -1):
|
|
||||||
raise SymbolExposedTwiceError(
|
raise SymbolExposedTwiceError(
|
||||||
'Trying to export multiple symbols with same name: %s.' %
|
'Trying to export multiple symbols with same name: %s.' % api_name)
|
||||||
api_name)
|
|
||||||
self._dest_import_to_id[api_name] = symbol_id
|
self._dest_import_to_id[api_name] = symbol_id
|
||||||
|
|
||||||
def add_import(
|
def add_import(self, symbol, source_module_name, source_name,
|
||||||
self, symbol, source_module_name, source_name, dest_module_name,
|
dest_module_name, dest_name):
|
||||||
dest_name):
|
|
||||||
"""Adds this import to module_imports.
|
"""Adds this import to module_imports.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -150,6 +148,10 @@ class _ModuleInitCodeBuilder(object):
|
|||||||
SymbolExposedTwiceError: Raised when an import with the same
|
SymbolExposedTwiceError: Raised when an import with the same
|
||||||
dest_name has already been added to dest_module_name.
|
dest_name has already been added to dest_module_name.
|
||||||
"""
|
"""
|
||||||
|
# modules_with_exports.py is only used during API generation and
|
||||||
|
# won't be available when actually importing tensorflow.
|
||||||
|
if source_module_name.endswith('python.modules_with_exports'):
|
||||||
|
source_module_name = symbol.__module__
|
||||||
import_str = self.format_import(source_module_name, source_name, dest_name)
|
import_str = self.format_import(source_module_name, source_name, dest_name)
|
||||||
|
|
||||||
# Check if we are trying to expose two different symbols with same name.
|
# Check if we are trying to expose two different symbols with same name.
|
||||||
@ -264,8 +266,8 @@ __all__.extend([_s for _s in _names_with_underscore])
|
|||||||
if not dest_module.startswith(_COMPAT_MODULE_PREFIX):
|
if not dest_module.startswith(_COMPAT_MODULE_PREFIX):
|
||||||
deprecation = 'True'
|
deprecation = 'True'
|
||||||
# Workaround to make sure not load lite from lite/__init__.py
|
# Workaround to make sure not load lite from lite/__init__.py
|
||||||
if (not dest_module and 'lite' in self._module_imports
|
if (not dest_module and 'lite' in self._module_imports and
|
||||||
and self._lazy_loading):
|
self._lazy_loading):
|
||||||
has_lite = 'True'
|
has_lite = 'True'
|
||||||
if self._lazy_loading:
|
if self._lazy_loading:
|
||||||
public_apis_name = '_PUBLIC_APIS'
|
public_apis_name = '_PUBLIC_APIS'
|
||||||
@ -311,8 +313,8 @@ __all__.extend([_s for _s in _names_with_underscore])
|
|||||||
self._module_imports[from_dest_module].copy())
|
self._module_imports[from_dest_module].copy())
|
||||||
|
|
||||||
|
|
||||||
def add_nested_compat_imports(
|
def add_nested_compat_imports(module_builder, compat_api_versions,
|
||||||
module_builder, compat_api_versions, output_package):
|
output_package):
|
||||||
"""Adds compat.vN.compat.vK modules to module builder.
|
"""Adds compat.vN.compat.vK modules to module builder.
|
||||||
|
|
||||||
To avoid circular imports, we want to add __init__.py files under
|
To avoid circular imports, we want to add __init__.py files under
|
||||||
@ -334,8 +336,8 @@ def add_nested_compat_imports(
|
|||||||
subcompat_module = _SUBCOMPAT_MODULE_TEMPLATE % (v, sv)
|
subcompat_module = _SUBCOMPAT_MODULE_TEMPLATE % (v, sv)
|
||||||
compat_module = _COMPAT_MODULE_TEMPLATE % sv
|
compat_module = _COMPAT_MODULE_TEMPLATE % sv
|
||||||
module_builder.copy_imports(compat_module, subcompat_module)
|
module_builder.copy_imports(compat_module, subcompat_module)
|
||||||
module_builder.copy_imports(
|
module_builder.copy_imports('%s.compat' % compat_module,
|
||||||
'%s.compat' % compat_module, '%s.compat' % subcompat_module)
|
'%s.compat' % subcompat_module)
|
||||||
|
|
||||||
# Prefixes of modules under compatibility packages, for e.g. "compat.v1.".
|
# Prefixes of modules under compatibility packages, for e.g. "compat.v1.".
|
||||||
compat_prefixes = tuple(
|
compat_prefixes = tuple(
|
||||||
@ -400,8 +402,7 @@ def _join_modules(module1, module2):
|
|||||||
return '%s.%s' % (module1, module2)
|
return '%s.%s' % (module1, module2)
|
||||||
|
|
||||||
|
|
||||||
def add_imports_for_symbol(
|
def add_imports_for_symbol(module_code_builder,
|
||||||
module_code_builder,
|
|
||||||
symbol,
|
symbol,
|
||||||
source_module_name,
|
source_module_name,
|
||||||
source_name,
|
source_name,
|
||||||
@ -432,8 +433,8 @@ def add_imports_for_symbol(
|
|||||||
for export in exports:
|
for export in exports:
|
||||||
dest_module, dest_name = _get_name_and_module(export)
|
dest_module, dest_name = _get_name_and_module(export)
|
||||||
dest_module = _join_modules(output_module_prefix, dest_module)
|
dest_module = _join_modules(output_module_prefix, dest_module)
|
||||||
module_code_builder.add_import(
|
module_code_builder.add_import(None, source_module_name, name,
|
||||||
None, source_module_name, name, dest_module, dest_name)
|
dest_module, dest_name)
|
||||||
|
|
||||||
# If symbol has _tf_api_names attribute, then add import for it.
|
# If symbol has _tf_api_names attribute, then add import for it.
|
||||||
if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__):
|
if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__):
|
||||||
@ -442,8 +443,8 @@ def add_imports_for_symbol(
|
|||||||
for export in getattr(symbol, names_attr): # pylint: disable=protected-access
|
for export in getattr(symbol, names_attr): # pylint: disable=protected-access
|
||||||
dest_module, dest_name = _get_name_and_module(export)
|
dest_module, dest_name = _get_name_and_module(export)
|
||||||
dest_module = _join_modules(output_module_prefix, dest_module)
|
dest_module = _join_modules(output_module_prefix, dest_module)
|
||||||
module_code_builder.add_import(
|
module_code_builder.add_import(symbol, source_module_name, source_name,
|
||||||
symbol, source_module_name, source_name, dest_module, dest_name)
|
dest_module, dest_name)
|
||||||
|
|
||||||
|
|
||||||
def get_api_init_text(packages,
|
def get_api_init_text(packages,
|
||||||
@ -466,8 +467,8 @@ def get_api_init_text(packages,
|
|||||||
directory.
|
directory.
|
||||||
lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is
|
lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is
|
||||||
produced and if `False`, static imports are used.
|
produced and if `False`, static imports are used.
|
||||||
use_relative_imports: True if we should use relative imports when
|
use_relative_imports: True if we should use relative imports when importing
|
||||||
importing submodules.
|
submodules.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary where
|
A dictionary where
|
||||||
@ -477,8 +478,10 @@ def get_api_init_text(packages,
|
|||||||
"""
|
"""
|
||||||
if compat_api_versions is None:
|
if compat_api_versions is None:
|
||||||
compat_api_versions = []
|
compat_api_versions = []
|
||||||
module_code_builder = _ModuleInitCodeBuilder(
|
module_code_builder = _ModuleInitCodeBuilder(output_package, api_version,
|
||||||
output_package, api_version, lazy_loading, use_relative_imports)
|
lazy_loading,
|
||||||
|
use_relative_imports)
|
||||||
|
|
||||||
# Traverse over everything imported above. Specifically,
|
# Traverse over everything imported above. Specifically,
|
||||||
# we want to traverse over TensorFlow Python modules.
|
# we want to traverse over TensorFlow Python modules.
|
||||||
|
|
||||||
@ -496,24 +499,23 @@ def get_api_init_text(packages,
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for module_contents_name in dir(module):
|
for module_contents_name in dir(module):
|
||||||
if (module.__name__ + '.' + module_contents_name
|
if (module.__name__ + '.' +
|
||||||
in _SYMBOLS_TO_SKIP_EXPLICITLY):
|
module_contents_name in _SYMBOLS_TO_SKIP_EXPLICITLY):
|
||||||
continue
|
continue
|
||||||
attr = getattr(module, module_contents_name)
|
attr = getattr(module, module_contents_name)
|
||||||
_, attr = tf_decorator.unwrap(attr)
|
_, attr = tf_decorator.unwrap(attr)
|
||||||
|
|
||||||
add_imports_for_symbol(
|
add_imports_for_symbol(module_code_builder, attr, module.__name__,
|
||||||
module_code_builder, attr, module.__name__, module_contents_name,
|
module_contents_name, api_name, api_version)
|
||||||
api_name, api_version)
|
|
||||||
for compat_api_version in compat_api_versions:
|
for compat_api_version in compat_api_versions:
|
||||||
add_imports_for_symbol(
|
add_imports_for_symbol(module_code_builder, attr, module.__name__,
|
||||||
module_code_builder, attr, module.__name__, module_contents_name,
|
module_contents_name, api_name,
|
||||||
api_name, compat_api_version,
|
compat_api_version,
|
||||||
_COMPAT_MODULE_TEMPLATE % compat_api_version)
|
_COMPAT_MODULE_TEMPLATE % compat_api_version)
|
||||||
|
|
||||||
if compat_api_versions:
|
if compat_api_versions:
|
||||||
add_nested_compat_imports(
|
add_nested_compat_imports(module_code_builder, compat_api_versions,
|
||||||
module_code_builder, compat_api_versions, output_package)
|
output_package)
|
||||||
return module_code_builder.build()
|
return module_code_builder.build()
|
||||||
|
|
||||||
|
|
||||||
@ -545,8 +547,8 @@ def get_module_docstring(module_name, package, api_name):
|
|||||||
4. Returns a default docstring.
|
4. Returns a default docstring.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module_name: module name relative to tensorflow
|
module_name: module name relative to tensorflow (excluding 'tensorflow.'
|
||||||
(excluding 'tensorflow.' prefix) to get a docstring for.
|
prefix) to get a docstring for.
|
||||||
package: Base python package containing python with target tf_export
|
package: Base python package containing python with target tf_export
|
||||||
decorators.
|
decorators.
|
||||||
api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
|
api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
|
||||||
@ -581,31 +583,37 @@ def get_module_docstring(module_name, package, api_name):
|
|||||||
return 'Public API for tf.%s namespace.' % module_name
|
return 'Public API for tf.%s namespace.' % module_name
|
||||||
|
|
||||||
|
|
||||||
def create_api_files(output_files, packages, root_init_template, output_dir,
|
def create_api_files(output_files,
|
||||||
output_package, api_name, api_version,
|
packages,
|
||||||
compat_api_versions, compat_init_templates,
|
root_init_template,
|
||||||
lazy_loading=_LAZY_LOADING, use_relative_imports=False):
|
output_dir,
|
||||||
|
output_package,
|
||||||
|
api_name,
|
||||||
|
api_version,
|
||||||
|
compat_api_versions,
|
||||||
|
compat_init_templates,
|
||||||
|
lazy_loading=_LAZY_LOADING,
|
||||||
|
use_relative_imports=False):
|
||||||
"""Creates __init__.py files for the Python API.
|
"""Creates __init__.py files for the Python API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_files: List of __init__.py file paths to create.
|
output_files: List of __init__.py file paths to create.
|
||||||
packages: Base python packages containing python with target tf_export
|
packages: Base python packages containing python with target tf_export
|
||||||
decorators.
|
decorators.
|
||||||
root_init_template: Template for top-level __init__.py file.
|
root_init_template: Template for top-level __init__.py file. "# API IMPORTS
|
||||||
"# API IMPORTS PLACEHOLDER" comment in the template file will be replaced
|
PLACEHOLDER" comment in the template file will be replaced with imports.
|
||||||
with imports.
|
|
||||||
output_dir: output API root directory.
|
output_dir: output API root directory.
|
||||||
output_package: Base output package where generated API will be added.
|
output_package: Base output package where generated API will be added.
|
||||||
api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
|
api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
|
||||||
api_version: API version to generate (`v1` or `v2`).
|
api_version: API version to generate (`v1` or `v2`).
|
||||||
compat_api_versions: Additional API versions to generate in compat/
|
compat_api_versions: Additional API versions to generate in compat/
|
||||||
subdirectory.
|
subdirectory.
|
||||||
compat_init_templates: List of templates for top level compat init files
|
compat_init_templates: List of templates for top level compat init files in
|
||||||
in the same order as compat_api_versions.
|
the same order as compat_api_versions.
|
||||||
lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is
|
lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is
|
||||||
produced and if `False`, static imports are used.
|
produced and if `False`, static imports are used.
|
||||||
use_relative_imports: True if we should use relative imports when
|
use_relative_imports: True if we should use relative imports when import
|
||||||
import submodules.
|
submodules.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if output_files list is missing a required file.
|
ValueError: if output_files list is missing a required file.
|
||||||
@ -645,8 +653,7 @@ def create_api_files(output_files, packages, root_init_template, output_dir,
|
|||||||
for module, text in module_text_map.items():
|
for module, text in module_text_map.items():
|
||||||
# Make sure genrule output file list is in sync with API exports.
|
# Make sure genrule output file list is in sync with API exports.
|
||||||
if module not in module_name_to_file_path:
|
if module not in module_name_to_file_path:
|
||||||
module_file_path = '"%s/__init__.py"' % (
|
module_file_path = '"%s/__init__.py"' % (module.replace('.', '/'))
|
||||||
module.replace('.', '/'))
|
|
||||||
missing_output_files.append(module_file_path)
|
missing_output_files.append(module_file_path)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -664,8 +671,9 @@ def create_api_files(output_files, packages, root_init_template, output_dir,
|
|||||||
contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
|
contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
|
||||||
else:
|
else:
|
||||||
contents = (
|
contents = (
|
||||||
_GENERATED_FILE_HEADER % get_module_docstring(
|
_GENERATED_FILE_HEADER %
|
||||||
module, packages[0], api_name) + text + _GENERATED_FILE_FOOTER)
|
get_module_docstring(module, packages[0], api_name) + text +
|
||||||
|
_GENERATED_FILE_FOOTER)
|
||||||
if module in deprecation_footer_map:
|
if module in deprecation_footer_map:
|
||||||
if '# WRAPPER_PLACEHOLDER' in contents:
|
if '# WRAPPER_PLACEHOLDER' in contents:
|
||||||
contents = contents.replace('# WRAPPER_PLACEHOLDER',
|
contents = contents.replace('# WRAPPER_PLACEHOLDER',
|
||||||
@ -680,14 +688,17 @@ def create_api_files(output_files, packages, root_init_template, output_dir,
|
|||||||
"""Missing outputs for genrule:\n%s. Be sure to add these targets to
|
"""Missing outputs for genrule:\n%s. Be sure to add these targets to
|
||||||
tensorflow/python/tools/api/generator/api_init_files_v1.bzl and
|
tensorflow/python/tools/api/generator/api_init_files_v1.bzl and
|
||||||
tensorflow/python/tools/api/generator/api_init_files.bzl (tensorflow repo), or
|
tensorflow/python/tools/api/generator/api_init_files.bzl (tensorflow repo), or
|
||||||
tensorflow_estimator/python/estimator/api/api_gen.bzl (estimator repo)"""
|
tensorflow_estimator/python/estimator/api/api_gen.bzl (estimator repo)""" %
|
||||||
% ',\n'.join(sorted(missing_output_files)))
|
',\n'.join(sorted(missing_output_files)))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'outputs', metavar='O', type=str, nargs='+',
|
'outputs',
|
||||||
|
metavar='O',
|
||||||
|
type=str,
|
||||||
|
nargs='+',
|
||||||
help='If a single file is passed in, then we we assume it contains a '
|
help='If a single file is passed in, then we we assume it contains a '
|
||||||
'semicolon-separated list of Python files that we expect this script to '
|
'semicolon-separated list of Python files that we expect this script to '
|
||||||
'output. If multiple files are passed in, then we assume output files '
|
'output. If multiple files are passed in, then we assume output files '
|
||||||
@ -699,36 +710,54 @@ def main():
|
|||||||
help='Base packages that import modules containing the target tf_export '
|
help='Base packages that import modules containing the target tf_export '
|
||||||
'decorators.')
|
'decorators.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--root_init_template', default='', type=str,
|
'--root_init_template',
|
||||||
|
default='',
|
||||||
|
type=str,
|
||||||
help='Template for top level __init__.py file. '
|
help='Template for top level __init__.py file. '
|
||||||
'"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.')
|
'"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--apidir', type=str, required=True,
|
'--apidir',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
help='Directory where generated output files are placed. '
|
help='Directory where generated output files are placed. '
|
||||||
'gendir should be a prefix of apidir. Also, apidir '
|
'gendir should be a prefix of apidir. Also, apidir '
|
||||||
'should be a prefix of every directory in outputs.')
|
'should be a prefix of every directory in outputs.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--apiname', required=True, type=str,
|
'--apiname',
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
choices=API_ATTRS.keys(),
|
choices=API_ATTRS.keys(),
|
||||||
help='The API you want to generate.')
|
help='The API you want to generate.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--apiversion', default=2, type=int,
|
'--apiversion',
|
||||||
|
default=2,
|
||||||
|
type=int,
|
||||||
choices=_API_VERSIONS,
|
choices=_API_VERSIONS,
|
||||||
help='The API version you want to generate.')
|
help='The API version you want to generate.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--compat_apiversions', default=[], type=int, action='append',
|
'--compat_apiversions',
|
||||||
|
default=[],
|
||||||
|
type=int,
|
||||||
|
action='append',
|
||||||
help='Additional versions to generate in compat/ subdirectory. '
|
help='Additional versions to generate in compat/ subdirectory. '
|
||||||
'If set to 0, then no additional version would be generated.')
|
'If set to 0, then no additional version would be generated.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--compat_init_templates', default=[], type=str, action='append',
|
'--compat_init_templates',
|
||||||
|
default=[],
|
||||||
|
type=str,
|
||||||
|
action='append',
|
||||||
help='Templates for top-level __init__ files under compat modules. '
|
help='Templates for top-level __init__ files under compat modules. '
|
||||||
'The list of init file templates must be in the same order as '
|
'The list of init file templates must be in the same order as '
|
||||||
'list of versions passed with compat_apiversions.')
|
'list of versions passed with compat_apiversions.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--output_package', default='tensorflow', type=str,
|
'--output_package',
|
||||||
|
default='tensorflow',
|
||||||
|
type=str,
|
||||||
help='Root output package.')
|
help='Root output package.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--loading', default='default', type=str,
|
'--loading',
|
||||||
|
default='default',
|
||||||
|
type=str,
|
||||||
choices=['lazy', 'static', 'default'],
|
choices=['lazy', 'static', 'default'],
|
||||||
help='Controls how the generated __init__.py file loads the exported '
|
help='Controls how the generated __init__.py file loads the exported '
|
||||||
'symbols. \'lazy\' means the symbols are loaded when first used. '
|
'symbols. \'lazy\' means the symbols are loaded when first used. '
|
||||||
@ -736,7 +765,9 @@ def main():
|
|||||||
'__init__.py file. \'default\' uses the value of the '
|
'__init__.py file. \'default\' uses the value of the '
|
||||||
'_LAZY_LOADING constant in create_python_api.py.')
|
'_LAZY_LOADING constant in create_python_api.py.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--use_relative_imports', default=False, type=bool,
|
'--use_relative_imports',
|
||||||
|
default=False,
|
||||||
|
type=bool,
|
||||||
help='Whether to import submodules using relative imports or absolute '
|
help='Whether to import submodules using relative imports or absolute '
|
||||||
'imports')
|
'imports')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@ -34,7 +34,6 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from six.moves import range
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from google.protobuf import message
|
from google.protobuf import message
|
||||||
@ -87,9 +86,8 @@ _API_GOLDEN_FOLDER_V2 = None
|
|||||||
def _InitPathConstants():
|
def _InitPathConstants():
|
||||||
global _API_GOLDEN_FOLDER_V1
|
global _API_GOLDEN_FOLDER_V1
|
||||||
global _API_GOLDEN_FOLDER_V2
|
global _API_GOLDEN_FOLDER_V2
|
||||||
root_golden_path_v2 = os.path.join(
|
root_golden_path_v2 = os.path.join(resource_loader.get_data_files_path(),
|
||||||
resource_loader.get_data_files_path(), '..', 'golden', 'v2',
|
'..', 'golden', 'v2', 'tensorflow.pbtxt')
|
||||||
'tensorflow.pbtxt')
|
|
||||||
|
|
||||||
if FLAGS.update_goldens:
|
if FLAGS.update_goldens:
|
||||||
root_golden_path_v2 = os.path.realpath(root_golden_path_v2)
|
root_golden_path_v2 = os.path.realpath(root_golden_path_v2)
|
||||||
@ -106,7 +104,6 @@ _UPDATE_WARNING_FILE = resource_loader.get_path_to_datafile(
|
|||||||
|
|
||||||
_NON_CORE_PACKAGES = ['estimator']
|
_NON_CORE_PACKAGES = ['estimator']
|
||||||
|
|
||||||
|
|
||||||
# TODO(annarev): remove this once we test with newer version of
|
# TODO(annarev): remove this once we test with newer version of
|
||||||
# estimator that actually has compat v1 version.
|
# estimator that actually has compat v1 version.
|
||||||
if not hasattr(tf.compat.v1, 'estimator'):
|
if not hasattr(tf.compat.v1, 'estimator'):
|
||||||
@ -282,9 +279,6 @@ class ApiCompatibilityTest(test.TestCase):
|
|||||||
diff_count = len(diffs)
|
diff_count = len(diffs)
|
||||||
logging.error(self._test_readme_message)
|
logging.error(self._test_readme_message)
|
||||||
logging.error('%d differences found between API and golden.', diff_count)
|
logging.error('%d differences found between API and golden.', diff_count)
|
||||||
messages = verbose_diffs if verbose else diffs
|
|
||||||
for i in range(diff_count):
|
|
||||||
print('Issue %d\t: %s' % (i + 1, messages[i]), file=sys.stderr)
|
|
||||||
|
|
||||||
if update_goldens:
|
if update_goldens:
|
||||||
# Write files if requested.
|
# Write files if requested.
|
||||||
@ -394,11 +388,12 @@ class ApiCompatibilityTest(test.TestCase):
|
|||||||
resource_loader.get_root_dir_with_all_resources(),
|
resource_loader.get_root_dir_with_all_resources(),
|
||||||
_KeyToFilePath('*', api_version))
|
_KeyToFilePath('*', api_version))
|
||||||
omit_golden_symbols_map = {}
|
omit_golden_symbols_map = {}
|
||||||
if (api_version == 2 and FLAGS.only_test_core_api
|
if (api_version == 2 and FLAGS.only_test_core_api and
|
||||||
and not _TENSORBOARD_AVAILABLE):
|
not _TENSORBOARD_AVAILABLE):
|
||||||
# In TF 2.0 these summary symbols are imported from TensorBoard.
|
# In TF 2.0 these summary symbols are imported from TensorBoard.
|
||||||
omit_golden_symbols_map['tensorflow.summary'] = [
|
omit_golden_symbols_map['tensorflow.summary'] = [
|
||||||
'audio', 'histogram', 'image', 'scalar', 'text']
|
'audio', 'histogram', 'image', 'scalar', 'text'
|
||||||
|
]
|
||||||
|
|
||||||
self._checkBackwardsCompatibility(
|
self._checkBackwardsCompatibility(
|
||||||
tf,
|
tf,
|
||||||
@ -418,7 +413,9 @@ class ApiCompatibilityTest(test.TestCase):
|
|||||||
resource_loader.get_root_dir_with_all_resources(),
|
resource_loader.get_root_dir_with_all_resources(),
|
||||||
_KeyToFilePath('*', api_version))
|
_KeyToFilePath('*', api_version))
|
||||||
self._checkBackwardsCompatibility(
|
self._checkBackwardsCompatibility(
|
||||||
tf.compat.v1, golden_file_pattern, api_version,
|
tf.compat.v1,
|
||||||
|
golden_file_pattern,
|
||||||
|
api_version,
|
||||||
additional_private_map={
|
additional_private_map={
|
||||||
'tf': ['pywrap_tensorflow'],
|
'tf': ['pywrap_tensorflow'],
|
||||||
'tf.compat': ['v1', 'v2'],
|
'tf.compat': ['v1', 'v2'],
|
||||||
@ -434,7 +431,8 @@ class ApiCompatibilityTest(test.TestCase):
|
|||||||
if FLAGS.only_test_core_api and not _TENSORBOARD_AVAILABLE:
|
if FLAGS.only_test_core_api and not _TENSORBOARD_AVAILABLE:
|
||||||
# In TF 2.0 these summary symbols are imported from TensorBoard.
|
# In TF 2.0 these summary symbols are imported from TensorBoard.
|
||||||
omit_golden_symbols_map['tensorflow.summary'] = [
|
omit_golden_symbols_map['tensorflow.summary'] = [
|
||||||
'audio', 'histogram', 'image', 'scalar', 'text']
|
'audio', 'histogram', 'image', 'scalar', 'text'
|
||||||
|
]
|
||||||
self._checkBackwardsCompatibility(
|
self._checkBackwardsCompatibility(
|
||||||
tf.compat.v2,
|
tf.compat.v2,
|
||||||
golden_file_pattern,
|
golden_file_pattern,
|
||||||
|
|||||||
@ -46,8 +46,15 @@ def _traverse_internal(root, visit, stack, path):
|
|||||||
children.append(enum_member)
|
children.append(enum_member)
|
||||||
children = sorted(children)
|
children = sorted(children)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# On some Python installations, some modules do not support enumerating
|
# Children could be missing for one of two reasons:
|
||||||
|
# 1. On some Python installations, some modules do not support enumerating
|
||||||
# members (six in particular), leading to import errors.
|
# members (six in particular), leading to import errors.
|
||||||
|
# 2. Children are lazy-loaded.
|
||||||
|
try:
|
||||||
|
children = []
|
||||||
|
for child_name in root.__all__:
|
||||||
|
children.append((child_name, getattr(root, child_name)))
|
||||||
|
except AttributeError:
|
||||||
children = []
|
children = []
|
||||||
|
|
||||||
new_stack = stack + [root]
|
new_stack = stack + [root]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user