Reduce Layer.__call__ overhead by ~15%.
Performance improvements to how Variable autocasting is enabled: - Uses a class-based context manager to enable autocasting rather than contextlib.contextmanager wrapper (class-based is faster). - Skips property accesses - Fewer reads of expensive thread local objects PiperOrigin-RevId: 314377779 Change-Id: Ia1cca0b994d3d3745407e1b568cc46e4b6cafd9e
This commit is contained in:
parent
895709e1a8
commit
d7132b0c79
@ -5142,9 +5142,8 @@ class Graph(object):
|
||||
def _auto_cast_variable_read_dtype(self, dtype):
|
||||
self._thread_local._auto_cast_variable_read_dtype = dtype # pylint: disable=protected-access
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def _enable_auto_casting_variables(self, dtype):
|
||||
"""Context manager to automatically cast AutoCastVariables.
|
||||
"""Returns a context manager to automatically cast AutoCastVariables.
|
||||
|
||||
If an AutoCastVariable `var` is used under this context manager, it will be
|
||||
casted to `dtype` before being used.
|
||||
@ -5154,15 +5153,10 @@ class Graph(object):
|
||||
Args:
|
||||
dtype: The dtype that AutoCastVariables should be casted to.
|
||||
|
||||
Yields:
|
||||
Nothing.
|
||||
Returns:
|
||||
Context manager.
|
||||
"""
|
||||
prev_read_dtype = self._auto_cast_variable_read_dtype
|
||||
try:
|
||||
self._auto_cast_variable_read_dtype = dtype
|
||||
yield
|
||||
finally:
|
||||
self._auto_cast_variable_read_dtype = prev_read_dtype
|
||||
return enable_auto_cast_variables(dtype, graph=self)
|
||||
|
||||
def _mutation_lock(self):
|
||||
"""Returns a lock to guard code that creates & mutates ops.
|
||||
@ -5179,6 +5173,36 @@ class Graph(object):
|
||||
return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
|
||||
|
||||
|
||||
class enable_auto_cast_variables(object):
|
||||
"""Enables the autocasting of `AutoCastVariable`s.
|
||||
|
||||
Under this context manager, `AutoCastVariable`s will be cast to `dtype` if
|
||||
`dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast.
|
||||
"""
|
||||
|
||||
def __init__(self, dtype, graph=None):
|
||||
if dtype and not dtype.is_floating:
|
||||
self._dtype = None
|
||||
else:
|
||||
self._dtype = dtype
|
||||
if graph is None:
|
||||
self._graph = get_default_graph()
|
||||
else:
|
||||
self._graph = graph
|
||||
|
||||
def __enter__(self):
|
||||
# For performance, access `_thread_local` attr directly rather than
|
||||
# @property wrappers.
|
||||
graph_thread_local = self._graph._thread_local
|
||||
self._prev_read_dtype = getattr(graph_thread_local,
|
||||
"_auto_cast_variable_read_dtype", None)
|
||||
graph_thread_local._auto_cast_variable_read_dtype = self._dtype
|
||||
|
||||
def __exit__(self, type_arg, value_arg, traceback_arg):
|
||||
self._graph._thread_local._auto_cast_variable_read_dtype = (
|
||||
self._prev_read_dtype)
|
||||
|
||||
|
||||
# TODO(agarwal): currently device directives in an outer eager scope will not
|
||||
# apply to inner graph mode code. Fix that.
|
||||
|
||||
|
@ -937,8 +937,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
call_fn = self.call
|
||||
|
||||
try:
|
||||
with base_layer_utils.autocast_context_manager(
|
||||
self._compute_dtype_object):
|
||||
with ops.enable_auto_cast_variables(self._compute_dtype_object):
|
||||
# Add auto_control_deps in V2 when they are not already added by
|
||||
# a `tf.function`.
|
||||
if (ops.executing_eagerly_outside_functions() and
|
||||
@ -999,8 +998,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
with ops.name_scope_v2(self.name):
|
||||
self._maybe_build(inputs)
|
||||
cast_inputs = self._maybe_cast_inputs(inputs, input_list)
|
||||
with base_layer_utils.autocast_context_manager(
|
||||
self._compute_dtype_object):
|
||||
with ops.enable_auto_cast_variables(self._compute_dtype_object):
|
||||
outputs = self.call(cast_inputs, *args, **kwargs)
|
||||
self._handle_activity_regularization(inputs, outputs)
|
||||
self._set_mask_metadata(inputs, outputs, input_masks, build_graph)
|
||||
@ -1288,7 +1286,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
if callable(loss):
|
||||
# We run the loss without autocasting, as regularizers are often
|
||||
# numerically unstable in float16.
|
||||
with base_layer_utils.autocast_context_manager(None):
|
||||
with ops.enable_auto_cast_variables(None):
|
||||
loss = loss()
|
||||
if loss is None:
|
||||
return None # Will be filtered out when computing the .losses property
|
||||
|
@ -481,23 +481,6 @@ def training_arg_passed_to_call(argspec, args, kwargs):
|
||||
return 'training' in full_args and full_args['training'] is not None
|
||||
|
||||
|
||||
def autocast_context_manager(dtype):
|
||||
"""Returns a context manager to autocast AutoCastVariables.
|
||||
|
||||
Under this context manager, AutoCastVariables will be casted to `dtype` if
|
||||
`dtype` is floating-point. Otherwise, AutoCastVariables will not be casted.
|
||||
|
||||
Args:
|
||||
dtype: The dtype to cast AutoCastVariables to, or None.
|
||||
|
||||
Returns:
|
||||
A context manager to automatically cast AutoCastVariables.
|
||||
"""
|
||||
if dtype and not dtype.is_floating:
|
||||
dtype = None
|
||||
return ops.get_default_graph()._enable_auto_casting_variables(dtype) # pylint: disable=protected-access
|
||||
|
||||
|
||||
def is_subclassed(layer):
|
||||
"""Returns True if the object is a subclassed layer or subclassed model."""
|
||||
return (layer.__module__.find('keras.engine') == -1 and
|
||||
|
@ -767,8 +767,7 @@ class Layer(base_layer.Layer):
|
||||
|
||||
if not self.dynamic:
|
||||
try:
|
||||
with base_layer_utils.autocast_context_manager(
|
||||
self._compute_dtype_object):
|
||||
with ops.enable_auto_cast_variables(self._compute_dtype_object):
|
||||
outputs = call_fn(cast_inputs, *args, **kwargs)
|
||||
|
||||
except errors.OperatorNotAllowedInGraphError as e:
|
||||
@ -812,8 +811,7 @@ class Layer(base_layer.Layer):
|
||||
with backend.name_scope(self._name_scope()):
|
||||
self._maybe_build(inputs)
|
||||
cast_inputs = self._maybe_cast_inputs(inputs)
|
||||
with base_layer_utils.autocast_context_manager(
|
||||
self._compute_dtype_object):
|
||||
with ops.enable_auto_cast_variables(self._compute_dtype_object):
|
||||
outputs = self.call(cast_inputs, *args, **kwargs)
|
||||
self._handle_activity_regularization(inputs, outputs)
|
||||
self._set_mask_metadata(inputs, outputs, input_masks)
|
||||
@ -1023,7 +1021,7 @@ class Layer(base_layer.Layer):
|
||||
if callable(loss):
|
||||
# We run the loss without autocasting, as regularizers are often
|
||||
# numerically unstable in float16.
|
||||
with base_layer_utils.autocast_context_manager(None):
|
||||
with ops.enable_auto_cast_variables(None):
|
||||
loss = loss()
|
||||
if loss is None:
|
||||
return None # Will be filtered out when computing the .losses property
|
||||
|
@ -26,6 +26,7 @@ import weakref
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import backend as K
|
||||
@ -521,8 +522,7 @@ def layer_call_wrapper(call_collection, method):
|
||||
with base_layer_utils.call_context().enter(
|
||||
layer, inputs=inputs, build_graph=False, training=training,
|
||||
saving=True):
|
||||
with base_layer_utils.autocast_context_manager(
|
||||
layer._compute_dtype_object): # pylint: disable=protected-access
|
||||
with ops.enable_auto_cast_variables(layer._compute_dtype_object): # pylint: disable=protected-access
|
||||
ret = method(*args, **kwargs)
|
||||
_restore_layer_losses(original_losses)
|
||||
return ret
|
||||
|
Loading…
x
Reference in New Issue
Block a user