Cache results of attr lookup in TFModuleWrapper. This should avoid lots of roundtrips through __getattribute__ and speed up (especially) eager execution.

PiperOrigin-RevId: 266943412
This commit is contained in:
Martin Wicke 2019-09-03 09:12:58 -07:00 committed by TensorFlower Gardener
parent c2dbce7f9e
commit 2ccea0c7f8

View File

@ -90,6 +90,8 @@ class TFModuleWrapper(types.ModuleType):
deprecation=True,
has_lite=False): # pylint: enable=super-on-old-class
super(TFModuleWrapper, self).__init__(wrapped.__name__)
# A cache for all members which do not print deprecations (any more).
self._tfmw_attr_map = {}
self.__dict__.update(wrapped.__dict__)
# Prefix all local attributes with _tfmw_ so that we can
# handle them differently in attribute access methods.
@ -136,6 +138,8 @@ class TFModuleWrapper(types.ModuleType):
'From %s: The name %s is deprecated. Please use %s instead.\n',
_call_location(), full_name, rename)
self._tfmw_warning_count += 1
return True
return False
def _tfmw_import_module(self, name):
symbol_loc_info = self._tfmw_public_apis[name]
@ -149,25 +153,37 @@ class TFModuleWrapper(types.ModuleType):
return attr
def __getattribute__(self, name): # pylint: disable=super-on-old-class
# Workaround to make sure we do not import from tensorflow/lite/__init__.py
if name == 'lite':
if self._tfmw_has_lite:
attr = self._tfmw_import_module(name)
setattr(self._tfmw_wrapped_module, 'lite', attr)
attr_map = object.__getattribute__(self, '_tfmw_attr_map')
try:
# Use cached attrs if available
return attr_map[name]
except KeyError:
# Make sure we do not import from tensorflow/lite/__init__.py
if name == 'lite':
if self._tfmw_has_lite:
attr = self._tfmw_import_module(name)
setattr(self._tfmw_wrapped_module, 'lite', attr)
attr_map[name] = attr
return attr
attr = super(TFModuleWrapper, self).__getattribute__(name)
# Return and cache dunders and our own members.
if name.startswith('__') or name.startswith('_tfmw_'):
attr_map[name] = attr
return attr
attr = super(TFModuleWrapper, self).__getattribute__(name)
if name.startswith('__') or name.startswith('_tfmw_'):
# Print deprecations, only cache functions after deprecation warnings have
# stopped.
if not (self._tfmw_print_deprecation_warnings and
self._tfmw_add_deprecation_warning(name, attr)):
attr_map[name] = attr
return attr
if self._tfmw_print_deprecation_warnings:
self._tfmw_add_deprecation_warning(name, attr)
return attr
def __getattr__(self, name):
try:
attr = getattr(self._tfmw_wrapped_module, name)
except AttributeError as e:
except AttributeError:
if not self._tfmw_public_apis:
raise
if name not in self._tfmw_public_apis:
@ -184,6 +200,8 @@ class TFModuleWrapper(types.ModuleType):
self.__dict__[arg] = val
if arg not in self.__all__ and arg != '__all__':
self.__all__.append(arg)
if arg in self._tfmw_attr_map:
self._tfmw_attr_map[arg] = val
super(TFModuleWrapper, self).__setattr__(arg, val)
def __dir__(self):