Reduce Layer.__call__ overhead by ~15%.
Performance improvements to how Variable autocasting is enabled: - Uses a class-based context manager to enable autocasting rather than contextlib.contextmanager wrapper (class-based is faster). - Skips property accesses - Fewer reads of expensive thread local objects PiperOrigin-RevId: 314377779 Change-Id: Ia1cca0b994d3d3745407e1b568cc46e4b6cafd9e
This commit is contained in:
		
							parent
							
								
									895709e1a8
								
							
						
					
					
						commit
						d7132b0c79
					
				@ -5142,9 +5142,8 @@ class Graph(object):
 | 
				
			|||||||
  def _auto_cast_variable_read_dtype(self, dtype):
 | 
					  def _auto_cast_variable_read_dtype(self, dtype):
 | 
				
			||||||
    self._thread_local._auto_cast_variable_read_dtype = dtype  # pylint: disable=protected-access
 | 
					    self._thread_local._auto_cast_variable_read_dtype = dtype  # pylint: disable=protected-access
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @tf_contextlib.contextmanager
 | 
					 | 
				
			||||||
  def _enable_auto_casting_variables(self, dtype):
 | 
					  def _enable_auto_casting_variables(self, dtype):
 | 
				
			||||||
    """Context manager to automatically cast AutoCastVariables.
 | 
					    """Returns a context manager to automatically cast AutoCastVariables.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    If an AutoCastVariable `var` is used under this context manager, it will be
 | 
					    If an AutoCastVariable `var` is used under this context manager, it will be
 | 
				
			||||||
    casted to `dtype` before being used.
 | 
					    casted to `dtype` before being used.
 | 
				
			||||||
@ -5154,15 +5153,10 @@ class Graph(object):
 | 
				
			|||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
      dtype: The dtype that AutoCastVariables should be casted to.
 | 
					      dtype: The dtype that AutoCastVariables should be casted to.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Yields:
 | 
					    Returns:
 | 
				
			||||||
      Nothing.
 | 
					      Context manager.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    prev_read_dtype = self._auto_cast_variable_read_dtype
 | 
					    return enable_auto_cast_variables(dtype, graph=self)
 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
      self._auto_cast_variable_read_dtype = dtype
 | 
					 | 
				
			||||||
      yield
 | 
					 | 
				
			||||||
    finally:
 | 
					 | 
				
			||||||
      self._auto_cast_variable_read_dtype = prev_read_dtype
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def _mutation_lock(self):
 | 
					  def _mutation_lock(self):
 | 
				
			||||||
    """Returns a lock to guard code that creates & mutates ops.
 | 
					    """Returns a lock to guard code that creates & mutates ops.
 | 
				
			||||||
@ -5179,6 +5173,36 @@ class Graph(object):
 | 
				
			|||||||
    return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
 | 
					    return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class enable_auto_cast_variables(object):
 | 
				
			||||||
 | 
					  """Enables the autocasting of `AutoCastVariable`s.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Under this context manager, `AutoCastVariable`s will be cast to `dtype` if
 | 
				
			||||||
 | 
					  `dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __init__(self, dtype, graph=None):
 | 
				
			||||||
 | 
					    if dtype and not dtype.is_floating:
 | 
				
			||||||
 | 
					      self._dtype = None
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					      self._dtype = dtype
 | 
				
			||||||
 | 
					    if graph is None:
 | 
				
			||||||
 | 
					      self._graph = get_default_graph()
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					      self._graph = graph
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __enter__(self):
 | 
				
			||||||
 | 
					    # For performance, access `_thread_local` attr directly rather than
 | 
				
			||||||
 | 
					    # @property wrappers.
 | 
				
			||||||
 | 
					    graph_thread_local = self._graph._thread_local
 | 
				
			||||||
 | 
					    self._prev_read_dtype = getattr(graph_thread_local,
 | 
				
			||||||
 | 
					                                    "_auto_cast_variable_read_dtype", None)
 | 
				
			||||||
 | 
					    graph_thread_local._auto_cast_variable_read_dtype = self._dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def __exit__(self, type_arg, value_arg, traceback_arg):
 | 
				
			||||||
 | 
					    self._graph._thread_local._auto_cast_variable_read_dtype = (
 | 
				
			||||||
 | 
					        self._prev_read_dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO(agarwal): currently device directives in an outer eager scope will not
 | 
					# TODO(agarwal): currently device directives in an outer eager scope will not
 | 
				
			||||||
# apply to inner graph mode code. Fix that.
 | 
					# apply to inner graph mode code. Fix that.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -937,8 +937,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
 | 
				
			|||||||
              call_fn = self.call
 | 
					              call_fn = self.call
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
              with base_layer_utils.autocast_context_manager(
 | 
					              with ops.enable_auto_cast_variables(self._compute_dtype_object):
 | 
				
			||||||
                  self._compute_dtype_object):
 | 
					 | 
				
			||||||
                # Add auto_control_deps in V2 when they are not already added by
 | 
					                # Add auto_control_deps in V2 when they are not already added by
 | 
				
			||||||
                # a `tf.function`.
 | 
					                # a `tf.function`.
 | 
				
			||||||
                if (ops.executing_eagerly_outside_functions() and
 | 
					                if (ops.executing_eagerly_outside_functions() and
 | 
				
			||||||
@ -999,8 +998,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
 | 
				
			|||||||
        with ops.name_scope_v2(self.name):
 | 
					        with ops.name_scope_v2(self.name):
 | 
				
			||||||
          self._maybe_build(inputs)
 | 
					          self._maybe_build(inputs)
 | 
				
			||||||
          cast_inputs = self._maybe_cast_inputs(inputs, input_list)
 | 
					          cast_inputs = self._maybe_cast_inputs(inputs, input_list)
 | 
				
			||||||
          with base_layer_utils.autocast_context_manager(
 | 
					          with ops.enable_auto_cast_variables(self._compute_dtype_object):
 | 
				
			||||||
              self._compute_dtype_object):
 | 
					 | 
				
			||||||
            outputs = self.call(cast_inputs, *args, **kwargs)
 | 
					            outputs = self.call(cast_inputs, *args, **kwargs)
 | 
				
			||||||
          self._handle_activity_regularization(inputs, outputs)
 | 
					          self._handle_activity_regularization(inputs, outputs)
 | 
				
			||||||
          self._set_mask_metadata(inputs, outputs, input_masks, build_graph)
 | 
					          self._set_mask_metadata(inputs, outputs, input_masks, build_graph)
 | 
				
			||||||
@ -1288,7 +1286,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
 | 
				
			|||||||
      if callable(loss):
 | 
					      if callable(loss):
 | 
				
			||||||
        # We run the loss without autocasting, as regularizers are often
 | 
					        # We run the loss without autocasting, as regularizers are often
 | 
				
			||||||
        # numerically unstable in float16.
 | 
					        # numerically unstable in float16.
 | 
				
			||||||
        with base_layer_utils.autocast_context_manager(None):
 | 
					        with ops.enable_auto_cast_variables(None):
 | 
				
			||||||
          loss = loss()
 | 
					          loss = loss()
 | 
				
			||||||
      if loss is None:
 | 
					      if loss is None:
 | 
				
			||||||
        return None  # Will be filtered out when computing the .losses property
 | 
					        return None  # Will be filtered out when computing the .losses property
 | 
				
			||||||
 | 
				
			|||||||
@ -481,23 +481,6 @@ def training_arg_passed_to_call(argspec, args, kwargs):
 | 
				
			|||||||
  return 'training' in full_args and full_args['training'] is not None
 | 
					  return 'training' in full_args and full_args['training'] is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def autocast_context_manager(dtype):
 | 
					 | 
				
			||||||
  """Returns a context manager to autocast AutoCastVariables.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  Under this context manager, AutoCastVariables will be casted to `dtype` if
 | 
					 | 
				
			||||||
  `dtype` is floating-point. Otherwise, AutoCastVariables will not be casted.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  Args:
 | 
					 | 
				
			||||||
    dtype: The dtype to cast AutoCastVariables to, or None.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  Returns:
 | 
					 | 
				
			||||||
    A context manager to automatically cast AutoCastVariables.
 | 
					 | 
				
			||||||
  """
 | 
					 | 
				
			||||||
  if dtype and not dtype.is_floating:
 | 
					 | 
				
			||||||
    dtype = None
 | 
					 | 
				
			||||||
  return ops.get_default_graph()._enable_auto_casting_variables(dtype)  # pylint: disable=protected-access
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def is_subclassed(layer):
 | 
					def is_subclassed(layer):
 | 
				
			||||||
  """Returns True if the object is a subclassed layer or subclassed model."""
 | 
					  """Returns True if the object is a subclassed layer or subclassed model."""
 | 
				
			||||||
  return (layer.__module__.find('keras.engine') == -1 and
 | 
					  return (layer.__module__.find('keras.engine') == -1 and
 | 
				
			||||||
 | 
				
			|||||||
@ -767,8 +767,7 @@ class Layer(base_layer.Layer):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
          if not self.dynamic:
 | 
					          if not self.dynamic:
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
              with base_layer_utils.autocast_context_manager(
 | 
					              with ops.enable_auto_cast_variables(self._compute_dtype_object):
 | 
				
			||||||
                  self._compute_dtype_object):
 | 
					 | 
				
			||||||
                outputs = call_fn(cast_inputs, *args, **kwargs)
 | 
					                outputs = call_fn(cast_inputs, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            except errors.OperatorNotAllowedInGraphError as e:
 | 
					            except errors.OperatorNotAllowedInGraphError as e:
 | 
				
			||||||
@ -812,8 +811,7 @@ class Layer(base_layer.Layer):
 | 
				
			|||||||
        with backend.name_scope(self._name_scope()):
 | 
					        with backend.name_scope(self._name_scope()):
 | 
				
			||||||
          self._maybe_build(inputs)
 | 
					          self._maybe_build(inputs)
 | 
				
			||||||
          cast_inputs = self._maybe_cast_inputs(inputs)
 | 
					          cast_inputs = self._maybe_cast_inputs(inputs)
 | 
				
			||||||
          with base_layer_utils.autocast_context_manager(
 | 
					          with ops.enable_auto_cast_variables(self._compute_dtype_object):
 | 
				
			||||||
              self._compute_dtype_object):
 | 
					 | 
				
			||||||
            outputs = self.call(cast_inputs, *args, **kwargs)
 | 
					            outputs = self.call(cast_inputs, *args, **kwargs)
 | 
				
			||||||
          self._handle_activity_regularization(inputs, outputs)
 | 
					          self._handle_activity_regularization(inputs, outputs)
 | 
				
			||||||
          self._set_mask_metadata(inputs, outputs, input_masks)
 | 
					          self._set_mask_metadata(inputs, outputs, input_masks)
 | 
				
			||||||
@ -1023,7 +1021,7 @@ class Layer(base_layer.Layer):
 | 
				
			|||||||
      if callable(loss):
 | 
					      if callable(loss):
 | 
				
			||||||
        # We run the loss without autocasting, as regularizers are often
 | 
					        # We run the loss without autocasting, as regularizers are often
 | 
				
			||||||
        # numerically unstable in float16.
 | 
					        # numerically unstable in float16.
 | 
				
			||||||
        with base_layer_utils.autocast_context_manager(None):
 | 
					        with ops.enable_auto_cast_variables(None):
 | 
				
			||||||
          loss = loss()
 | 
					          loss = loss()
 | 
				
			||||||
      if loss is None:
 | 
					      if loss is None:
 | 
				
			||||||
        return None  # Will be filtered out when computing the .losses property
 | 
					        return None  # Will be filtered out when computing the .losses property
 | 
				
			||||||
 | 
				
			|||||||
@ -26,6 +26,7 @@ import weakref
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from tensorflow.python.eager import def_function
 | 
					from tensorflow.python.eager import def_function
 | 
				
			||||||
from tensorflow.python.eager import function
 | 
					from tensorflow.python.eager import function
 | 
				
			||||||
 | 
					from tensorflow.python.framework import ops
 | 
				
			||||||
from tensorflow.python.framework import tensor_shape
 | 
					from tensorflow.python.framework import tensor_shape
 | 
				
			||||||
from tensorflow.python.framework import tensor_spec
 | 
					from tensorflow.python.framework import tensor_spec
 | 
				
			||||||
from tensorflow.python.keras import backend as K
 | 
					from tensorflow.python.keras import backend as K
 | 
				
			||||||
@ -521,8 +522,7 @@ def layer_call_wrapper(call_collection, method):
 | 
				
			|||||||
    with base_layer_utils.call_context().enter(
 | 
					    with base_layer_utils.call_context().enter(
 | 
				
			||||||
        layer, inputs=inputs, build_graph=False, training=training,
 | 
					        layer, inputs=inputs, build_graph=False, training=training,
 | 
				
			||||||
        saving=True):
 | 
					        saving=True):
 | 
				
			||||||
      with base_layer_utils.autocast_context_manager(
 | 
					      with ops.enable_auto_cast_variables(layer._compute_dtype_object):  # pylint: disable=protected-access
 | 
				
			||||||
          layer._compute_dtype_object):  # pylint: disable=protected-access
 | 
					 | 
				
			||||||
        ret = method(*args, **kwargs)
 | 
					        ret = method(*args, **kwargs)
 | 
				
			||||||
    _restore_layer_losses(original_losses)
 | 
					    _restore_layer_losses(original_losses)
 | 
				
			||||||
    return ret
 | 
					    return ret
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user