Reduce Layer.__call__ overhead by ~15%.
Performance improvements to how Variable autocasting is enabled: - Uses a class-based context manager to enable autocasting rather than contextlib.contextmanager wrapper (class-based is faster). - Skips property accesses - Fewer reads of expensive thread local objects PiperOrigin-RevId: 314377779 Change-Id: Ia1cca0b994d3d3745407e1b568cc46e4b6cafd9e
This commit is contained in:
		
							parent
							
								
									895709e1a8
								
							
						
					
					
						commit
						d7132b0c79
					
				| @ -5142,9 +5142,8 @@ class Graph(object): | ||||
|   def _auto_cast_variable_read_dtype(self, dtype): | ||||
|     self._thread_local._auto_cast_variable_read_dtype = dtype  # pylint: disable=protected-access | ||||
| 
 | ||||
|   @tf_contextlib.contextmanager | ||||
|   def _enable_auto_casting_variables(self, dtype): | ||||
|     """Context manager to automatically cast AutoCastVariables. | ||||
|     """Returns a context manager to automatically cast AutoCastVariables. | ||||
| 
 | ||||
|     If an AutoCastVariable `var` is used under this context manager, it will be | ||||
|     casted to `dtype` before being used. | ||||
| @ -5154,15 +5153,10 @@ class Graph(object): | ||||
|     Args: | ||||
|       dtype: The dtype that AutoCastVariables should be casted to. | ||||
| 
 | ||||
|     Yields: | ||||
|       Nothing. | ||||
|     Returns: | ||||
|       Context manager. | ||||
|     """ | ||||
|     prev_read_dtype = self._auto_cast_variable_read_dtype | ||||
|     try: | ||||
|       self._auto_cast_variable_read_dtype = dtype | ||||
|       yield | ||||
|     finally: | ||||
|       self._auto_cast_variable_read_dtype = prev_read_dtype | ||||
|     return enable_auto_cast_variables(dtype, graph=self) | ||||
| 
 | ||||
|   def _mutation_lock(self): | ||||
|     """Returns a lock to guard code that creates & mutates ops. | ||||
| @ -5179,6 +5173,36 @@ class Graph(object): | ||||
|     return self._group_lock.group(_SESSION_RUN_LOCK_GROUP) | ||||
| 
 | ||||
| 
 | ||||
| class enable_auto_cast_variables(object): | ||||
|   """Enables the autocasting of `AutoCastVariable`s. | ||||
| 
 | ||||
|   Under this context manager, `AutoCastVariable`s will be cast to `dtype` if | ||||
|   `dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast. | ||||
|   """ | ||||
| 
 | ||||
|   def __init__(self, dtype, graph=None): | ||||
|     if dtype and not dtype.is_floating: | ||||
|       self._dtype = None | ||||
|     else: | ||||
|       self._dtype = dtype | ||||
|     if graph is None: | ||||
|       self._graph = get_default_graph() | ||||
|     else: | ||||
|       self._graph = graph | ||||
| 
 | ||||
|   def __enter__(self): | ||||
|     # For performance, access `_thread_local` attr directly rather than | ||||
|     # @property wrappers. | ||||
|     graph_thread_local = self._graph._thread_local | ||||
|     self._prev_read_dtype = getattr(graph_thread_local, | ||||
|                                     "_auto_cast_variable_read_dtype", None) | ||||
|     graph_thread_local._auto_cast_variable_read_dtype = self._dtype | ||||
| 
 | ||||
|   def __exit__(self, type_arg, value_arg, traceback_arg): | ||||
|     self._graph._thread_local._auto_cast_variable_read_dtype = ( | ||||
|         self._prev_read_dtype) | ||||
| 
 | ||||
| 
 | ||||
| # TODO(agarwal): currently device directives in an outer eager scope will not | ||||
| # apply to inner graph mode code. Fix that. | ||||
| 
 | ||||
|  | ||||
| @ -937,8 +937,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): | ||||
|               call_fn = self.call | ||||
| 
 | ||||
|             try: | ||||
|               with base_layer_utils.autocast_context_manager( | ||||
|                   self._compute_dtype_object): | ||||
|               with ops.enable_auto_cast_variables(self._compute_dtype_object): | ||||
|                 # Add auto_control_deps in V2 when they are not already added by | ||||
|                 # a `tf.function`. | ||||
|                 if (ops.executing_eagerly_outside_functions() and | ||||
| @ -999,8 +998,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): | ||||
|         with ops.name_scope_v2(self.name): | ||||
|           self._maybe_build(inputs) | ||||
|           cast_inputs = self._maybe_cast_inputs(inputs, input_list) | ||||
|           with base_layer_utils.autocast_context_manager( | ||||
|               self._compute_dtype_object): | ||||
|           with ops.enable_auto_cast_variables(self._compute_dtype_object): | ||||
|             outputs = self.call(cast_inputs, *args, **kwargs) | ||||
|           self._handle_activity_regularization(inputs, outputs) | ||||
|           self._set_mask_metadata(inputs, outputs, input_masks, build_graph) | ||||
| @ -1288,7 +1286,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): | ||||
|       if callable(loss): | ||||
|         # We run the loss without autocasting, as regularizers are often | ||||
|         # numerically unstable in float16. | ||||
|         with base_layer_utils.autocast_context_manager(None): | ||||
|         with ops.enable_auto_cast_variables(None): | ||||
|           loss = loss() | ||||
|       if loss is None: | ||||
|         return None  # Will be filtered out when computing the .losses property | ||||
|  | ||||
| @ -481,23 +481,6 @@ def training_arg_passed_to_call(argspec, args, kwargs): | ||||
|   return 'training' in full_args and full_args['training'] is not None | ||||
| 
 | ||||
| 
 | ||||
| def autocast_context_manager(dtype): | ||||
|   """Returns a context manager to autocast AutoCastVariables. | ||||
| 
 | ||||
|   Under this context manager, AutoCastVariables will be casted to `dtype` if | ||||
|   `dtype` is floating-point. Otherwise, AutoCastVariables will not be casted. | ||||
| 
 | ||||
|   Args: | ||||
|     dtype: The dtype to cast AutoCastVariables to, or None. | ||||
| 
 | ||||
|   Returns: | ||||
|     A context manager to automatically cast AutoCastVariables. | ||||
|   """ | ||||
|   if dtype and not dtype.is_floating: | ||||
|     dtype = None | ||||
|   return ops.get_default_graph()._enable_auto_casting_variables(dtype)  # pylint: disable=protected-access | ||||
| 
 | ||||
| 
 | ||||
| def is_subclassed(layer): | ||||
|   """Returns True if the object is a subclassed layer or subclassed model.""" | ||||
|   return (layer.__module__.find('keras.engine') == -1 and | ||||
|  | ||||
| @ -767,8 +767,7 @@ class Layer(base_layer.Layer): | ||||
| 
 | ||||
|           if not self.dynamic: | ||||
|             try: | ||||
|               with base_layer_utils.autocast_context_manager( | ||||
|                   self._compute_dtype_object): | ||||
|               with ops.enable_auto_cast_variables(self._compute_dtype_object): | ||||
|                 outputs = call_fn(cast_inputs, *args, **kwargs) | ||||
| 
 | ||||
|             except errors.OperatorNotAllowedInGraphError as e: | ||||
| @ -812,8 +811,7 @@ class Layer(base_layer.Layer): | ||||
|         with backend.name_scope(self._name_scope()): | ||||
|           self._maybe_build(inputs) | ||||
|           cast_inputs = self._maybe_cast_inputs(inputs) | ||||
|           with base_layer_utils.autocast_context_manager( | ||||
|               self._compute_dtype_object): | ||||
|           with ops.enable_auto_cast_variables(self._compute_dtype_object): | ||||
|             outputs = self.call(cast_inputs, *args, **kwargs) | ||||
|           self._handle_activity_regularization(inputs, outputs) | ||||
|           self._set_mask_metadata(inputs, outputs, input_masks) | ||||
| @ -1023,7 +1021,7 @@ class Layer(base_layer.Layer): | ||||
|       if callable(loss): | ||||
|         # We run the loss without autocasting, as regularizers are often | ||||
|         # numerically unstable in float16. | ||||
|         with base_layer_utils.autocast_context_manager(None): | ||||
|         with ops.enable_auto_cast_variables(None): | ||||
|           loss = loss() | ||||
|       if loss is None: | ||||
|         return None  # Will be filtered out when computing the .losses property | ||||
|  | ||||
| @ -26,6 +26,7 @@ import weakref | ||||
| 
 | ||||
| from tensorflow.python.eager import def_function | ||||
| from tensorflow.python.eager import function | ||||
| from tensorflow.python.framework import ops | ||||
| from tensorflow.python.framework import tensor_shape | ||||
| from tensorflow.python.framework import tensor_spec | ||||
| from tensorflow.python.keras import backend as K | ||||
| @ -521,8 +522,7 @@ def layer_call_wrapper(call_collection, method): | ||||
|     with base_layer_utils.call_context().enter( | ||||
|         layer, inputs=inputs, build_graph=False, training=training, | ||||
|         saving=True): | ||||
|       with base_layer_utils.autocast_context_manager( | ||||
|           layer._compute_dtype_object):  # pylint: disable=protected-access | ||||
|       with ops.enable_auto_cast_variables(layer._compute_dtype_object):  # pylint: disable=protected-access | ||||
|         ret = method(*args, **kwargs) | ||||
|     _restore_layer_losses(original_losses) | ||||
|     return ret | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user