diff --git a/tensorflow/python/util/module_wrapper.py b/tensorflow/python/util/module_wrapper.py index 6207a393d60..da23e33d0ed 100644 --- a/tensorflow/python/util/module_wrapper.py +++ b/tensorflow/python/util/module_wrapper.py @@ -90,6 +90,8 @@ class TFModuleWrapper(types.ModuleType): deprecation=True, has_lite=False): # pylint: enable=super-on-old-class super(TFModuleWrapper, self).__init__(wrapped.__name__) + # A cache for all members which do not print deprecations (any more). + self._tfmw_attr_map = {} self.__dict__.update(wrapped.__dict__) # Prefix all local attributes with _tfmw_ so that we can # handle them differently in attribute access methods. @@ -136,6 +138,8 @@ class TFModuleWrapper(types.ModuleType): 'From %s: The name %s is deprecated. Please use %s instead.\n', _call_location(), full_name, rename) self._tfmw_warning_count += 1 + return True + return False def _tfmw_import_module(self, name): symbol_loc_info = self._tfmw_public_apis[name] @@ -149,25 +153,37 @@ class TFModuleWrapper(types.ModuleType): return attr def __getattribute__(self, name): # pylint: disable=super-on-old-class - # Workaround to make sure we do not import from tensorflow/lite/__init__.py - if name == 'lite': - if self._tfmw_has_lite: - attr = self._tfmw_import_module(name) - setattr(self._tfmw_wrapped_module, 'lite', attr) + attr_map = object.__getattribute__(self, '_tfmw_attr_map') + try: + # Use cached attrs if available + return attr_map[name] + except KeyError: + # Make sure we do not import from tensorflow/lite/__init__.py + if name == 'lite': + if self._tfmw_has_lite: + attr = self._tfmw_import_module(name) + setattr(self._tfmw_wrapped_module, 'lite', attr) + attr_map[name] = attr + return attr + + attr = super(TFModuleWrapper, self).__getattribute__(name) + + # Return and cache dunders and our own members. + if name.startswith('__') or name.startswith('_tfmw_'): + attr_map[name] = attr return attr - attr = super(TFModuleWrapper, self).__getattribute__(name) - if name.startswith('__') or name.startswith('_tfmw_'): + # Print deprecations, only cache functions after deprecation warnings have + # stopped. + if not (self._tfmw_print_deprecation_warnings and + self._tfmw_add_deprecation_warning(name, attr)): + attr_map[name] = attr return attr - if self._tfmw_print_deprecation_warnings: - self._tfmw_add_deprecation_warning(name, attr) - return attr - def __getattr__(self, name): try: attr = getattr(self._tfmw_wrapped_module, name) - except AttributeError as e: + except AttributeError: if not self._tfmw_public_apis: raise if name not in self._tfmw_public_apis: @@ -184,6 +200,8 @@ class TFModuleWrapper(types.ModuleType): self.__dict__[arg] = val if arg not in self.__all__ and arg != '__all__': self.__all__.append(arg) + if arg in self._tfmw_attr_map: + self._tfmw_attr_map[arg] = val super(TFModuleWrapper, self).__setattr__(arg, val) def __dir__(self):