Generalize LazyLoader for use by ffmpeg
Add __dir__ method so the docs generator doesn't need to do anything special to activate the loading Change: 153583515
This commit is contained in:
parent
8b6c3c8e88
commit
95c5d7e880
@ -24,19 +24,9 @@ from __future__ import print_function
|
||||
from tensorflow.python import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
# Lazily import the `tf.contrib` module. This avoids loading all of the
|
||||
# dependencies of `tf.contrib` at `import tensorflow` time.
|
||||
class _LazyContribLoader(object):
|
||||
|
||||
def __getattr__(self, item):
|
||||
global contrib
|
||||
# Replace the lazy loader with the imported module itself.
|
||||
import importlib # pylint: disable=g-import-not-at-top
|
||||
contrib = importlib.import_module('tensorflow.contrib')
|
||||
return getattr(contrib, item)
|
||||
|
||||
|
||||
contrib = _LazyContribLoader()
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
|
||||
del LazyLoader
|
||||
|
||||
del absolute_import
|
||||
del division
|
||||
|
@ -67,6 +67,11 @@ from tensorflow.contrib import util
|
||||
from tensorflow.contrib.ndlstm import python as ndlstm
|
||||
from tensorflow.contrib.specs import python as specs
|
||||
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
ffmpeg = LazyLoader("ffmpeg", globals(),
|
||||
"tensorflow.contrib.ffmpeg")
|
||||
del LazyLoader
|
||||
|
||||
del absolute_import
|
||||
del division
|
||||
del print_function
|
||||
|
58
tensorflow/python/util/lazy_loader.py
Normal file
58
tensorflow/python/util/lazy_loader.py
Normal file
@ -0,0 +1,58 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""A LazyLoader class."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import importlib
|
||||
import types
|
||||
|
||||
|
||||
class LazyLoader(types.ModuleType):
|
||||
"""Lazily import a module, mainly to avoid pulling in large dependancies.
|
||||
|
||||
`contrib`, and `ffmpeg` are examples of modules that are large and not always
|
||||
needed, and this allows them to only be loaded when they are used.
|
||||
"""
|
||||
|
||||
# The lint error here is incorrect.
|
||||
def __init__(self, local_name, parent_module_globals, name): # pylint: disable=super-on-old-class
|
||||
self._local_name = local_name
|
||||
self._parent_module_globals = parent_module_globals
|
||||
|
||||
super(LazyLoader, self).__init__(name)
|
||||
|
||||
def _load(self):
|
||||
# Import the target module and insert it into the parent's namespace
|
||||
module = importlib.import_module(self.__name__)
|
||||
self._parent_module_globals[self._local_name] = module
|
||||
|
||||
# Update this object's dict so that if someone keeps a reference to the
|
||||
# LazyLoader, lookups are efficient (__getattr__ is only called on lookups
|
||||
# that fail).
|
||||
self.__dict__.update(module.__dict__)
|
||||
|
||||
return module
|
||||
|
||||
def __getattr__(self, item):
|
||||
module = self._load()
|
||||
return getattr(module, item)
|
||||
|
||||
def __dir__(self):
|
||||
module = self._load()
|
||||
return dir(module)
|
@ -278,7 +278,7 @@ tf_module {
|
||||
}
|
||||
member {
|
||||
name: "contrib"
|
||||
mtype: "<class \'_LazyContribLoader\'>"
|
||||
mtype: "<class \'tensorflow.python.util.lazy_loader.LazyLoader\'>"
|
||||
}
|
||||
member {
|
||||
name: "double"
|
||||
|
@ -27,9 +27,7 @@ from tensorflow.tools.api.lib import api_objects_pb2
|
||||
|
||||
# Following object need to be handled individually.
|
||||
_CORNER_CASES = {
|
||||
'': {'contrib': {'name': 'contrib',
|
||||
'mtype': '<class \'_LazyContribLoader\'>'},
|
||||
'tools': {}},
|
||||
'': {'tools': {}},
|
||||
'test.TestCase': {},
|
||||
'test.TestCase.failureException': {},
|
||||
}
|
||||
|
@ -196,7 +196,11 @@ class ApiCompatibilityTest(test.TestCase):
|
||||
def testAPIBackwardsCompatibility(self):
|
||||
# Extract all API stuff.
|
||||
visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
|
||||
traverse.traverse(tf, public_api.PublicAPIVisitor(visitor))
|
||||
|
||||
public_api_visitor = public_api.PublicAPIVisitor(visitor)
|
||||
public_api_visitor.do_not_descend_map[''].append('contrib')
|
||||
traverse.traverse(tf, public_api_visitor)
|
||||
|
||||
proto_dict = visitor.GetProtos()
|
||||
|
||||
# Read all golden files.
|
||||
|
@ -36,29 +36,35 @@ class PublicAPIVisitor(object):
|
||||
"""
|
||||
self._visitor = visitor
|
||||
|
||||
# Modules/classes we do not want to descend into if we hit them. Usually,
|
||||
# sytem modules exposed through platforms for compatibility reasons.
|
||||
# Each entry maps a module path to a name to ignore in traversal.
|
||||
_do_not_descend_map = {
|
||||
'': [
|
||||
'core',
|
||||
'examples',
|
||||
'flags', # Don't add flags
|
||||
'platform', # TODO(drpng): This can be removed once sealed off.
|
||||
'pywrap_tensorflow', # TODO(drpng): This can be removed once sealed.
|
||||
'user_ops', # TODO(drpng): This can be removed once sealed.
|
||||
'python',
|
||||
'tools',
|
||||
'tensorboard',
|
||||
],
|
||||
# Modules/classes we do not want to descend into if we hit them. Usually,
|
||||
# sytem modules exposed through platforms for compatibility reasons.
|
||||
# Each entry maps a module path to a name to ignore in traversal.
|
||||
self._do_not_descend_map = {
|
||||
'': [
|
||||
'core',
|
||||
'examples',
|
||||
'flags', # Don't add flags
|
||||
# TODO(drpng): This can be removed once sealed off.
|
||||
'platform',
|
||||
# TODO(drpng): This can be removed once sealed.
|
||||
'pywrap_tensorflow',
|
||||
# TODO(drpng): This can be removed once sealed.
|
||||
'user_ops',
|
||||
'python',
|
||||
'tools',
|
||||
'tensorboard',
|
||||
],
|
||||
|
||||
# Some implementations have this internal module that we shouldn't expose.
|
||||
'flags': ['cpp_flags'],
|
||||
# Some implementations have this internal module that we shouldn't
|
||||
# expose.
|
||||
'flags': ['cpp_flags'],
|
||||
|
||||
# Everything below here is legitimate.
|
||||
'app': ['flags'], # It'll stay, but it's not officially part of the API.
|
||||
'test': ['mock'], # Imported for compatibility between py2/3.
|
||||
}
|
||||
## Everything below here is legitimate.
|
||||
# It'll stay, but it's not officially part of the API.
|
||||
'app': ['flags'],
|
||||
# Imported for compatibility between py2/3.
|
||||
'test': ['mock'],
|
||||
}
|
||||
|
||||
@property
|
||||
def do_not_descend_map(self):
|
||||
|
@ -40,7 +40,6 @@ class BuildDocsTest(googletest.TestCase):
|
||||
doc_generator = generate_lib.DocGenerator()
|
||||
|
||||
doc_generator.set_py_modules([('tf', tf), ('tfdbg', tf_debug)])
|
||||
doc_generator.load_contrib()
|
||||
|
||||
status = doc_generator.build(Flags())
|
||||
|
||||
|
@ -46,6 +46,5 @@ if __name__ == '__main__':
|
||||
|
||||
# tf_debug is not imported with tf, it's a separate module altogether
|
||||
doc_generator.set_py_modules([('tf', tf), ('tfdbg', tf_debug)])
|
||||
doc_generator.load_contrib()
|
||||
|
||||
sys.exit(doc_generator.build(flags))
|
||||
|
@ -47,7 +47,6 @@ if __name__ == '__main__':
|
||||
# tf_debug is not imported with tf, it's a separate module altogether
|
||||
doc_generator.set_py_modules([('tf', tf), ('tfdbg', tf_debug)])
|
||||
|
||||
doc_generator.load_contrib()
|
||||
doc_generator.set_do_not_descend_map({
|
||||
'': ['cli', 'lib', 'wrappers'],
|
||||
'contrib': [
|
||||
|
@ -439,19 +439,6 @@ class DocGenerator(object):
|
||||
def set_py_modules(self, py_modules):
|
||||
self._py_modules = py_modules
|
||||
|
||||
def load_contrib(self):
|
||||
"""Access something in contrib so tf.contrib is properly loaded."""
|
||||
# Without this, it ends up hidden behind lazy loading. Requires
|
||||
# that the caller has already called set_py_modules().
|
||||
if self._py_modules is None:
|
||||
raise RuntimeError(
|
||||
'Must call set_py_modules() before running load_contrib().')
|
||||
for name, module in self._py_modules:
|
||||
if name == 'tf':
|
||||
_ = module.contrib.__name__
|
||||
return True
|
||||
return False
|
||||
|
||||
def py_module_names(self):
|
||||
if self._py_modules is None:
|
||||
raise RuntimeError(
|
||||
|
@ -56,7 +56,7 @@ class GenerateTest(googletest.TestCase):
|
||||
|
||||
def test_extraction(self):
|
||||
py_modules = [('tf', tf), ('tfdbg', tf_debug)]
|
||||
_ = tf.contrib.__name__ # Trigger loading of tf.contrib
|
||||
|
||||
try:
|
||||
generate_lib.extract(
|
||||
py_modules, generate_lib._get_default_do_not_descend_map())
|
||||
|
Loading…
Reference in New Issue
Block a user