Cache results of attr lookup in TFModuleWrapper. This should avoid lots of roundtrips through __getattribute__ and speed up (especially) eager execution.
PiperOrigin-RevId: 266943412
This commit is contained in:
parent
c2dbce7f9e
commit
2ccea0c7f8
@ -90,6 +90,8 @@ class TFModuleWrapper(types.ModuleType):
|
||||
deprecation=True,
|
||||
has_lite=False): # pylint: enable=super-on-old-class
|
||||
super(TFModuleWrapper, self).__init__(wrapped.__name__)
|
||||
# A cache for all members which do not print deprecations (any more).
|
||||
self._tfmw_attr_map = {}
|
||||
self.__dict__.update(wrapped.__dict__)
|
||||
# Prefix all local attributes with _tfmw_ so that we can
|
||||
# handle them differently in attribute access methods.
|
||||
@ -136,6 +138,8 @@ class TFModuleWrapper(types.ModuleType):
|
||||
'From %s: The name %s is deprecated. Please use %s instead.\n',
|
||||
_call_location(), full_name, rename)
|
||||
self._tfmw_warning_count += 1
|
||||
return True
|
||||
return False
|
||||
|
||||
def _tfmw_import_module(self, name):
|
||||
symbol_loc_info = self._tfmw_public_apis[name]
|
||||
@ -149,25 +153,37 @@ class TFModuleWrapper(types.ModuleType):
|
||||
return attr
|
||||
|
||||
def __getattribute__(self, name): # pylint: disable=super-on-old-class
|
||||
# Workaround to make sure we do not import from tensorflow/lite/__init__.py
|
||||
if name == 'lite':
|
||||
if self._tfmw_has_lite:
|
||||
attr = self._tfmw_import_module(name)
|
||||
setattr(self._tfmw_wrapped_module, 'lite', attr)
|
||||
attr_map = object.__getattribute__(self, '_tfmw_attr_map')
|
||||
try:
|
||||
# Use cached attrs if available
|
||||
return attr_map[name]
|
||||
except KeyError:
|
||||
# Make sure we do not import from tensorflow/lite/__init__.py
|
||||
if name == 'lite':
|
||||
if self._tfmw_has_lite:
|
||||
attr = self._tfmw_import_module(name)
|
||||
setattr(self._tfmw_wrapped_module, 'lite', attr)
|
||||
attr_map[name] = attr
|
||||
return attr
|
||||
|
||||
attr = super(TFModuleWrapper, self).__getattribute__(name)
|
||||
|
||||
# Return and cache dunders and our own members.
|
||||
if name.startswith('__') or name.startswith('_tfmw_'):
|
||||
attr_map[name] = attr
|
||||
return attr
|
||||
|
||||
attr = super(TFModuleWrapper, self).__getattribute__(name)
|
||||
if name.startswith('__') or name.startswith('_tfmw_'):
|
||||
# Print deprecations, only cache functions after deprecation warnings have
|
||||
# stopped.
|
||||
if not (self._tfmw_print_deprecation_warnings and
|
||||
self._tfmw_add_deprecation_warning(name, attr)):
|
||||
attr_map[name] = attr
|
||||
return attr
|
||||
|
||||
if self._tfmw_print_deprecation_warnings:
|
||||
self._tfmw_add_deprecation_warning(name, attr)
|
||||
return attr
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
attr = getattr(self._tfmw_wrapped_module, name)
|
||||
except AttributeError as e:
|
||||
except AttributeError:
|
||||
if not self._tfmw_public_apis:
|
||||
raise
|
||||
if name not in self._tfmw_public_apis:
|
||||
@ -184,6 +200,8 @@ class TFModuleWrapper(types.ModuleType):
|
||||
self.__dict__[arg] = val
|
||||
if arg not in self.__all__ and arg != '__all__':
|
||||
self.__all__.append(arg)
|
||||
if arg in self._tfmw_attr_map:
|
||||
self._tfmw_attr_map[arg] = val
|
||||
super(TFModuleWrapper, self).__setattr__(arg, val)
|
||||
|
||||
def __dir__(self):
|
||||
|
Loading…
Reference in New Issue
Block a user