Extending the default PyModule_Type as a FastModuleType and use it for TFModuleWrapper. This extended type has improved attribute lookup mechanism to shorten tf api lookup speed.

PiperOrigin-RevId: 351187183
Change-Id: I02fe01dfff4f0669c7f5fbf0c1e7f3641a891e6d
This commit is contained in:
A. Unique TensorFlower 2021-01-11 10:42:25 -08:00 committed by TensorFlower Gardener
parent 8a7c07385c
commit a17e1fc7ef
6 changed files with 465 additions and 44 deletions

View File

@ -24,6 +24,7 @@ py_library(
deps = [
"//tensorflow/python:util",
"//tensorflow/python/tools/api/generator:doc_srcs",
"//tensorflow/python/util:fast_module_type",
],
)

View File

@ -575,12 +575,35 @@ tf_py_test(
],
)
tf_python_pybind_extension(
name = "fast_module_type",
srcs = ["fast_module_type.cc"],
module_name = "fast_module_type",
deps = [
"//tensorflow/core/platform:logging",
"//third_party/python_runtime:headers",
"@com_google_absl//absl/container:flat_hash_map",
"@pybind11",
],
)
tf_py_test(
name = "fast_module_type_test",
srcs = ["fast_module_type_test.py"],
python_version = "PY3",
deps = [
":fast_module_type",
"//tensorflow/python:platform",
],
)
tf_py_test(
name = "module_wrapper_test",
size = "small",
srcs = ["module_wrapper_test.py"],
python_version = "PY3",
deps = [
":fast_module_type",
":util",
"//tensorflow/python:client_testlib",
"//tensorflow/tools/compatibility:all_renames_v2",

View File

@ -0,0 +1,292 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <Python.h>
#include "absl/container/flat_hash_map.h"
#include "pybind11/pybind11.h"
#include "tensorflow/core/platform/logging.h"
namespace py = pybind11;
constexpr int PY_MODULE_TYPE_TP_BASIC_SIZE = 56;
struct FastModuleObject {
// A dummy array that ensures enough size is reserved for FastModuleObject,
// because it's inherited from PyModuleObject.
const std::array<char, PY_MODULE_TYPE_TP_BASIC_SIZE> opaque_base_fields;
// A cache that helps reduce attribute lookup overhead.
absl::flat_hash_map<PyObject *, PyObject *> attr_map;
// pointer to the external getattribute function
PyObject *cb_getattribute = nullptr;
// pointer to the external getattr function
PyObject *cb_getattr = nullptr;
// static PyTypeObject type;
FastModuleObject() = delete;
~FastModuleObject() = delete;
static FastModuleObject *UncheckedCast(PyObject *obj);
};
static int FastModule_init(FastModuleObject *self, PyObject *args,
PyObject *kwds) {
DCHECK_EQ(PY_MODULE_TYPE_TP_BASIC_SIZE, PyModule_Type.tp_basicsize);
if (PyModule_Type.tp_init(reinterpret_cast<PyObject *>(self), args, kwds) < 0)
return -1;
new (&(self->attr_map)) absl::flat_hash_map<PyObject *, PyObject *>();
return 0;
}
// Parses the input as a callable and checks the result.
static PyObject *ParseFunc(PyObject *args) {
PyObject *func;
if (!PyArg_ParseTuple(args, "O:set_callback", &func)) return nullptr;
if (!PyCallable_Check(func)) {
PyErr_SetString(PyExc_TypeError, "input args must be callable");
return nullptr;
}
Py_INCREF(func); // Add a reference to new callback
return func;
}
// Sets the pointer 'cb_getattribute' in the FastModuleObject object
// corresponding to 'self'.
static PyObject *SetGetattributeCallback(PyObject *self, PyObject *args) {
PyObject *func = ParseFunc(args);
// Dispose of previous callback
Py_XDECREF(FastModuleObject::UncheckedCast(self)->cb_getattribute);
// Remember new callback
FastModuleObject::UncheckedCast(self)->cb_getattribute = func;
Py_RETURN_NONE;
}
// Sets the pointer 'cb_getattr' in the FastModuleObject object
// corresponding to 'self'.
static PyObject *SetGetattrCallback(PyObject *self, PyObject *args) {
PyObject *func = ParseFunc(args);
// Dispose of previous callback
Py_XDECREF(FastModuleObject::UncheckedCast(self)->cb_getattr);
// Remember new callback
FastModuleObject::UncheckedCast(self)->cb_getattr = func;
Py_RETURN_NONE;
}
// Inserts or updates a key-value pair in the cache 'attr_map'
// of the FastModuleObject object corresponding to 'self'.
static PyObject *FastDictInsert(FastModuleObject *self, PyObject *args) {
PyObject *name, *value;
if (!PyArg_ParseTuple(args, "OO", &name, &value)) {
PyErr_SetString(PyExc_TypeError, "_fastdict_insert: incorrect inputs");
return nullptr;
}
auto &attr_map = self->attr_map;
if (attr_map.find(name) != attr_map.end()) {
Py_DECREF(name);
Py_DECREF(value);
}
attr_map.insert_or_assign(name, value);
// Increment the reference count
Py_INCREF(name);
Py_INCREF(value);
// Properly handle returning Py_None
Py_RETURN_NONE;
}
// Gets a value from a key in the cache 'attr_map'
// of the FastModuleObject object corresponding to 'self'.
static PyObject *FastDictGet(FastModuleObject *self, PyObject *args) {
PyObject *name;
if (!PyArg_ParseTuple(args, "O", &name)) {
PyErr_SetString(PyExc_TypeError, "_fastdict_get: incorrect inputs");
return nullptr;
}
auto &attr_map = self->attr_map;
auto result = attr_map.find(name);
if (result != attr_map.end()) {
PyObject *value = result->second;
Py_INCREF(value);
return value;
}
// Copied from CPython's moduleobject.c
PyErr_Format(PyExc_KeyError, "module has no attribute '%U'", name);
return nullptr;
}
// Returns true if a key exists in the cache 'attr_map'
// of the FastModuleObject object corresponding to 'self',
// otherwise returns false.
static PyObject *FastDictContains(FastModuleObject *self, PyObject *args) {
PyObject *name;
if (!PyArg_ParseTuple(args, "O", &name)) {
PyErr_SetString(PyExc_TypeError, "_fastdict_key_in: incorrect inputs");
return nullptr;
}
const auto &attr_map = self->attr_map;
const auto result = attr_map.contains(name);
if (result) {
// Properly handle returning Py_True
Py_RETURN_TRUE;
}
// Properly handle returning Py_False
Py_RETURN_FALSE;
}
// Calls a function 'func' with inputs 'self' and 'args'.
static PyObject *CallFunc(FastModuleObject *self, PyObject *args,
PyObject *func) {
if (func == nullptr) {
PyErr_SetString(PyExc_NameError,
"Attempting to call a callback that was not defined");
return nullptr;
}
PyObject *name;
if (!PyArg_ParseTuple(args, "O", &name)) {
PyErr_SetString(PyExc_TypeError, "CallFunc: incorrect inputs");
return nullptr;
}
PyObject *arglist = Py_BuildValue("(OO)", self, name);
auto result = PyObject_CallObject(func, arglist);
Py_DECREF(arglist);
return result;
}
static PyMethodDef FastModule_methods[] = {
{"_fastdict_insert", reinterpret_cast<PyCFunction>(FastDictInsert),
METH_VARARGS, "Registers a method to the fast lookup table."},
{"_fastdict_get", reinterpret_cast<PyCFunction>(FastDictGet), METH_VARARGS,
"Gets a method from the fast lookup table."},
{"_fastdict_key_in", reinterpret_cast<PyCFunction>(FastDictContains),
METH_VARARGS, "Checks if a method exists in the fast lookup table."},
{"set_getattribute_callback", SetGetattributeCallback, METH_VARARGS,
"Defines the callback function to replace __getattribute__"},
{"set_getattr_callback", SetGetattrCallback, METH_VARARGS,
"Defines the callback function to replace __getattr__"},
{nullptr, nullptr, 0, nullptr},
};
// Attempts to get the attribute based on 'name' as the key in cache 'attr_map'
// of the FastModuleObject object corresponding to 'module'.
// If the lookup fails in the cache, either uses
// a user-defined callback 'cb_getattribute'
// or the default 'tp_getattro' function to look for the attribute.
static PyObject *FastTpGetattro(PyObject *module, PyObject *name) {
FastModuleObject *fast_module = FastModuleObject::UncheckedCast(module);
auto &attr_map = fast_module->attr_map;
auto it = attr_map.find(name);
// If the attribute lookup is successful in the cache, directly return it.
if (it != attr_map.end()) {
PyObject *value = it->second;
Py_INCREF(value);
return value;
}
PyObject *arglist = Py_BuildValue("(O)", name);
PyObject *result;
// Prefer the customized callback function over the default function.
if (fast_module->cb_getattribute != nullptr) {
result = CallFunc(fast_module, arglist, fast_module->cb_getattribute);
} else {
result = PyModule_Type.tp_getattro(module, name);
}
// Return result if it's found
if (result != nullptr) {
return result;
}
// If the default lookup fails and an AttributeError is raised,
// clear the error status before using the __getattr__ callback function.
auto is_error = PyErr_Occurred();
if (is_error && PyErr_ExceptionMatches(PyExc_AttributeError) &&
fast_module->cb_getattr != nullptr) {
PyErr_Clear();
return CallFunc(fast_module, arglist, fast_module->cb_getattr);
}
// If all options were used up
return result;
}
// Customized destructor for FastModuleType.tp_dealloc
// In addition to default behavior it also clears up the contents in attr_map.
static void FastModuleObjectDealloc(PyObject *module) {
auto &attr_map = FastModuleObject::UncheckedCast(module)->attr_map;
for (auto &it : attr_map) {
Py_DECREF(it.first);
Py_DECREF(it.second);
}
attr_map.~flat_hash_map<PyObject *, PyObject *>();
Py_TYPE(module)->tp_free(module);
}
static PyTypeObject FastModuleType = []() {
PyTypeObject obj = {PyVarObject_HEAD_INIT(&PyType_Type, 0)};
obj.tp_name = "fast_module_type.FastModuleType";
obj.tp_basicsize = sizeof(FastModuleObject);
obj.tp_itemsize = 0;
obj.tp_dealloc = FastModuleObjectDealloc;
obj.tp_getattro = FastTpGetattro;
obj.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
obj.tp_doc = "FastModuleType objects";
obj.tp_methods = FastModule_methods;
obj.tp_init = reinterpret_cast<initproc>(FastModule_init);
return obj;
}();
// Returns true if the type of 'obj' or any of its parent class
// is equal to 'target'. Otherwise returns false.
bool IsAnyBaseSameType(const PyObject *obj, const PyTypeObject *target) {
auto *tp = Py_TYPE(obj);
while (true) {
if (tp == target) return true;
// If the default type is found, there is no need to search further
if (tp == &PyBaseObject_Type) break;
tp = tp->tp_base;
}
return false;
}
// Casts 'obj' to 'FastModuleObject *'.
// Conducts a check only in non-optimized builds.
FastModuleObject *FastModuleObject::UncheckedCast(PyObject *obj) {
DCHECK(IsAnyBaseSameType(obj, &FastModuleType));
return reinterpret_cast<FastModuleObject *>(obj);
}
PYBIND11_MODULE(fast_module_type, m) {
FastModuleType.tp_base = &PyModule_Type;
FastModuleType.tp_setattro = [](PyObject *module, PyObject *name,
PyObject *value) -> int {
auto &attr_map = FastModuleObject::UncheckedCast(module)->attr_map;
if (attr_map.find(name) != attr_map.end()) {
Py_DECREF(name);
Py_DECREF(value);
}
attr_map.insert_or_assign(name, value);
// Increment the reference count
Py_INCREF(name);
Py_INCREF(value);
PyObject_GenericSetAttr(module, name, value);
return 0;
};
m.doc() = R"pbdoc(
fast_module_type
-----
)pbdoc";
// Use getter function to hold attributes rather than pybind11's m.attr due to
// b/145559202.
m.def(
"get_fast_module_type_class",
[]() {
return py::cast<py::object>(
reinterpret_cast<PyObject *>(&FastModuleType));
},
py::return_value_policy::reference);
}

View File

@ -0,0 +1,71 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.util.fast_module_type."""
from tensorflow.python.platform import test
from tensorflow.python.util import fast_module_type
FastModuleType = fast_module_type.get_fast_module_type_class()
class ChildFastModule(FastModuleType):
def _getattribute1(self, name): # pylint: disable=unused-argument
return 2
def _getattribute2(self, name): # pylint: disable=unused-argument
raise AttributeError("Pass to getattr")
def _getattr(self, name): # pylint: disable=unused-argument
return 3
class FastModuleTypeTest(test.TestCase):
def testBaseGetattribute(self):
# Tests that the default attribute lookup works.
module = ChildFastModule("test")
module.foo = 1
self.assertEqual(1, module.foo)
def testGetattributeCallback(self):
# Tests that functionality of __getattribute__ can be set as a callback.
module = ChildFastModule("test")
FastModuleType.set_getattribute_callback(module,
ChildFastModule._getattribute1)
self.assertEqual(2, module.foo)
def testGetattrCallback(self):
# Tests that functionality of __getattr__ can be set as a callback.
module = ChildFastModule("test")
FastModuleType.set_getattribute_callback(module,
ChildFastModule._getattribute2)
FastModuleType.set_getattr_callback(module, ChildFastModule._getattr)
self.assertEqual(3, module.foo)
def testFastdictApis(self):
module = ChildFastModule("test")
# At first "bar" does not exist in the module's attributes
self.assertFalse(module._fastdict_key_in("bar"))
with self.assertRaisesRegex(KeyError, "module has no attribute 'bar'"):
module._fastdict_get("bar")
module._fastdict_insert("bar", 1)
# After _fastdict_insert() the attribute is added.
self.assertTrue(module._fastdict_key_in("bar"))
self.assertEqual(1, module.bar)
if __name__ == "__main__":
test.main()

View File

@ -1,4 +1,4 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,14 +20,14 @@ from __future__ import print_function
import importlib
import inspect
import types
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import fast_module_type
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.tools.compatibility import all_renames_v2
FastModuleType = fast_module_type.get_fast_module_type_class()
_PER_MODULE_WARNING_LIMIT = 1
@ -50,8 +50,7 @@ def _call_location():
def contains_deprecation_decorator(decorators):
return any(
d.decorator_name == 'deprecated' for d in decorators)
return any(d.decorator_name == 'deprecated' for d in decorators)
def has_deprecation_decorator(symbol):
@ -79,7 +78,7 @@ def has_deprecation_decorator(symbol):
return contains_deprecation_decorator(init_decorators)
class TFModuleWrapper(types.ModuleType):
class TFModuleWrapper(FastModuleType):
"""Wrapper for TF modules to support deprecation messages and lazyloading."""
def __init__( # pylint: disable=super-on-old-class
@ -90,8 +89,9 @@ class TFModuleWrapper(types.ModuleType):
deprecation=True,
has_lite=False): # pylint: enable=super-on-old-class
super(TFModuleWrapper, self).__init__(wrapped.__name__)
# A cache for all members which do not print deprecations (any more).
self._tfmw_attr_map = {}
FastModuleType.set_getattr_callback(self, TFModuleWrapper._getattr)
FastModuleType.set_getattribute_callback(self,
TFModuleWrapper._getattribute)
self.__dict__.update(wrapped.__dict__)
# Prefix all local attributes with _tfmw_ so that we can
# handle them differently in attribute access methods.
@ -142,6 +142,7 @@ class TFModuleWrapper(types.ModuleType):
return False
def _tfmw_import_module(self, name):
"""Lazily loading the modules."""
symbol_loc_info = self._tfmw_public_apis[name]
if symbol_loc_info[0]:
module = importlib.import_module(symbol_loc_info[0])
@ -150,51 +151,67 @@ class TFModuleWrapper(types.ModuleType):
attr = importlib.import_module(symbol_loc_info[1])
setattr(self._tfmw_wrapped_module, name, attr)
self.__dict__[name] = attr
# Cache the pair
self._fastdict_insert(name, attr)
return attr
def __getattribute__(self, name): # pylint: disable=super-on-old-class
# Handle edge case where we unpickle and the object is not initialized yet
# and does not have _tfmw_attr_map attribute. Otherwise, calling
# __getattribute__ on __setstate__ will result in infinite recursion where
# we keep trying to get _tfmw_wrapped_module in __getattr__.
try:
attr_map = object.__getattribute__(self, '_tfmw_attr_map')
except AttributeError:
self._tfmw_attr_map = attr_map = {}
def _getattribute(self, name):
# pylint: disable=g-doc-return-or-yield,g-doc-args
"""Imports and caches pre-defined API.
try:
# Use cached attrs if available
return attr_map[name]
except KeyError:
# Make sure we do not import from tensorflow/lite/__init__.py
if name == 'lite':
if self._tfmw_has_lite:
attr = self._tfmw_import_module(name)
setattr(self._tfmw_wrapped_module, 'lite', attr)
attr_map[name] = attr
return attr
Warns if necessary.
# Placeholder for Google-internal contrib error
This method is a replacement for __getattribute__(). It will be added into
the extended python module as a callback to reduce API overhead.
"""
# Avoid infinite recursions
func__fastdict_insert = object.__getattribute__(self, '_fastdict_insert')
attr = super(TFModuleWrapper, self).__getattribute__(name)
# Return and cache dunders and our own members.
if name.startswith('__') or name.startswith('_tfmw_'):
attr_map[name] = attr
# Make sure we do not import from tensorflow/lite/__init__.py
if name == 'lite':
if self._tfmw_has_lite:
attr = self._tfmw_import_module(name)
setattr(self._tfmw_wrapped_module, 'lite', attr)
func__fastdict_insert(name, attr)
return attr
# Placeholder for Google-internal contrib error
# Print deprecations, only cache functions after deprecation warnings have
# stopped.
if not (self._tfmw_print_deprecation_warnings and
self._tfmw_add_deprecation_warning(name, attr)):
attr_map[name] = attr
attr = object.__getattribute__(self, name)
# Return and cache dunders and our own members.
# This is necessary to guarantee successful construction.
# In addition, all the accessed attributes used during the construction must
# begin with "__" or "_tfmw" or "_fastdict_".
if name.startswith('__') or name.startswith('_tfmw_') or name.startswith(
'_fastdict_'):
func__fastdict_insert(name, attr)
return attr
def __getattr__(self, name):
# Print deprecations, only cache functions after deprecation warnings have
# stopped.
if not (self._tfmw_print_deprecation_warnings and
self._tfmw_add_deprecation_warning(name, attr)):
func__fastdict_insert(name, attr)
return attr
def _getattr(self, name):
# pylint: disable=g-doc-return-or-yield,g-doc-args
"""Imports and caches pre-defined API.
Warns if necessary.
This method is a replacement for __getattr__(). It will be added into the
extended python module as a callback to reduce API overhead. Instead of
relying on implicit AttributeError handling, this added callback function
will
be called explicitly from the extended C API if the default attribute lookup
fails.
"""
try:
attr = getattr(self._tfmw_wrapped_module, name)
except AttributeError:
# Placeholder for Google-internal contrib error
# Placeholder for Google-internal contrib error
if not self._tfmw_public_apis:
raise
@ -212,8 +229,9 @@ class TFModuleWrapper(types.ModuleType):
self.__dict__[arg] = val
if arg not in self.__all__ and arg != '__all__':
self.__all__.append(arg)
if arg in self._tfmw_attr_map:
self._tfmw_attr_map[arg] = val
# Update the cache
if self._fastdict_key_in(arg):
self._fastdict_insert(arg, val)
super(TFModuleWrapper, self).__setattr__(arg, val)
def __dir__(self):

View File

@ -1,4 +1,4 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -77,8 +77,13 @@ class LazyLoadingWrapperTest(test.TestCase):
module, 'test', public_apis=apis, deprecation=False)
import cmd as _cmd # pylint: disable=g-import-not-at-top
from abc import ABCMeta as _ABCMeta # pylint: disable=g-import-not-at-top, g-importing-member
self.assertFalse(wrapped_module._fastdict_key_in('cmd'))
self.assertEqual(wrapped_module.cmd, _cmd)
# Verify that the APIs are added to the cache of FastModuleType object
self.assertTrue(wrapped_module._fastdict_key_in('cmd'))
self.assertFalse(wrapped_module._fastdict_key_in('ABCMeta'))
self.assertEqual(wrapped_module.ABCMeta, _ABCMeta)
self.assertTrue(wrapped_module._fastdict_key_in('ABCMeta'))
def testLazyLoadLocalOverride(self):
# Test that we can override and add fields to the wrapped module.
@ -91,7 +96,11 @@ class LazyLoadingWrapperTest(test.TestCase):
setattr(wrapped_module, 'cmd', 1)
setattr(wrapped_module, 'cgi', 2)
self.assertEqual(wrapped_module.cmd, 1) # override
# Verify that the values are also updated in the cache
# of the FastModuleType object
self.assertEqual(wrapped_module._fastdict_get('cmd'), 1)
self.assertEqual(wrapped_module.cgi, 2) # add
self.assertEqual(wrapped_module._fastdict_get('cgi'), 2)
def testLazyLoadDict(self):
# Test that we can override and add fields to the wrapped module.
@ -131,6 +140,13 @@ class LazyLoadingWrapperTest(test.TestCase):
module, 'test', public_apis=apis, deprecation=False, has_lite=True)
self.assertEqual(wrapped_module.lite, _cmd)
def testInitCachesAttributes(self):
module = MockModule('test')
wrapped_module = module_wrapper.TFModuleWrapper(module, 'test')
self.assertTrue(wrapped_module._fastdict_key_in('_fastdict_key_in'))
self.assertTrue(wrapped_module._fastdict_key_in('_tfmw_module_name'))
self.assertTrue(wrapped_module._fastdict_key_in('__all__'))
class PickleTest(test.TestCase):