First step towards reducing the size of python/__init__.py.

API generation has been relying on python/__init__.py to look for modules with @tf_export decorators. Now, modules_with_exports.py should be used to specify modules to scan instead.

PiperOrigin-RevId: 314177907
Change-Id: I13460b5db55d4e56dd810106086ecb2aeca69a6b
This commit is contained in:
Anna R 2020-06-01 12:15:28 -07:00 committed by TensorFlower Gardener
parent 25213f58c4
commit 5498a3cb12
7 changed files with 249 additions and 174 deletions

View File

@ -239,6 +239,21 @@ py_library(
], ],
) )
# This target should only be used for API generation.
py_library(
name = "modules_with_exports",
srcs = ["modules_with_exports.py"],
srcs_version = "PY2AND3",
visibility = [
"//tensorflow:__pkg__",
"//tensorflow/python/tools/api/generator:__pkg__",
"//third_party/py/tensorflow_core:__subpackages__",
],
deps = [
":no_contrib",
],
)
# TODO(gunan): Investigate making this action hermetic so we do not need # TODO(gunan): Investigate making this action hermetic so we do not need
# to run it locally. # to run it locally.
tf_py_build_info_genrule( tf_py_build_info_genrule(

View File

@ -30,51 +30,14 @@ import importlib
import sys import sys
import traceback import traceback
# TODO(drpng): write up instructions for editing this file in a doc and point to # We aim to keep this file minimal and ideally remove completely.
# the doc instead. # If you are adding a new file with @tf_export decorators,
# If you want to edit this file to expose modules in public tensorflow API, you # import it in modules_with_exports.py instead.
# need to follow these steps:
# 1. Consult with tensorflow team and get approval for adding a new API to the
# public interface.
# 2. Document the module in the gen_docs_combined.py.
# 3. Import the module in the main tensorflow namespace by adding an import
# statement in this file.
# 4. Sanitize the entry point by making sure that your module does not expose
# transitively imported modules used for implementation, such as os, sys.
# go/tf-wildcard-import # go/tf-wildcard-import
# pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top # pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top
import numpy as np from tensorflow.python.eager import context
from tensorflow.python import pywrap_tensorflow
# Protocol buffers
from tensorflow.core.framework.graph_pb2 import *
from tensorflow.core.framework.node_def_pb2 import *
from tensorflow.core.framework.summary_pb2 import *
from tensorflow.core.framework.attr_value_pb2 import *
from tensorflow.core.protobuf.meta_graph_pb2 import TensorInfo
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
from tensorflow.core.protobuf.config_pb2 import *
from tensorflow.core.protobuf.tensorflow_server_pb2 import *
from tensorflow.core.util.event_pb2 import *
# Framework
from tensorflow.python.framework.framework_lib import * # pylint: disable=redefined-builtin
from tensorflow.python.framework.versions import *
from tensorflow.python.framework import config
from tensorflow.python.framework import errors
from tensorflow.python.framework import graph_util
# Session
from tensorflow.python.client.client_lib import *
# Ops
from tensorflow.python.ops.standard_ops import *
# Namespaces
from tensorflow.python.ops import initializers_ns as initializers
# pylint: enable=wildcard-import # pylint: enable=wildcard-import
@ -152,8 +115,8 @@ from tensorflow.python.framework.ops import enable_eager_execution
# Check whether TF2_BEHAVIOR is turned on. # Check whether TF2_BEHAVIOR is turned on.
from tensorflow.python.eager import monitoring as _monitoring from tensorflow.python.eager import monitoring as _monitoring
from tensorflow.python import tf2 as _tf2 from tensorflow.python import tf2 as _tf2
_tf2_gauge = _monitoring.BoolGauge('/tensorflow/api/tf2_enable', _tf2_gauge = _monitoring.BoolGauge(
'Environment variable TF2_BEHAVIOR is set".') '/tensorflow/api/tf2_enable', 'Environment variable TF2_BEHAVIOR is set".')
_tf2_gauge.get_cell().set(_tf2.enabled()) _tf2_gauge.get_cell().set(_tf2.enabled())
# Necessary for the symbols in this module to be taken into account by # Necessary for the symbols in this module to be taken into account by
@ -186,30 +149,6 @@ nn.bidirectional_dynamic_rnn = rnn.bidirectional_dynamic_rnn
nn.static_state_saving_rnn = rnn.static_state_saving_rnn nn.static_state_saving_rnn = rnn.static_state_saving_rnn
nn.rnn_cell = rnn_cell nn.rnn_cell = rnn_cell
# Export protos
# pylint: disable=undefined-variable
tf_export(v1=['AttrValue'])(AttrValue)
tf_export(v1=['ConfigProto'])(ConfigProto)
tf_export(v1=['Event', 'summary.Event'])(Event)
tf_export(v1=['GPUOptions'])(GPUOptions)
tf_export(v1=['GraphDef'])(GraphDef)
tf_export(v1=['GraphOptions'])(GraphOptions)
tf_export(v1=['HistogramProto'])(HistogramProto)
tf_export(v1=['LogMessage'])(LogMessage)
tf_export(v1=['MetaGraphDef'])(MetaGraphDef)
tf_export(v1=['NameAttrList'])(NameAttrList)
tf_export(v1=['NodeDef'])(NodeDef)
tf_export(v1=['OptimizerOptions'])(OptimizerOptions)
tf_export(v1=['RunMetadata'])(RunMetadata)
tf_export(v1=['RunOptions'])(RunOptions)
tf_export(v1=['SessionLog', 'summary.SessionLog'])(SessionLog)
tf_export(v1=['Summary', 'summary.Summary'])(Summary)
tf_export(v1=['summary.SummaryDescription'])(SummaryDescription)
tf_export(v1=['SummaryMetadata'])(SummaryMetadata)
tf_export(v1=['summary.TaggedRunMetadata'])(TaggedRunMetadata)
tf_export(v1=['TensorInfo'])(TensorInfo)
# pylint: enable=undefined-variable
# Special dunders that we choose to export: # Special dunders that we choose to export:
_exported_dunders = set([ _exported_dunders = set([
'__version__', '__version__',

View File

@ -0,0 +1,78 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Imports modules that should be scanned during API generation.
This file should eventually contain everything we need to scan looking for
tf_export decorators.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# go/tf-wildcard-import
# pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top
# pylint: disable=unused-import,g-importing-member
# Protocol buffers
from tensorflow.core.framework.graph_pb2 import *
from tensorflow.core.framework.node_def_pb2 import *
from tensorflow.core.framework.summary_pb2 import *
from tensorflow.core.framework.attr_value_pb2 import *
from tensorflow.core.protobuf.meta_graph_pb2 import TensorInfo
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
from tensorflow.core.protobuf.config_pb2 import *
from tensorflow.core.util.event_pb2 import *
# Framework
from tensorflow.python.framework.framework_lib import * # pylint: disable=redefined-builtin
from tensorflow.python.framework.versions import *
from tensorflow.python.framework import config
from tensorflow.python.framework import errors
from tensorflow.python.framework import graph_util
# Session
from tensorflow.python.client.client_lib import *
# Ops
from tensorflow.python.ops.standard_ops import *
# Namespaces
from tensorflow.python.ops import initializers_ns as initializers
from tensorflow.python.util.tf_export import tf_export
# Export protos
# pylint: disable=undefined-variable
tf_export(v1=['AttrValue'])(AttrValue)
tf_export(v1=['ConfigProto'])(ConfigProto)
tf_export(v1=['Event', 'summary.Event'])(Event)
tf_export(v1=['GPUOptions'])(GPUOptions)
tf_export(v1=['GraphDef'])(GraphDef)
tf_export(v1=['GraphOptions'])(GraphOptions)
tf_export(v1=['HistogramProto'])(HistogramProto)
tf_export(v1=['LogMessage'])(LogMessage)
tf_export(v1=['MetaGraphDef'])(MetaGraphDef)
tf_export(v1=['NameAttrList'])(NameAttrList)
tf_export(v1=['NodeDef'])(NodeDef)
tf_export(v1=['OptimizerOptions'])(OptimizerOptions)
tf_export(v1=['RunMetadata'])(RunMetadata)
tf_export(v1=['RunOptions'])(RunOptions)
tf_export(v1=['SessionLog', 'summary.SessionLog'])(SessionLog)
tf_export(v1=['Summary', 'summary.Summary'])(Summary)
tf_export(v1=['summary.SummaryDescription'])(SummaryDescription)
tf_export(v1=['SummaryMetadata'])(SummaryMetadata)
tf_export(v1=['summary.TaggedRunMetadata'])(TaggedRunMetadata)
tf_export(v1=['TensorInfo'])(TensorInfo)
# pylint: enable=undefined-variable

View File

@ -42,8 +42,15 @@ def gen_api_init_files(
api_version = 2, api_version = 2,
compat_api_versions = [], compat_api_versions = [],
compat_init_templates = [], compat_init_templates = [],
packages = ["tensorflow.python", "tensorflow.lite.python.lite"], packages = [
package_deps = ["//tensorflow/python:no_contrib"], "tensorflow.python",
"tensorflow.lite.python.lite",
"tensorflow.python.modules_with_exports",
],
package_deps = [
"//tensorflow/python:no_contrib",
"//tensorflow/python:modules_with_exports",
],
output_package = "tensorflow", output_package = "tensorflow",
output_dir = "", output_dir = "",
root_file_name = "__init__.py"): root_file_name = "__init__.py"):

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================= # =============================================================================
"""Generates and prints out imports and constants for new TensorFlow python api. """Generates and prints out imports and constants for new TensorFlow python api."""
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
@ -85,9 +84,9 @@ def get_canonical_import(import_set):
ordering. ordering.
Args: Args:
import_set: (set) Imports providing the same symbol. This is a set of import_set: (set) Imports providing the same symbol. This is a set of tuples
tuples in the form (import, priority). We want to pick an import in the form (import, priority). We want to pick an import with highest
with highest priority. priority.
Returns: Returns:
A module name to import A module name to import
@ -106,8 +105,10 @@ def get_canonical_import(import_set):
class _ModuleInitCodeBuilder(object): class _ModuleInitCodeBuilder(object):
"""Builds a map from module name to imports included in that module.""" """Builds a map from module name to imports included in that module."""
def __init__( def __init__(self,
self, output_package, api_version, lazy_loading=_LAZY_LOADING, output_package,
api_version,
lazy_loading=_LAZY_LOADING,
use_relative_imports=False): use_relative_imports=False):
self._output_package = output_package self._output_package = output_package
# Maps API module to API symbol name to set of tuples of the form # Maps API module to API symbol name to set of tuples of the form
@ -127,16 +128,13 @@ class _ModuleInitCodeBuilder(object):
def _check_already_imported(self, symbol_id, api_name): def _check_already_imported(self, symbol_id, api_name):
if (api_name in self._dest_import_to_id and if (api_name in self._dest_import_to_id and
symbol_id != self._dest_import_to_id[api_name] and symbol_id != self._dest_import_to_id[api_name] and symbol_id != -1):
symbol_id != -1):
raise SymbolExposedTwiceError( raise SymbolExposedTwiceError(
'Trying to export multiple symbols with same name: %s.' % 'Trying to export multiple symbols with same name: %s.' % api_name)
api_name)
self._dest_import_to_id[api_name] = symbol_id self._dest_import_to_id[api_name] = symbol_id
def add_import( def add_import(self, symbol, source_module_name, source_name,
self, symbol, source_module_name, source_name, dest_module_name, dest_module_name, dest_name):
dest_name):
"""Adds this import to module_imports. """Adds this import to module_imports.
Args: Args:
@ -150,6 +148,10 @@ class _ModuleInitCodeBuilder(object):
SymbolExposedTwiceError: Raised when an import with the same SymbolExposedTwiceError: Raised when an import with the same
dest_name has already been added to dest_module_name. dest_name has already been added to dest_module_name.
""" """
# modules_with_exports.py is only used during API generation and
# won't be available when actually importing tensorflow.
if source_module_name.endswith('python.modules_with_exports'):
source_module_name = symbol.__module__
import_str = self.format_import(source_module_name, source_name, dest_name) import_str = self.format_import(source_module_name, source_name, dest_name)
# Check if we are trying to expose two different symbols with same name. # Check if we are trying to expose two different symbols with same name.
@ -264,8 +266,8 @@ __all__.extend([_s for _s in _names_with_underscore])
if not dest_module.startswith(_COMPAT_MODULE_PREFIX): if not dest_module.startswith(_COMPAT_MODULE_PREFIX):
deprecation = 'True' deprecation = 'True'
# Workaround to make sure not load lite from lite/__init__.py # Workaround to make sure not load lite from lite/__init__.py
if (not dest_module and 'lite' in self._module_imports if (not dest_module and 'lite' in self._module_imports and
and self._lazy_loading): self._lazy_loading):
has_lite = 'True' has_lite = 'True'
if self._lazy_loading: if self._lazy_loading:
public_apis_name = '_PUBLIC_APIS' public_apis_name = '_PUBLIC_APIS'
@ -311,8 +313,8 @@ __all__.extend([_s for _s in _names_with_underscore])
self._module_imports[from_dest_module].copy()) self._module_imports[from_dest_module].copy())
def add_nested_compat_imports( def add_nested_compat_imports(module_builder, compat_api_versions,
module_builder, compat_api_versions, output_package): output_package):
"""Adds compat.vN.compat.vK modules to module builder. """Adds compat.vN.compat.vK modules to module builder.
To avoid circular imports, we want to add __init__.py files under To avoid circular imports, we want to add __init__.py files under
@ -334,8 +336,8 @@ def add_nested_compat_imports(
subcompat_module = _SUBCOMPAT_MODULE_TEMPLATE % (v, sv) subcompat_module = _SUBCOMPAT_MODULE_TEMPLATE % (v, sv)
compat_module = _COMPAT_MODULE_TEMPLATE % sv compat_module = _COMPAT_MODULE_TEMPLATE % sv
module_builder.copy_imports(compat_module, subcompat_module) module_builder.copy_imports(compat_module, subcompat_module)
module_builder.copy_imports( module_builder.copy_imports('%s.compat' % compat_module,
'%s.compat' % compat_module, '%s.compat' % subcompat_module) '%s.compat' % subcompat_module)
# Prefixes of modules under compatibility packages, for e.g. "compat.v1.". # Prefixes of modules under compatibility packages, for e.g. "compat.v1.".
compat_prefixes = tuple( compat_prefixes = tuple(
@ -400,8 +402,7 @@ def _join_modules(module1, module2):
return '%s.%s' % (module1, module2) return '%s.%s' % (module1, module2)
def add_imports_for_symbol( def add_imports_for_symbol(module_code_builder,
module_code_builder,
symbol, symbol,
source_module_name, source_module_name,
source_name, source_name,
@ -432,8 +433,8 @@ def add_imports_for_symbol(
for export in exports: for export in exports:
dest_module, dest_name = _get_name_and_module(export) dest_module, dest_name = _get_name_and_module(export)
dest_module = _join_modules(output_module_prefix, dest_module) dest_module = _join_modules(output_module_prefix, dest_module)
module_code_builder.add_import( module_code_builder.add_import(None, source_module_name, name,
None, source_module_name, name, dest_module, dest_name) dest_module, dest_name)
# If symbol has _tf_api_names attribute, then add import for it. # If symbol has _tf_api_names attribute, then add import for it.
if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__): if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__):
@ -442,8 +443,8 @@ def add_imports_for_symbol(
for export in getattr(symbol, names_attr): # pylint: disable=protected-access for export in getattr(symbol, names_attr): # pylint: disable=protected-access
dest_module, dest_name = _get_name_and_module(export) dest_module, dest_name = _get_name_and_module(export)
dest_module = _join_modules(output_module_prefix, dest_module) dest_module = _join_modules(output_module_prefix, dest_module)
module_code_builder.add_import( module_code_builder.add_import(symbol, source_module_name, source_name,
symbol, source_module_name, source_name, dest_module, dest_name) dest_module, dest_name)
def get_api_init_text(packages, def get_api_init_text(packages,
@ -466,8 +467,8 @@ def get_api_init_text(packages,
directory. directory.
lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is
produced and if `False`, static imports are used. produced and if `False`, static imports are used.
use_relative_imports: True if we should use relative imports when use_relative_imports: True if we should use relative imports when importing
importing submodules. submodules.
Returns: Returns:
A dictionary where A dictionary where
@ -477,8 +478,10 @@ def get_api_init_text(packages,
""" """
if compat_api_versions is None: if compat_api_versions is None:
compat_api_versions = [] compat_api_versions = []
module_code_builder = _ModuleInitCodeBuilder( module_code_builder = _ModuleInitCodeBuilder(output_package, api_version,
output_package, api_version, lazy_loading, use_relative_imports) lazy_loading,
use_relative_imports)
# Traverse over everything imported above. Specifically, # Traverse over everything imported above. Specifically,
# we want to traverse over TensorFlow Python modules. # we want to traverse over TensorFlow Python modules.
@ -496,24 +499,23 @@ def get_api_init_text(packages,
continue continue
for module_contents_name in dir(module): for module_contents_name in dir(module):
if (module.__name__ + '.' + module_contents_name if (module.__name__ + '.' +
in _SYMBOLS_TO_SKIP_EXPLICITLY): module_contents_name in _SYMBOLS_TO_SKIP_EXPLICITLY):
continue continue
attr = getattr(module, module_contents_name) attr = getattr(module, module_contents_name)
_, attr = tf_decorator.unwrap(attr) _, attr = tf_decorator.unwrap(attr)
add_imports_for_symbol( add_imports_for_symbol(module_code_builder, attr, module.__name__,
module_code_builder, attr, module.__name__, module_contents_name, module_contents_name, api_name, api_version)
api_name, api_version)
for compat_api_version in compat_api_versions: for compat_api_version in compat_api_versions:
add_imports_for_symbol( add_imports_for_symbol(module_code_builder, attr, module.__name__,
module_code_builder, attr, module.__name__, module_contents_name, module_contents_name, api_name,
api_name, compat_api_version, compat_api_version,
_COMPAT_MODULE_TEMPLATE % compat_api_version) _COMPAT_MODULE_TEMPLATE % compat_api_version)
if compat_api_versions: if compat_api_versions:
add_nested_compat_imports( add_nested_compat_imports(module_code_builder, compat_api_versions,
module_code_builder, compat_api_versions, output_package) output_package)
return module_code_builder.build() return module_code_builder.build()
@ -545,8 +547,8 @@ def get_module_docstring(module_name, package, api_name):
4. Returns a default docstring. 4. Returns a default docstring.
Args: Args:
module_name: module name relative to tensorflow module_name: module name relative to tensorflow (excluding 'tensorflow.'
(excluding 'tensorflow.' prefix) to get a docstring for. prefix) to get a docstring for.
package: Base python package containing python with target tf_export package: Base python package containing python with target tf_export
decorators. decorators.
api_name: API you want to generate (e.g. `tensorflow` or `estimator`). api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
@ -581,31 +583,37 @@ def get_module_docstring(module_name, package, api_name):
return 'Public API for tf.%s namespace.' % module_name return 'Public API for tf.%s namespace.' % module_name
def create_api_files(output_files, packages, root_init_template, output_dir, def create_api_files(output_files,
output_package, api_name, api_version, packages,
compat_api_versions, compat_init_templates, root_init_template,
lazy_loading=_LAZY_LOADING, use_relative_imports=False): output_dir,
output_package,
api_name,
api_version,
compat_api_versions,
compat_init_templates,
lazy_loading=_LAZY_LOADING,
use_relative_imports=False):
"""Creates __init__.py files for the Python API. """Creates __init__.py files for the Python API.
Args: Args:
output_files: List of __init__.py file paths to create. output_files: List of __init__.py file paths to create.
packages: Base python packages containing python with target tf_export packages: Base python packages containing python with target tf_export
decorators. decorators.
root_init_template: Template for top-level __init__.py file. root_init_template: Template for top-level __init__.py file. "# API IMPORTS
"# API IMPORTS PLACEHOLDER" comment in the template file will be replaced PLACEHOLDER" comment in the template file will be replaced with imports.
with imports.
output_dir: output API root directory. output_dir: output API root directory.
output_package: Base output package where generated API will be added. output_package: Base output package where generated API will be added.
api_name: API you want to generate (e.g. `tensorflow` or `estimator`). api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
api_version: API version to generate (`v1` or `v2`). api_version: API version to generate (`v1` or `v2`).
compat_api_versions: Additional API versions to generate in compat/ compat_api_versions: Additional API versions to generate in compat/
subdirectory. subdirectory.
compat_init_templates: List of templates for top level compat init files compat_init_templates: List of templates for top level compat init files in
in the same order as compat_api_versions. the same order as compat_api_versions.
lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is
produced and if `False`, static imports are used. produced and if `False`, static imports are used.
use_relative_imports: True if we should use relative imports when use_relative_imports: True if we should use relative imports when import
import submodules. submodules.
Raises: Raises:
ValueError: if output_files list is missing a required file. ValueError: if output_files list is missing a required file.
@ -645,8 +653,7 @@ def create_api_files(output_files, packages, root_init_template, output_dir,
for module, text in module_text_map.items(): for module, text in module_text_map.items():
# Make sure genrule output file list is in sync with API exports. # Make sure genrule output file list is in sync with API exports.
if module not in module_name_to_file_path: if module not in module_name_to_file_path:
module_file_path = '"%s/__init__.py"' % ( module_file_path = '"%s/__init__.py"' % (module.replace('.', '/'))
module.replace('.', '/'))
missing_output_files.append(module_file_path) missing_output_files.append(module_file_path)
continue continue
@ -664,8 +671,9 @@ def create_api_files(output_files, packages, root_init_template, output_dir,
contents = contents.replace('# API IMPORTS PLACEHOLDER', text) contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
else: else:
contents = ( contents = (
_GENERATED_FILE_HEADER % get_module_docstring( _GENERATED_FILE_HEADER %
module, packages[0], api_name) + text + _GENERATED_FILE_FOOTER) get_module_docstring(module, packages[0], api_name) + text +
_GENERATED_FILE_FOOTER)
if module in deprecation_footer_map: if module in deprecation_footer_map:
if '# WRAPPER_PLACEHOLDER' in contents: if '# WRAPPER_PLACEHOLDER' in contents:
contents = contents.replace('# WRAPPER_PLACEHOLDER', contents = contents.replace('# WRAPPER_PLACEHOLDER',
@ -680,14 +688,17 @@ def create_api_files(output_files, packages, root_init_template, output_dir,
"""Missing outputs for genrule:\n%s. Be sure to add these targets to """Missing outputs for genrule:\n%s. Be sure to add these targets to
tensorflow/python/tools/api/generator/api_init_files_v1.bzl and tensorflow/python/tools/api/generator/api_init_files_v1.bzl and
tensorflow/python/tools/api/generator/api_init_files.bzl (tensorflow repo), or tensorflow/python/tools/api/generator/api_init_files.bzl (tensorflow repo), or
tensorflow_estimator/python/estimator/api/api_gen.bzl (estimator repo)""" tensorflow_estimator/python/estimator/api/api_gen.bzl (estimator repo)""" %
% ',\n'.join(sorted(missing_output_files))) ',\n'.join(sorted(missing_output_files)))
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'outputs', metavar='O', type=str, nargs='+', 'outputs',
metavar='O',
type=str,
nargs='+',
help='If a single file is passed in, then we we assume it contains a ' help='If a single file is passed in, then we we assume it contains a '
'semicolon-separated list of Python files that we expect this script to ' 'semicolon-separated list of Python files that we expect this script to '
'output. If multiple files are passed in, then we assume output files ' 'output. If multiple files are passed in, then we assume output files '
@ -699,36 +710,54 @@ def main():
help='Base packages that import modules containing the target tf_export ' help='Base packages that import modules containing the target tf_export '
'decorators.') 'decorators.')
parser.add_argument( parser.add_argument(
'--root_init_template', default='', type=str, '--root_init_template',
default='',
type=str,
help='Template for top level __init__.py file. ' help='Template for top level __init__.py file. '
'"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.') '"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.')
parser.add_argument( parser.add_argument(
'--apidir', type=str, required=True, '--apidir',
type=str,
required=True,
help='Directory where generated output files are placed. ' help='Directory where generated output files are placed. '
'gendir should be a prefix of apidir. Also, apidir ' 'gendir should be a prefix of apidir. Also, apidir '
'should be a prefix of every directory in outputs.') 'should be a prefix of every directory in outputs.')
parser.add_argument( parser.add_argument(
'--apiname', required=True, type=str, '--apiname',
required=True,
type=str,
choices=API_ATTRS.keys(), choices=API_ATTRS.keys(),
help='The API you want to generate.') help='The API you want to generate.')
parser.add_argument( parser.add_argument(
'--apiversion', default=2, type=int, '--apiversion',
default=2,
type=int,
choices=_API_VERSIONS, choices=_API_VERSIONS,
help='The API version you want to generate.') help='The API version you want to generate.')
parser.add_argument( parser.add_argument(
'--compat_apiversions', default=[], type=int, action='append', '--compat_apiversions',
default=[],
type=int,
action='append',
help='Additional versions to generate in compat/ subdirectory. ' help='Additional versions to generate in compat/ subdirectory. '
'If set to 0, then no additional version would be generated.') 'If set to 0, then no additional version would be generated.')
parser.add_argument( parser.add_argument(
'--compat_init_templates', default=[], type=str, action='append', '--compat_init_templates',
default=[],
type=str,
action='append',
help='Templates for top-level __init__ files under compat modules. ' help='Templates for top-level __init__ files under compat modules. '
'The list of init file templates must be in the same order as ' 'The list of init file templates must be in the same order as '
'list of versions passed with compat_apiversions.') 'list of versions passed with compat_apiversions.')
parser.add_argument( parser.add_argument(
'--output_package', default='tensorflow', type=str, '--output_package',
default='tensorflow',
type=str,
help='Root output package.') help='Root output package.')
parser.add_argument( parser.add_argument(
'--loading', default='default', type=str, '--loading',
default='default',
type=str,
choices=['lazy', 'static', 'default'], choices=['lazy', 'static', 'default'],
help='Controls how the generated __init__.py file loads the exported ' help='Controls how the generated __init__.py file loads the exported '
'symbols. \'lazy\' means the symbols are loaded when first used. ' 'symbols. \'lazy\' means the symbols are loaded when first used. '
@ -736,7 +765,9 @@ def main():
'__init__.py file. \'default\' uses the value of the ' '__init__.py file. \'default\' uses the value of the '
'_LAZY_LOADING constant in create_python_api.py.') '_LAZY_LOADING constant in create_python_api.py.')
parser.add_argument( parser.add_argument(
'--use_relative_imports', default=False, type=bool, '--use_relative_imports',
default=False,
type=bool,
help='Whether to import submodules using relative imports or absolute ' help='Whether to import submodules using relative imports or absolute '
'imports') 'imports')
args = parser.parse_args() args = parser.parse_args()

View File

@ -34,7 +34,6 @@ import re
import sys import sys
import six import six
from six.moves import range
import tensorflow as tf import tensorflow as tf
from google.protobuf import message from google.protobuf import message
@ -87,9 +86,8 @@ _API_GOLDEN_FOLDER_V2 = None
def _InitPathConstants(): def _InitPathConstants():
global _API_GOLDEN_FOLDER_V1 global _API_GOLDEN_FOLDER_V1
global _API_GOLDEN_FOLDER_V2 global _API_GOLDEN_FOLDER_V2
root_golden_path_v2 = os.path.join( root_golden_path_v2 = os.path.join(resource_loader.get_data_files_path(),
resource_loader.get_data_files_path(), '..', 'golden', 'v2', '..', 'golden', 'v2', 'tensorflow.pbtxt')
'tensorflow.pbtxt')
if FLAGS.update_goldens: if FLAGS.update_goldens:
root_golden_path_v2 = os.path.realpath(root_golden_path_v2) root_golden_path_v2 = os.path.realpath(root_golden_path_v2)
@ -106,7 +104,6 @@ _UPDATE_WARNING_FILE = resource_loader.get_path_to_datafile(
_NON_CORE_PACKAGES = ['estimator'] _NON_CORE_PACKAGES = ['estimator']
# TODO(annarev): remove this once we test with newer version of # TODO(annarev): remove this once we test with newer version of
# estimator that actually has compat v1 version. # estimator that actually has compat v1 version.
if not hasattr(tf.compat.v1, 'estimator'): if not hasattr(tf.compat.v1, 'estimator'):
@ -282,9 +279,6 @@ class ApiCompatibilityTest(test.TestCase):
diff_count = len(diffs) diff_count = len(diffs)
logging.error(self._test_readme_message) logging.error(self._test_readme_message)
logging.error('%d differences found between API and golden.', diff_count) logging.error('%d differences found between API and golden.', diff_count)
messages = verbose_diffs if verbose else diffs
for i in range(diff_count):
print('Issue %d\t: %s' % (i + 1, messages[i]), file=sys.stderr)
if update_goldens: if update_goldens:
# Write files if requested. # Write files if requested.
@ -394,11 +388,12 @@ class ApiCompatibilityTest(test.TestCase):
resource_loader.get_root_dir_with_all_resources(), resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version)) _KeyToFilePath('*', api_version))
omit_golden_symbols_map = {} omit_golden_symbols_map = {}
if (api_version == 2 and FLAGS.only_test_core_api if (api_version == 2 and FLAGS.only_test_core_api and
and not _TENSORBOARD_AVAILABLE): not _TENSORBOARD_AVAILABLE):
# In TF 2.0 these summary symbols are imported from TensorBoard. # In TF 2.0 these summary symbols are imported from TensorBoard.
omit_golden_symbols_map['tensorflow.summary'] = [ omit_golden_symbols_map['tensorflow.summary'] = [
'audio', 'histogram', 'image', 'scalar', 'text'] 'audio', 'histogram', 'image', 'scalar', 'text'
]
self._checkBackwardsCompatibility( self._checkBackwardsCompatibility(
tf, tf,
@ -418,7 +413,9 @@ class ApiCompatibilityTest(test.TestCase):
resource_loader.get_root_dir_with_all_resources(), resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version)) _KeyToFilePath('*', api_version))
self._checkBackwardsCompatibility( self._checkBackwardsCompatibility(
tf.compat.v1, golden_file_pattern, api_version, tf.compat.v1,
golden_file_pattern,
api_version,
additional_private_map={ additional_private_map={
'tf': ['pywrap_tensorflow'], 'tf': ['pywrap_tensorflow'],
'tf.compat': ['v1', 'v2'], 'tf.compat': ['v1', 'v2'],
@ -434,7 +431,8 @@ class ApiCompatibilityTest(test.TestCase):
if FLAGS.only_test_core_api and not _TENSORBOARD_AVAILABLE: if FLAGS.only_test_core_api and not _TENSORBOARD_AVAILABLE:
# In TF 2.0 these summary symbols are imported from TensorBoard. # In TF 2.0 these summary symbols are imported from TensorBoard.
omit_golden_symbols_map['tensorflow.summary'] = [ omit_golden_symbols_map['tensorflow.summary'] = [
'audio', 'histogram', 'image', 'scalar', 'text'] 'audio', 'histogram', 'image', 'scalar', 'text'
]
self._checkBackwardsCompatibility( self._checkBackwardsCompatibility(
tf.compat.v2, tf.compat.v2,
golden_file_pattern, golden_file_pattern,

View File

@ -46,8 +46,15 @@ def _traverse_internal(root, visit, stack, path):
children.append(enum_member) children.append(enum_member)
children = sorted(children) children = sorted(children)
except ImportError: except ImportError:
# On some Python installations, some modules do not support enumerating # Children could be missing for one of two reasons:
# 1. On some Python installations, some modules do not support enumerating
# members (six in particular), leading to import errors. # members (six in particular), leading to import errors.
# 2. Children are lazy-loaded.
try:
children = []
for child_name in root.__all__:
children.append((child_name, getattr(root, child_name)))
except AttributeError:
children = [] children = []
new_stack = stack + [root] new_stack = stack + [root]