Extending the default PyModule_Type as a FastModuleType and use it for TFModuleWrapper. This extended type has improved attribute lookup mechanism to shorten tf api lookup speed.
PiperOrigin-RevId: 351187183 Change-Id: I02fe01dfff4f0669c7f5fbf0c1e7f3641a891e6d
This commit is contained in:
parent
8a7c07385c
commit
a17e1fc7ef
tensorflow/python
tools/api/generator
util
@ -24,6 +24,7 @@ py_library(
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/tools/api/generator:doc_srcs",
|
||||
"//tensorflow/python/util:fast_module_type",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -575,12 +575,35 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "fast_module_type",
|
||||
srcs = ["fast_module_type.cc"],
|
||||
module_name = "fast_module_type",
|
||||
deps = [
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "fast_module_type_test",
|
||||
srcs = ["fast_module_type_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":fast_module_type",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "module_wrapper_test",
|
||||
size = "small",
|
||||
srcs = ["module_wrapper_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":fast_module_type",
|
||||
":util",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/tools/compatibility:all_renames_v2",
|
||||
|
292
tensorflow/python/util/fast_module_type.cc
Normal file
292
tensorflow/python/util/fast_module_type.cc
Normal file
@ -0,0 +1,292 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <Python.h>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
constexpr int PY_MODULE_TYPE_TP_BASIC_SIZE = 56;
|
||||
|
||||
struct FastModuleObject {
|
||||
// A dummy array that ensures enough size is reserved for FastModuleObject,
|
||||
// because it's inherited from PyModuleObject.
|
||||
const std::array<char, PY_MODULE_TYPE_TP_BASIC_SIZE> opaque_base_fields;
|
||||
// A cache that helps reduce attribute lookup overhead.
|
||||
absl::flat_hash_map<PyObject *, PyObject *> attr_map;
|
||||
// pointer to the external getattribute function
|
||||
PyObject *cb_getattribute = nullptr;
|
||||
// pointer to the external getattr function
|
||||
PyObject *cb_getattr = nullptr;
|
||||
// static PyTypeObject type;
|
||||
|
||||
FastModuleObject() = delete;
|
||||
~FastModuleObject() = delete;
|
||||
static FastModuleObject *UncheckedCast(PyObject *obj);
|
||||
};
|
||||
|
||||
static int FastModule_init(FastModuleObject *self, PyObject *args,
|
||||
PyObject *kwds) {
|
||||
DCHECK_EQ(PY_MODULE_TYPE_TP_BASIC_SIZE, PyModule_Type.tp_basicsize);
|
||||
if (PyModule_Type.tp_init(reinterpret_cast<PyObject *>(self), args, kwds) < 0)
|
||||
return -1;
|
||||
new (&(self->attr_map)) absl::flat_hash_map<PyObject *, PyObject *>();
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Parses the input as a callable and checks the result.
|
||||
static PyObject *ParseFunc(PyObject *args) {
|
||||
PyObject *func;
|
||||
if (!PyArg_ParseTuple(args, "O:set_callback", &func)) return nullptr;
|
||||
if (!PyCallable_Check(func)) {
|
||||
PyErr_SetString(PyExc_TypeError, "input args must be callable");
|
||||
return nullptr;
|
||||
}
|
||||
Py_INCREF(func); // Add a reference to new callback
|
||||
return func;
|
||||
}
|
||||
|
||||
// Sets the pointer 'cb_getattribute' in the FastModuleObject object
|
||||
// corresponding to 'self'.
|
||||
static PyObject *SetGetattributeCallback(PyObject *self, PyObject *args) {
|
||||
PyObject *func = ParseFunc(args);
|
||||
// Dispose of previous callback
|
||||
Py_XDECREF(FastModuleObject::UncheckedCast(self)->cb_getattribute);
|
||||
// Remember new callback
|
||||
FastModuleObject::UncheckedCast(self)->cb_getattribute = func;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
// Sets the pointer 'cb_getattr' in the FastModuleObject object
|
||||
// corresponding to 'self'.
|
||||
static PyObject *SetGetattrCallback(PyObject *self, PyObject *args) {
|
||||
PyObject *func = ParseFunc(args);
|
||||
// Dispose of previous callback
|
||||
Py_XDECREF(FastModuleObject::UncheckedCast(self)->cb_getattr);
|
||||
// Remember new callback
|
||||
FastModuleObject::UncheckedCast(self)->cb_getattr = func;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
// Inserts or updates a key-value pair in the cache 'attr_map'
|
||||
// of the FastModuleObject object corresponding to 'self'.
|
||||
static PyObject *FastDictInsert(FastModuleObject *self, PyObject *args) {
|
||||
PyObject *name, *value;
|
||||
if (!PyArg_ParseTuple(args, "OO", &name, &value)) {
|
||||
PyErr_SetString(PyExc_TypeError, "_fastdict_insert: incorrect inputs");
|
||||
return nullptr;
|
||||
}
|
||||
auto &attr_map = self->attr_map;
|
||||
if (attr_map.find(name) != attr_map.end()) {
|
||||
Py_DECREF(name);
|
||||
Py_DECREF(value);
|
||||
}
|
||||
attr_map.insert_or_assign(name, value);
|
||||
// Increment the reference count
|
||||
Py_INCREF(name);
|
||||
Py_INCREF(value);
|
||||
// Properly handle returning Py_None
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
// Gets a value from a key in the cache 'attr_map'
|
||||
// of the FastModuleObject object corresponding to 'self'.
|
||||
static PyObject *FastDictGet(FastModuleObject *self, PyObject *args) {
|
||||
PyObject *name;
|
||||
if (!PyArg_ParseTuple(args, "O", &name)) {
|
||||
PyErr_SetString(PyExc_TypeError, "_fastdict_get: incorrect inputs");
|
||||
return nullptr;
|
||||
}
|
||||
auto &attr_map = self->attr_map;
|
||||
auto result = attr_map.find(name);
|
||||
if (result != attr_map.end()) {
|
||||
PyObject *value = result->second;
|
||||
Py_INCREF(value);
|
||||
return value;
|
||||
}
|
||||
// Copied from CPython's moduleobject.c
|
||||
PyErr_Format(PyExc_KeyError, "module has no attribute '%U'", name);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Returns true if a key exists in the cache 'attr_map'
|
||||
// of the FastModuleObject object corresponding to 'self',
|
||||
// otherwise returns false.
|
||||
static PyObject *FastDictContains(FastModuleObject *self, PyObject *args) {
|
||||
PyObject *name;
|
||||
if (!PyArg_ParseTuple(args, "O", &name)) {
|
||||
PyErr_SetString(PyExc_TypeError, "_fastdict_key_in: incorrect inputs");
|
||||
return nullptr;
|
||||
}
|
||||
const auto &attr_map = self->attr_map;
|
||||
const auto result = attr_map.contains(name);
|
||||
if (result) {
|
||||
// Properly handle returning Py_True
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
// Properly handle returning Py_False
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
|
||||
// Calls a function 'func' with inputs 'self' and 'args'.
|
||||
static PyObject *CallFunc(FastModuleObject *self, PyObject *args,
|
||||
PyObject *func) {
|
||||
if (func == nullptr) {
|
||||
PyErr_SetString(PyExc_NameError,
|
||||
"Attempting to call a callback that was not defined");
|
||||
return nullptr;
|
||||
}
|
||||
PyObject *name;
|
||||
if (!PyArg_ParseTuple(args, "O", &name)) {
|
||||
PyErr_SetString(PyExc_TypeError, "CallFunc: incorrect inputs");
|
||||
return nullptr;
|
||||
}
|
||||
PyObject *arglist = Py_BuildValue("(OO)", self, name);
|
||||
auto result = PyObject_CallObject(func, arglist);
|
||||
Py_DECREF(arglist);
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyMethodDef FastModule_methods[] = {
|
||||
{"_fastdict_insert", reinterpret_cast<PyCFunction>(FastDictInsert),
|
||||
METH_VARARGS, "Registers a method to the fast lookup table."},
|
||||
{"_fastdict_get", reinterpret_cast<PyCFunction>(FastDictGet), METH_VARARGS,
|
||||
"Gets a method from the fast lookup table."},
|
||||
{"_fastdict_key_in", reinterpret_cast<PyCFunction>(FastDictContains),
|
||||
METH_VARARGS, "Checks if a method exists in the fast lookup table."},
|
||||
{"set_getattribute_callback", SetGetattributeCallback, METH_VARARGS,
|
||||
"Defines the callback function to replace __getattribute__"},
|
||||
{"set_getattr_callback", SetGetattrCallback, METH_VARARGS,
|
||||
"Defines the callback function to replace __getattr__"},
|
||||
{nullptr, nullptr, 0, nullptr},
|
||||
};
|
||||
|
||||
// Attempts to get the attribute based on 'name' as the key in cache 'attr_map'
|
||||
// of the FastModuleObject object corresponding to 'module'.
|
||||
// If the lookup fails in the cache, either uses
|
||||
// a user-defined callback 'cb_getattribute'
|
||||
// or the default 'tp_getattro' function to look for the attribute.
|
||||
static PyObject *FastTpGetattro(PyObject *module, PyObject *name) {
|
||||
FastModuleObject *fast_module = FastModuleObject::UncheckedCast(module);
|
||||
auto &attr_map = fast_module->attr_map;
|
||||
auto it = attr_map.find(name);
|
||||
// If the attribute lookup is successful in the cache, directly return it.
|
||||
if (it != attr_map.end()) {
|
||||
PyObject *value = it->second;
|
||||
Py_INCREF(value);
|
||||
return value;
|
||||
}
|
||||
PyObject *arglist = Py_BuildValue("(O)", name);
|
||||
PyObject *result;
|
||||
// Prefer the customized callback function over the default function.
|
||||
if (fast_module->cb_getattribute != nullptr) {
|
||||
result = CallFunc(fast_module, arglist, fast_module->cb_getattribute);
|
||||
} else {
|
||||
result = PyModule_Type.tp_getattro(module, name);
|
||||
}
|
||||
// Return result if it's found
|
||||
if (result != nullptr) {
|
||||
return result;
|
||||
}
|
||||
// If the default lookup fails and an AttributeError is raised,
|
||||
// clear the error status before using the __getattr__ callback function.
|
||||
auto is_error = PyErr_Occurred();
|
||||
if (is_error && PyErr_ExceptionMatches(PyExc_AttributeError) &&
|
||||
fast_module->cb_getattr != nullptr) {
|
||||
PyErr_Clear();
|
||||
return CallFunc(fast_module, arglist, fast_module->cb_getattr);
|
||||
}
|
||||
// If all options were used up
|
||||
return result;
|
||||
}
|
||||
|
||||
// Customized destructor for FastModuleType.tp_dealloc
|
||||
// In addition to default behavior it also clears up the contents in attr_map.
|
||||
static void FastModuleObjectDealloc(PyObject *module) {
|
||||
auto &attr_map = FastModuleObject::UncheckedCast(module)->attr_map;
|
||||
for (auto &it : attr_map) {
|
||||
Py_DECREF(it.first);
|
||||
Py_DECREF(it.second);
|
||||
}
|
||||
attr_map.~flat_hash_map<PyObject *, PyObject *>();
|
||||
Py_TYPE(module)->tp_free(module);
|
||||
}
|
||||
|
||||
static PyTypeObject FastModuleType = []() {
|
||||
PyTypeObject obj = {PyVarObject_HEAD_INIT(&PyType_Type, 0)};
|
||||
obj.tp_name = "fast_module_type.FastModuleType";
|
||||
obj.tp_basicsize = sizeof(FastModuleObject);
|
||||
obj.tp_itemsize = 0;
|
||||
obj.tp_dealloc = FastModuleObjectDealloc;
|
||||
obj.tp_getattro = FastTpGetattro;
|
||||
obj.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
|
||||
obj.tp_doc = "FastModuleType objects";
|
||||
obj.tp_methods = FastModule_methods;
|
||||
obj.tp_init = reinterpret_cast<initproc>(FastModule_init);
|
||||
return obj;
|
||||
}();
|
||||
|
||||
// Returns true if the type of 'obj' or any of its parent class
|
||||
// is equal to 'target'. Otherwise returns false.
|
||||
bool IsAnyBaseSameType(const PyObject *obj, const PyTypeObject *target) {
|
||||
auto *tp = Py_TYPE(obj);
|
||||
while (true) {
|
||||
if (tp == target) return true;
|
||||
// If the default type is found, there is no need to search further
|
||||
if (tp == &PyBaseObject_Type) break;
|
||||
tp = tp->tp_base;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Casts 'obj' to 'FastModuleObject *'.
|
||||
// Conducts a check only in non-optimized builds.
|
||||
FastModuleObject *FastModuleObject::UncheckedCast(PyObject *obj) {
|
||||
DCHECK(IsAnyBaseSameType(obj, &FastModuleType));
|
||||
return reinterpret_cast<FastModuleObject *>(obj);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(fast_module_type, m) {
|
||||
FastModuleType.tp_base = &PyModule_Type;
|
||||
FastModuleType.tp_setattro = [](PyObject *module, PyObject *name,
|
||||
PyObject *value) -> int {
|
||||
auto &attr_map = FastModuleObject::UncheckedCast(module)->attr_map;
|
||||
if (attr_map.find(name) != attr_map.end()) {
|
||||
Py_DECREF(name);
|
||||
Py_DECREF(value);
|
||||
}
|
||||
attr_map.insert_or_assign(name, value);
|
||||
// Increment the reference count
|
||||
Py_INCREF(name);
|
||||
Py_INCREF(value);
|
||||
PyObject_GenericSetAttr(module, name, value);
|
||||
return 0;
|
||||
};
|
||||
|
||||
m.doc() = R"pbdoc(
|
||||
fast_module_type
|
||||
-----
|
||||
)pbdoc";
|
||||
// Use getter function to hold attributes rather than pybind11's m.attr due to
|
||||
// b/145559202.
|
||||
m.def(
|
||||
"get_fast_module_type_class",
|
||||
[]() {
|
||||
return py::cast<py::object>(
|
||||
reinterpret_cast<PyObject *>(&FastModuleType));
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
}
|
71
tensorflow/python/util/fast_module_type_test.py
Normal file
71
tensorflow/python/util/fast_module_type_test.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.python.util.fast_module_type."""
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import fast_module_type
|
||||
FastModuleType = fast_module_type.get_fast_module_type_class()
|
||||
|
||||
|
||||
class ChildFastModule(FastModuleType):
|
||||
|
||||
def _getattribute1(self, name): # pylint: disable=unused-argument
|
||||
return 2
|
||||
|
||||
def _getattribute2(self, name): # pylint: disable=unused-argument
|
||||
raise AttributeError("Pass to getattr")
|
||||
|
||||
def _getattr(self, name): # pylint: disable=unused-argument
|
||||
return 3
|
||||
|
||||
|
||||
class FastModuleTypeTest(test.TestCase):
|
||||
|
||||
def testBaseGetattribute(self):
|
||||
# Tests that the default attribute lookup works.
|
||||
module = ChildFastModule("test")
|
||||
module.foo = 1
|
||||
self.assertEqual(1, module.foo)
|
||||
|
||||
def testGetattributeCallback(self):
|
||||
# Tests that functionality of __getattribute__ can be set as a callback.
|
||||
module = ChildFastModule("test")
|
||||
FastModuleType.set_getattribute_callback(module,
|
||||
ChildFastModule._getattribute1)
|
||||
self.assertEqual(2, module.foo)
|
||||
|
||||
def testGetattrCallback(self):
|
||||
# Tests that functionality of __getattr__ can be set as a callback.
|
||||
module = ChildFastModule("test")
|
||||
FastModuleType.set_getattribute_callback(module,
|
||||
ChildFastModule._getattribute2)
|
||||
FastModuleType.set_getattr_callback(module, ChildFastModule._getattr)
|
||||
self.assertEqual(3, module.foo)
|
||||
|
||||
def testFastdictApis(self):
|
||||
module = ChildFastModule("test")
|
||||
# At first "bar" does not exist in the module's attributes
|
||||
self.assertFalse(module._fastdict_key_in("bar"))
|
||||
with self.assertRaisesRegex(KeyError, "module has no attribute 'bar'"):
|
||||
module._fastdict_get("bar")
|
||||
|
||||
module._fastdict_insert("bar", 1)
|
||||
# After _fastdict_insert() the attribute is added.
|
||||
self.assertTrue(module._fastdict_key_in("bar"))
|
||||
self.assertEqual(1, module.bar)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -20,14 +20,14 @@ from __future__ import print_function
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import types
|
||||
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import fast_module_type
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.tools.compatibility import all_renames_v2
|
||||
|
||||
|
||||
FastModuleType = fast_module_type.get_fast_module_type_class()
|
||||
_PER_MODULE_WARNING_LIMIT = 1
|
||||
|
||||
|
||||
@ -50,8 +50,7 @@ def _call_location():
|
||||
|
||||
|
||||
def contains_deprecation_decorator(decorators):
|
||||
return any(
|
||||
d.decorator_name == 'deprecated' for d in decorators)
|
||||
return any(d.decorator_name == 'deprecated' for d in decorators)
|
||||
|
||||
|
||||
def has_deprecation_decorator(symbol):
|
||||
@ -79,7 +78,7 @@ def has_deprecation_decorator(symbol):
|
||||
return contains_deprecation_decorator(init_decorators)
|
||||
|
||||
|
||||
class TFModuleWrapper(types.ModuleType):
|
||||
class TFModuleWrapper(FastModuleType):
|
||||
"""Wrapper for TF modules to support deprecation messages and lazyloading."""
|
||||
|
||||
def __init__( # pylint: disable=super-on-old-class
|
||||
@ -90,8 +89,9 @@ class TFModuleWrapper(types.ModuleType):
|
||||
deprecation=True,
|
||||
has_lite=False): # pylint: enable=super-on-old-class
|
||||
super(TFModuleWrapper, self).__init__(wrapped.__name__)
|
||||
# A cache for all members which do not print deprecations (any more).
|
||||
self._tfmw_attr_map = {}
|
||||
FastModuleType.set_getattr_callback(self, TFModuleWrapper._getattr)
|
||||
FastModuleType.set_getattribute_callback(self,
|
||||
TFModuleWrapper._getattribute)
|
||||
self.__dict__.update(wrapped.__dict__)
|
||||
# Prefix all local attributes with _tfmw_ so that we can
|
||||
# handle them differently in attribute access methods.
|
||||
@ -142,6 +142,7 @@ class TFModuleWrapper(types.ModuleType):
|
||||
return False
|
||||
|
||||
def _tfmw_import_module(self, name):
|
||||
"""Lazily loading the modules."""
|
||||
symbol_loc_info = self._tfmw_public_apis[name]
|
||||
if symbol_loc_info[0]:
|
||||
module = importlib.import_module(symbol_loc_info[0])
|
||||
@ -150,51 +151,67 @@ class TFModuleWrapper(types.ModuleType):
|
||||
attr = importlib.import_module(symbol_loc_info[1])
|
||||
setattr(self._tfmw_wrapped_module, name, attr)
|
||||
self.__dict__[name] = attr
|
||||
# Cache the pair
|
||||
self._fastdict_insert(name, attr)
|
||||
return attr
|
||||
|
||||
def __getattribute__(self, name): # pylint: disable=super-on-old-class
|
||||
# Handle edge case where we unpickle and the object is not initialized yet
|
||||
# and does not have _tfmw_attr_map attribute. Otherwise, calling
|
||||
# __getattribute__ on __setstate__ will result in infinite recursion where
|
||||
# we keep trying to get _tfmw_wrapped_module in __getattr__.
|
||||
try:
|
||||
attr_map = object.__getattribute__(self, '_tfmw_attr_map')
|
||||
except AttributeError:
|
||||
self._tfmw_attr_map = attr_map = {}
|
||||
def _getattribute(self, name):
|
||||
# pylint: disable=g-doc-return-or-yield,g-doc-args
|
||||
"""Imports and caches pre-defined API.
|
||||
|
||||
try:
|
||||
# Use cached attrs if available
|
||||
return attr_map[name]
|
||||
except KeyError:
|
||||
# Make sure we do not import from tensorflow/lite/__init__.py
|
||||
if name == 'lite':
|
||||
if self._tfmw_has_lite:
|
||||
attr = self._tfmw_import_module(name)
|
||||
setattr(self._tfmw_wrapped_module, 'lite', attr)
|
||||
attr_map[name] = attr
|
||||
return attr
|
||||
Warns if necessary.
|
||||
|
||||
# Placeholder for Google-internal contrib error
|
||||
This method is a replacement for __getattribute__(). It will be added into
|
||||
the extended python module as a callback to reduce API overhead.
|
||||
"""
|
||||
# Avoid infinite recursions
|
||||
func__fastdict_insert = object.__getattribute__(self, '_fastdict_insert')
|
||||
|
||||
attr = super(TFModuleWrapper, self).__getattribute__(name)
|
||||
|
||||
# Return and cache dunders and our own members.
|
||||
if name.startswith('__') or name.startswith('_tfmw_'):
|
||||
attr_map[name] = attr
|
||||
# Make sure we do not import from tensorflow/lite/__init__.py
|
||||
if name == 'lite':
|
||||
if self._tfmw_has_lite:
|
||||
attr = self._tfmw_import_module(name)
|
||||
setattr(self._tfmw_wrapped_module, 'lite', attr)
|
||||
func__fastdict_insert(name, attr)
|
||||
return attr
|
||||
# Placeholder for Google-internal contrib error
|
||||
|
||||
# Print deprecations, only cache functions after deprecation warnings have
|
||||
# stopped.
|
||||
if not (self._tfmw_print_deprecation_warnings and
|
||||
self._tfmw_add_deprecation_warning(name, attr)):
|
||||
attr_map[name] = attr
|
||||
attr = object.__getattribute__(self, name)
|
||||
|
||||
# Return and cache dunders and our own members.
|
||||
# This is necessary to guarantee successful construction.
|
||||
# In addition, all the accessed attributes used during the construction must
|
||||
# begin with "__" or "_tfmw" or "_fastdict_".
|
||||
if name.startswith('__') or name.startswith('_tfmw_') or name.startswith(
|
||||
'_fastdict_'):
|
||||
func__fastdict_insert(name, attr)
|
||||
return attr
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Print deprecations, only cache functions after deprecation warnings have
|
||||
# stopped.
|
||||
if not (self._tfmw_print_deprecation_warnings and
|
||||
self._tfmw_add_deprecation_warning(name, attr)):
|
||||
func__fastdict_insert(name, attr)
|
||||
|
||||
return attr
|
||||
|
||||
def _getattr(self, name):
|
||||
# pylint: disable=g-doc-return-or-yield,g-doc-args
|
||||
"""Imports and caches pre-defined API.
|
||||
|
||||
Warns if necessary.
|
||||
|
||||
This method is a replacement for __getattr__(). It will be added into the
|
||||
extended python module as a callback to reduce API overhead. Instead of
|
||||
relying on implicit AttributeError handling, this added callback function
|
||||
will
|
||||
be called explicitly from the extended C API if the default attribute lookup
|
||||
fails.
|
||||
"""
|
||||
try:
|
||||
attr = getattr(self._tfmw_wrapped_module, name)
|
||||
except AttributeError:
|
||||
# Placeholder for Google-internal contrib error
|
||||
# Placeholder for Google-internal contrib error
|
||||
|
||||
if not self._tfmw_public_apis:
|
||||
raise
|
||||
@ -212,8 +229,9 @@ class TFModuleWrapper(types.ModuleType):
|
||||
self.__dict__[arg] = val
|
||||
if arg not in self.__all__ and arg != '__all__':
|
||||
self.__all__.append(arg)
|
||||
if arg in self._tfmw_attr_map:
|
||||
self._tfmw_attr_map[arg] = val
|
||||
# Update the cache
|
||||
if self._fastdict_key_in(arg):
|
||||
self._fastdict_insert(arg, val)
|
||||
super(TFModuleWrapper, self).__setattr__(arg, val)
|
||||
|
||||
def __dir__(self):
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -77,8 +77,13 @@ class LazyLoadingWrapperTest(test.TestCase):
|
||||
module, 'test', public_apis=apis, deprecation=False)
|
||||
import cmd as _cmd # pylint: disable=g-import-not-at-top
|
||||
from abc import ABCMeta as _ABCMeta # pylint: disable=g-import-not-at-top, g-importing-member
|
||||
self.assertFalse(wrapped_module._fastdict_key_in('cmd'))
|
||||
self.assertEqual(wrapped_module.cmd, _cmd)
|
||||
# Verify that the APIs are added to the cache of FastModuleType object
|
||||
self.assertTrue(wrapped_module._fastdict_key_in('cmd'))
|
||||
self.assertFalse(wrapped_module._fastdict_key_in('ABCMeta'))
|
||||
self.assertEqual(wrapped_module.ABCMeta, _ABCMeta)
|
||||
self.assertTrue(wrapped_module._fastdict_key_in('ABCMeta'))
|
||||
|
||||
def testLazyLoadLocalOverride(self):
|
||||
# Test that we can override and add fields to the wrapped module.
|
||||
@ -91,7 +96,11 @@ class LazyLoadingWrapperTest(test.TestCase):
|
||||
setattr(wrapped_module, 'cmd', 1)
|
||||
setattr(wrapped_module, 'cgi', 2)
|
||||
self.assertEqual(wrapped_module.cmd, 1) # override
|
||||
# Verify that the values are also updated in the cache
|
||||
# of the FastModuleType object
|
||||
self.assertEqual(wrapped_module._fastdict_get('cmd'), 1)
|
||||
self.assertEqual(wrapped_module.cgi, 2) # add
|
||||
self.assertEqual(wrapped_module._fastdict_get('cgi'), 2)
|
||||
|
||||
def testLazyLoadDict(self):
|
||||
# Test that we can override and add fields to the wrapped module.
|
||||
@ -131,6 +140,13 @@ class LazyLoadingWrapperTest(test.TestCase):
|
||||
module, 'test', public_apis=apis, deprecation=False, has_lite=True)
|
||||
self.assertEqual(wrapped_module.lite, _cmd)
|
||||
|
||||
def testInitCachesAttributes(self):
|
||||
module = MockModule('test')
|
||||
wrapped_module = module_wrapper.TFModuleWrapper(module, 'test')
|
||||
self.assertTrue(wrapped_module._fastdict_key_in('_fastdict_key_in'))
|
||||
self.assertTrue(wrapped_module._fastdict_key_in('_tfmw_module_name'))
|
||||
self.assertTrue(wrapped_module._fastdict_key_in('__all__'))
|
||||
|
||||
|
||||
class PickleTest(test.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user