First step towards reducing the size of python/__init__.py.

API generation has been relying on python/__init__.py to look for modules with @tf_export decorators. Now, modules_with_exports.py should be used to specify modules to scan instead.

PiperOrigin-RevId: 314177907
Change-Id: I13460b5db55d4e56dd810106086ecb2aeca69a6b
This commit is contained in:
Anna R 2020-06-01 12:15:28 -07:00 committed by TensorFlower Gardener
parent 25213f58c4
commit 5498a3cb12
7 changed files with 249 additions and 174 deletions

View File

@ -239,6 +239,21 @@ py_library(
],
)
# This target should only be used for API generation.
py_library(
name = "modules_with_exports",
srcs = ["modules_with_exports.py"],
srcs_version = "PY2AND3",
visibility = [
"//tensorflow:__pkg__",
"//tensorflow/python/tools/api/generator:__pkg__",
"//third_party/py/tensorflow_core:__subpackages__",
],
deps = [
":no_contrib",
],
)
# TODO(gunan): Investigate making this action hermetic so we do not need
# to run it locally.
tf_py_build_info_genrule(

View File

@ -30,51 +30,14 @@ import importlib
import sys
import traceback
# TODO(drpng): write up instructions for editing this file in a doc and point to
# the doc instead.
# If you want to edit this file to expose modules in public tensorflow API, you
# need to follow these steps:
# 1. Consult with tensorflow team and get approval for adding a new API to the
# public interface.
# 2. Document the module in the gen_docs_combined.py.
# 3. Import the module in the main tensorflow namespace by adding an import
# statement in this file.
# 4. Sanitize the entry point by making sure that your module does not expose
# transitively imported modules used for implementation, such as os, sys.
# We aim to keep this file minimal and ideally remove completely.
# If you are adding a new file with @tf_export decorators,
# import it in modules_with_exports.py instead.
# go/tf-wildcard-import
# pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top
import numpy as np
from tensorflow.python import pywrap_tensorflow
# Protocol buffers
from tensorflow.core.framework.graph_pb2 import *
from tensorflow.core.framework.node_def_pb2 import *
from tensorflow.core.framework.summary_pb2 import *
from tensorflow.core.framework.attr_value_pb2 import *
from tensorflow.core.protobuf.meta_graph_pb2 import TensorInfo
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
from tensorflow.core.protobuf.config_pb2 import *
from tensorflow.core.protobuf.tensorflow_server_pb2 import *
from tensorflow.core.util.event_pb2 import *
# Framework
from tensorflow.python.framework.framework_lib import * # pylint: disable=redefined-builtin
from tensorflow.python.framework.versions import *
from tensorflow.python.framework import config
from tensorflow.python.framework import errors
from tensorflow.python.framework import graph_util
# Session
from tensorflow.python.client.client_lib import *
# Ops
from tensorflow.python.ops.standard_ops import *
# Namespaces
from tensorflow.python.ops import initializers_ns as initializers
from tensorflow.python.eager import context
# pylint: enable=wildcard-import
@ -152,8 +115,8 @@ from tensorflow.python.framework.ops import enable_eager_execution
# Check whether TF2_BEHAVIOR is turned on.
from tensorflow.python.eager import monitoring as _monitoring
from tensorflow.python import tf2 as _tf2
_tf2_gauge = _monitoring.BoolGauge('/tensorflow/api/tf2_enable',
'Environment variable TF2_BEHAVIOR is set".')
_tf2_gauge = _monitoring.BoolGauge(
'/tensorflow/api/tf2_enable', 'Environment variable TF2_BEHAVIOR is set".')
_tf2_gauge.get_cell().set(_tf2.enabled())
# Necessary for the symbols in this module to be taken into account by
@ -186,30 +149,6 @@ nn.bidirectional_dynamic_rnn = rnn.bidirectional_dynamic_rnn
nn.static_state_saving_rnn = rnn.static_state_saving_rnn
nn.rnn_cell = rnn_cell
# Export protos
# pylint: disable=undefined-variable
tf_export(v1=['AttrValue'])(AttrValue)
tf_export(v1=['ConfigProto'])(ConfigProto)
tf_export(v1=['Event', 'summary.Event'])(Event)
tf_export(v1=['GPUOptions'])(GPUOptions)
tf_export(v1=['GraphDef'])(GraphDef)
tf_export(v1=['GraphOptions'])(GraphOptions)
tf_export(v1=['HistogramProto'])(HistogramProto)
tf_export(v1=['LogMessage'])(LogMessage)
tf_export(v1=['MetaGraphDef'])(MetaGraphDef)
tf_export(v1=['NameAttrList'])(NameAttrList)
tf_export(v1=['NodeDef'])(NodeDef)
tf_export(v1=['OptimizerOptions'])(OptimizerOptions)
tf_export(v1=['RunMetadata'])(RunMetadata)
tf_export(v1=['RunOptions'])(RunOptions)
tf_export(v1=['SessionLog', 'summary.SessionLog'])(SessionLog)
tf_export(v1=['Summary', 'summary.Summary'])(Summary)
tf_export(v1=['summary.SummaryDescription'])(SummaryDescription)
tf_export(v1=['SummaryMetadata'])(SummaryMetadata)
tf_export(v1=['summary.TaggedRunMetadata'])(TaggedRunMetadata)
tf_export(v1=['TensorInfo'])(TensorInfo)
# pylint: enable=undefined-variable
# Special dunders that we choose to export:
_exported_dunders = set([
'__version__',

View File

@ -0,0 +1,78 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Imports modules that should be scanned during API generation.
This file should eventually contain everything we need to scan looking for
tf_export decorators.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# go/tf-wildcard-import
# pylint: disable=wildcard-import,g-bad-import-order,g-import-not-at-top
# pylint: disable=unused-import,g-importing-member
# Protocol buffers
from tensorflow.core.framework.graph_pb2 import *
from tensorflow.core.framework.node_def_pb2 import *
from tensorflow.core.framework.summary_pb2 import *
from tensorflow.core.framework.attr_value_pb2 import *
from tensorflow.core.protobuf.meta_graph_pb2 import TensorInfo
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
from tensorflow.core.protobuf.config_pb2 import *
from tensorflow.core.util.event_pb2 import *
# Framework
from tensorflow.python.framework.framework_lib import * # pylint: disable=redefined-builtin
from tensorflow.python.framework.versions import *
from tensorflow.python.framework import config
from tensorflow.python.framework import errors
from tensorflow.python.framework import graph_util
# Session
from tensorflow.python.client.client_lib import *
# Ops
from tensorflow.python.ops.standard_ops import *
# Namespaces
from tensorflow.python.ops import initializers_ns as initializers
from tensorflow.python.util.tf_export import tf_export
# Export protos
# pylint: disable=undefined-variable
tf_export(v1=['AttrValue'])(AttrValue)
tf_export(v1=['ConfigProto'])(ConfigProto)
tf_export(v1=['Event', 'summary.Event'])(Event)
tf_export(v1=['GPUOptions'])(GPUOptions)
tf_export(v1=['GraphDef'])(GraphDef)
tf_export(v1=['GraphOptions'])(GraphOptions)
tf_export(v1=['HistogramProto'])(HistogramProto)
tf_export(v1=['LogMessage'])(LogMessage)
tf_export(v1=['MetaGraphDef'])(MetaGraphDef)
tf_export(v1=['NameAttrList'])(NameAttrList)
tf_export(v1=['NodeDef'])(NodeDef)
tf_export(v1=['OptimizerOptions'])(OptimizerOptions)
tf_export(v1=['RunMetadata'])(RunMetadata)
tf_export(v1=['RunOptions'])(RunOptions)
tf_export(v1=['SessionLog', 'summary.SessionLog'])(SessionLog)
tf_export(v1=['Summary', 'summary.Summary'])(Summary)
tf_export(v1=['summary.SummaryDescription'])(SummaryDescription)
tf_export(v1=['SummaryMetadata'])(SummaryMetadata)
tf_export(v1=['summary.TaggedRunMetadata'])(TaggedRunMetadata)
tf_export(v1=['TensorInfo'])(TensorInfo)
# pylint: enable=undefined-variable

View File

@ -42,8 +42,15 @@ def gen_api_init_files(
api_version = 2,
compat_api_versions = [],
compat_init_templates = [],
packages = ["tensorflow.python", "tensorflow.lite.python.lite"],
package_deps = ["//tensorflow/python:no_contrib"],
packages = [
"tensorflow.python",
"tensorflow.lite.python.lite",
"tensorflow.python.modules_with_exports",
],
package_deps = [
"//tensorflow/python:no_contrib",
"//tensorflow/python:modules_with_exports",
],
output_package = "tensorflow",
output_dir = "",
root_file_name = "__init__.py"):

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Generates and prints out imports and constants for new TensorFlow python api.
"""
"""Generates and prints out imports and constants for new TensorFlow python api."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -85,9 +84,9 @@ def get_canonical_import(import_set):
ordering.
Args:
import_set: (set) Imports providing the same symbol. This is a set of
tuples in the form (import, priority). We want to pick an import
with highest priority.
import_set: (set) Imports providing the same symbol. This is a set of tuples
in the form (import, priority). We want to pick an import with highest
priority.
Returns:
A module name to import
@ -106,9 +105,11 @@ def get_canonical_import(import_set):
class _ModuleInitCodeBuilder(object):
"""Builds a map from module name to imports included in that module."""
def __init__(
self, output_package, api_version, lazy_loading=_LAZY_LOADING,
use_relative_imports=False):
def __init__(self,
output_package,
api_version,
lazy_loading=_LAZY_LOADING,
use_relative_imports=False):
self._output_package = output_package
# Maps API module to API symbol name to set of tuples of the form
# (module name, priority).
@ -127,16 +128,13 @@ class _ModuleInitCodeBuilder(object):
def _check_already_imported(self, symbol_id, api_name):
if (api_name in self._dest_import_to_id and
symbol_id != self._dest_import_to_id[api_name] and
symbol_id != -1):
symbol_id != self._dest_import_to_id[api_name] and symbol_id != -1):
raise SymbolExposedTwiceError(
'Trying to export multiple symbols with same name: %s.' %
api_name)
'Trying to export multiple symbols with same name: %s.' % api_name)
self._dest_import_to_id[api_name] = symbol_id
def add_import(
self, symbol, source_module_name, source_name, dest_module_name,
dest_name):
def add_import(self, symbol, source_module_name, source_name,
dest_module_name, dest_name):
"""Adds this import to module_imports.
Args:
@ -150,6 +148,10 @@ class _ModuleInitCodeBuilder(object):
SymbolExposedTwiceError: Raised when an import with the same
dest_name has already been added to dest_module_name.
"""
# modules_with_exports.py is only used during API generation and
# won't be available when actually importing tensorflow.
if source_module_name.endswith('python.modules_with_exports'):
source_module_name = symbol.__module__
import_str = self.format_import(source_module_name, source_name, dest_name)
# Check if we are trying to expose two different symbols with same name.
@ -191,7 +193,7 @@ class _ModuleInitCodeBuilder(object):
for submodule_index in range(len(module_split)):
if submodule_index > 0:
submodule = module_split[submodule_index-1]
submodule = module_split[submodule_index - 1]
parent_module += '.' + submodule if parent_module else submodule
import_from = self._output_package
if self._lazy_loading:
@ -264,8 +266,8 @@ __all__.extend([_s for _s in _names_with_underscore])
if not dest_module.startswith(_COMPAT_MODULE_PREFIX):
deprecation = 'True'
# Workaround to make sure not load lite from lite/__init__.py
if (not dest_module and 'lite' in self._module_imports
and self._lazy_loading):
if (not dest_module and 'lite' in self._module_imports and
self._lazy_loading):
has_lite = 'True'
if self._lazy_loading:
public_apis_name = '_PUBLIC_APIS'
@ -311,8 +313,8 @@ __all__.extend([_s for _s in _names_with_underscore])
self._module_imports[from_dest_module].copy())
def add_nested_compat_imports(
module_builder, compat_api_versions, output_package):
def add_nested_compat_imports(module_builder, compat_api_versions,
output_package):
"""Adds compat.vN.compat.vK modules to module builder.
To avoid circular imports, we want to add __init__.py files under
@ -334,8 +336,8 @@ def add_nested_compat_imports(
subcompat_module = _SUBCOMPAT_MODULE_TEMPLATE % (v, sv)
compat_module = _COMPAT_MODULE_TEMPLATE % sv
module_builder.copy_imports(compat_module, subcompat_module)
module_builder.copy_imports(
'%s.compat' % compat_module, '%s.compat' % subcompat_module)
module_builder.copy_imports('%s.compat' % compat_module,
'%s.compat' % subcompat_module)
# Prefixes of modules under compatibility packages, for e.g. "compat.v1.".
compat_prefixes = tuple(
@ -400,14 +402,13 @@ def _join_modules(module1, module2):
return '%s.%s' % (module1, module2)
def add_imports_for_symbol(
module_code_builder,
symbol,
source_module_name,
source_name,
api_name,
api_version,
output_module_prefix=''):
def add_imports_for_symbol(module_code_builder,
symbol,
source_module_name,
source_name,
api_name,
api_version,
output_module_prefix=''):
"""Add imports for the given symbol to `module_code_builder`.
Args:
@ -432,8 +433,8 @@ def add_imports_for_symbol(
for export in exports:
dest_module, dest_name = _get_name_and_module(export)
dest_module = _join_modules(output_module_prefix, dest_module)
module_code_builder.add_import(
None, source_module_name, name, dest_module, dest_name)
module_code_builder.add_import(None, source_module_name, name,
dest_module, dest_name)
# If symbol has _tf_api_names attribute, then add import for it.
if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__):
@ -442,8 +443,8 @@ def add_imports_for_symbol(
for export in getattr(symbol, names_attr): # pylint: disable=protected-access
dest_module, dest_name = _get_name_and_module(export)
dest_module = _join_modules(output_module_prefix, dest_module)
module_code_builder.add_import(
symbol, source_module_name, source_name, dest_module, dest_name)
module_code_builder.add_import(symbol, source_module_name, source_name,
dest_module, dest_name)
def get_api_init_text(packages,
@ -466,8 +467,8 @@ def get_api_init_text(packages,
directory.
lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is
produced and if `False`, static imports are used.
use_relative_imports: True if we should use relative imports when
importing submodules.
use_relative_imports: True if we should use relative imports when importing
submodules.
Returns:
A dictionary where
@ -477,8 +478,10 @@ def get_api_init_text(packages,
"""
if compat_api_versions is None:
compat_api_versions = []
module_code_builder = _ModuleInitCodeBuilder(
output_package, api_version, lazy_loading, use_relative_imports)
module_code_builder = _ModuleInitCodeBuilder(output_package, api_version,
lazy_loading,
use_relative_imports)
# Traverse over everything imported above. Specifically,
# we want to traverse over TensorFlow Python modules.
@ -496,24 +499,23 @@ def get_api_init_text(packages,
continue
for module_contents_name in dir(module):
if (module.__name__ + '.' + module_contents_name
in _SYMBOLS_TO_SKIP_EXPLICITLY):
if (module.__name__ + '.' +
module_contents_name in _SYMBOLS_TO_SKIP_EXPLICITLY):
continue
attr = getattr(module, module_contents_name)
_, attr = tf_decorator.unwrap(attr)
add_imports_for_symbol(
module_code_builder, attr, module.__name__, module_contents_name,
api_name, api_version)
add_imports_for_symbol(module_code_builder, attr, module.__name__,
module_contents_name, api_name, api_version)
for compat_api_version in compat_api_versions:
add_imports_for_symbol(
module_code_builder, attr, module.__name__, module_contents_name,
api_name, compat_api_version,
_COMPAT_MODULE_TEMPLATE % compat_api_version)
add_imports_for_symbol(module_code_builder, attr, module.__name__,
module_contents_name, api_name,
compat_api_version,
_COMPAT_MODULE_TEMPLATE % compat_api_version)
if compat_api_versions:
add_nested_compat_imports(
module_code_builder, compat_api_versions, output_package)
add_nested_compat_imports(module_code_builder, compat_api_versions,
output_package)
return module_code_builder.build()
@ -545,8 +547,8 @@ def get_module_docstring(module_name, package, api_name):
4. Returns a default docstring.
Args:
module_name: module name relative to tensorflow
(excluding 'tensorflow.' prefix) to get a docstring for.
module_name: module name relative to tensorflow (excluding 'tensorflow.'
prefix) to get a docstring for.
package: Base python package containing python with target tf_export
decorators.
api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
@ -581,31 +583,37 @@ def get_module_docstring(module_name, package, api_name):
return 'Public API for tf.%s namespace.' % module_name
def create_api_files(output_files, packages, root_init_template, output_dir,
output_package, api_name, api_version,
compat_api_versions, compat_init_templates,
lazy_loading=_LAZY_LOADING, use_relative_imports=False):
def create_api_files(output_files,
packages,
root_init_template,
output_dir,
output_package,
api_name,
api_version,
compat_api_versions,
compat_init_templates,
lazy_loading=_LAZY_LOADING,
use_relative_imports=False):
"""Creates __init__.py files for the Python API.
Args:
output_files: List of __init__.py file paths to create.
packages: Base python packages containing python with target tf_export
decorators.
root_init_template: Template for top-level __init__.py file.
"# API IMPORTS PLACEHOLDER" comment in the template file will be replaced
with imports.
root_init_template: Template for top-level __init__.py file. "# API IMPORTS
PLACEHOLDER" comment in the template file will be replaced with imports.
output_dir: output API root directory.
output_package: Base output package where generated API will be added.
api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
api_version: API version to generate (`v1` or `v2`).
compat_api_versions: Additional API versions to generate in compat/
subdirectory.
compat_init_templates: List of templates for top level compat init files
in the same order as compat_api_versions.
compat_init_templates: List of templates for top level compat init files in
the same order as compat_api_versions.
lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is
produced and if `False`, static imports are used.
use_relative_imports: True if we should use relative imports when
import submodules.
use_relative_imports: True if we should use relative imports when import
submodules.
Raises:
ValueError: if output_files list is missing a required file.
@ -645,8 +653,7 @@ def create_api_files(output_files, packages, root_init_template, output_dir,
for module, text in module_text_map.items():
# Make sure genrule output file list is in sync with API exports.
if module not in module_name_to_file_path:
module_file_path = '"%s/__init__.py"' % (
module.replace('.', '/'))
module_file_path = '"%s/__init__.py"' % (module.replace('.', '/'))
missing_output_files.append(module_file_path)
continue
@ -664,8 +671,9 @@ def create_api_files(output_files, packages, root_init_template, output_dir,
contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
else:
contents = (
_GENERATED_FILE_HEADER % get_module_docstring(
module, packages[0], api_name) + text + _GENERATED_FILE_FOOTER)
_GENERATED_FILE_HEADER %
get_module_docstring(module, packages[0], api_name) + text +
_GENERATED_FILE_FOOTER)
if module in deprecation_footer_map:
if '# WRAPPER_PLACEHOLDER' in contents:
contents = contents.replace('# WRAPPER_PLACEHOLDER',
@ -680,14 +688,17 @@ def create_api_files(output_files, packages, root_init_template, output_dir,
"""Missing outputs for genrule:\n%s. Be sure to add these targets to
tensorflow/python/tools/api/generator/api_init_files_v1.bzl and
tensorflow/python/tools/api/generator/api_init_files.bzl (tensorflow repo), or
tensorflow_estimator/python/estimator/api/api_gen.bzl (estimator repo)"""
% ',\n'.join(sorted(missing_output_files)))
tensorflow_estimator/python/estimator/api/api_gen.bzl (estimator repo)""" %
',\n'.join(sorted(missing_output_files)))
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'outputs', metavar='O', type=str, nargs='+',
'outputs',
metavar='O',
type=str,
nargs='+',
help='If a single file is passed in, then we we assume it contains a '
'semicolon-separated list of Python files that we expect this script to '
'output. If multiple files are passed in, then we assume output files '
@ -699,46 +710,66 @@ def main():
help='Base packages that import modules containing the target tf_export '
'decorators.')
parser.add_argument(
'--root_init_template', default='', type=str,
'--root_init_template',
default='',
type=str,
help='Template for top level __init__.py file. '
'"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.')
'"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.')
parser.add_argument(
'--apidir', type=str, required=True,
'--apidir',
type=str,
required=True,
help='Directory where generated output files are placed. '
'gendir should be a prefix of apidir. Also, apidir '
'should be a prefix of every directory in outputs.')
'gendir should be a prefix of apidir. Also, apidir '
'should be a prefix of every directory in outputs.')
parser.add_argument(
'--apiname', required=True, type=str,
'--apiname',
required=True,
type=str,
choices=API_ATTRS.keys(),
help='The API you want to generate.')
parser.add_argument(
'--apiversion', default=2, type=int,
'--apiversion',
default=2,
type=int,
choices=_API_VERSIONS,
help='The API version you want to generate.')
parser.add_argument(
'--compat_apiversions', default=[], type=int, action='append',
'--compat_apiversions',
default=[],
type=int,
action='append',
help='Additional versions to generate in compat/ subdirectory. '
'If set to 0, then no additional version would be generated.')
'If set to 0, then no additional version would be generated.')
parser.add_argument(
'--compat_init_templates', default=[], type=str, action='append',
'--compat_init_templates',
default=[],
type=str,
action='append',
help='Templates for top-level __init__ files under compat modules. '
'The list of init file templates must be in the same order as '
'list of versions passed with compat_apiversions.')
'The list of init file templates must be in the same order as '
'list of versions passed with compat_apiversions.')
parser.add_argument(
'--output_package', default='tensorflow', type=str,
'--output_package',
default='tensorflow',
type=str,
help='Root output package.')
parser.add_argument(
'--loading', default='default', type=str,
'--loading',
default='default',
type=str,
choices=['lazy', 'static', 'default'],
help='Controls how the generated __init__.py file loads the exported '
'symbols. \'lazy\' means the symbols are loaded when first used. '
'\'static\' means all exported symbols are loaded in the '
'__init__.py file. \'default\' uses the value of the '
'_LAZY_LOADING constant in create_python_api.py.')
'symbols. \'lazy\' means the symbols are loaded when first used. '
'\'static\' means all exported symbols are loaded in the '
'__init__.py file. \'default\' uses the value of the '
'_LAZY_LOADING constant in create_python_api.py.')
parser.add_argument(
'--use_relative_imports', default=False, type=bool,
'--use_relative_imports',
default=False,
type=bool,
help='Whether to import submodules using relative imports or absolute '
'imports')
'imports')
args = parser.parse_args()
if len(args.outputs) == 1:

View File

@ -34,7 +34,6 @@ import re
import sys
import six
from six.moves import range
import tensorflow as tf
from google.protobuf import message
@ -87,9 +86,8 @@ _API_GOLDEN_FOLDER_V2 = None
def _InitPathConstants():
global _API_GOLDEN_FOLDER_V1
global _API_GOLDEN_FOLDER_V2
root_golden_path_v2 = os.path.join(
resource_loader.get_data_files_path(), '..', 'golden', 'v2',
'tensorflow.pbtxt')
root_golden_path_v2 = os.path.join(resource_loader.get_data_files_path(),
'..', 'golden', 'v2', 'tensorflow.pbtxt')
if FLAGS.update_goldens:
root_golden_path_v2 = os.path.realpath(root_golden_path_v2)
@ -106,7 +104,6 @@ _UPDATE_WARNING_FILE = resource_loader.get_path_to_datafile(
_NON_CORE_PACKAGES = ['estimator']
# TODO(annarev): remove this once we test with newer version of
# estimator that actually has compat v1 version.
if not hasattr(tf.compat.v1, 'estimator'):
@ -282,9 +279,6 @@ class ApiCompatibilityTest(test.TestCase):
diff_count = len(diffs)
logging.error(self._test_readme_message)
logging.error('%d differences found between API and golden.', diff_count)
messages = verbose_diffs if verbose else diffs
for i in range(diff_count):
print('Issue %d\t: %s' % (i + 1, messages[i]), file=sys.stderr)
if update_goldens:
# Write files if requested.
@ -394,11 +388,12 @@ class ApiCompatibilityTest(test.TestCase):
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
omit_golden_symbols_map = {}
if (api_version == 2 and FLAGS.only_test_core_api
and not _TENSORBOARD_AVAILABLE):
if (api_version == 2 and FLAGS.only_test_core_api and
not _TENSORBOARD_AVAILABLE):
# In TF 2.0 these summary symbols are imported from TensorBoard.
omit_golden_symbols_map['tensorflow.summary'] = [
'audio', 'histogram', 'image', 'scalar', 'text']
'audio', 'histogram', 'image', 'scalar', 'text'
]
self._checkBackwardsCompatibility(
tf,
@ -418,7 +413,9 @@ class ApiCompatibilityTest(test.TestCase):
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
self._checkBackwardsCompatibility(
tf.compat.v1, golden_file_pattern, api_version,
tf.compat.v1,
golden_file_pattern,
api_version,
additional_private_map={
'tf': ['pywrap_tensorflow'],
'tf.compat': ['v1', 'v2'],
@ -434,7 +431,8 @@ class ApiCompatibilityTest(test.TestCase):
if FLAGS.only_test_core_api and not _TENSORBOARD_AVAILABLE:
# In TF 2.0 these summary symbols are imported from TensorBoard.
omit_golden_symbols_map['tensorflow.summary'] = [
'audio', 'histogram', 'image', 'scalar', 'text']
'audio', 'histogram', 'image', 'scalar', 'text'
]
self._checkBackwardsCompatibility(
tf.compat.v2,
golden_file_pattern,

View File

@ -46,9 +46,16 @@ def _traverse_internal(root, visit, stack, path):
children.append(enum_member)
children = sorted(children)
except ImportError:
# On some Python installations, some modules do not support enumerating
# members (six in particular), leading to import errors.
children = []
# Children could be missing for one of two reasons:
# 1. On some Python installations, some modules do not support enumerating
# members (six in particular), leading to import errors.
# 2. Children are lazy-loaded.
try:
children = []
for child_name in root.__all__:
children.append((child_name, getattr(root, child_name)))
except AttributeError:
children = []
new_stack = stack + [root]
visit(path, root, children)