Reduce Layer.__call__ overhead by ~15%.
Performance improvements to how Variable autocasting is enabled: - Uses a class-based context manager to enable autocasting rather than contextlib.contextmanager wrapper (class-based is faster). - Skips property accesses - Fewer reads of expensive thread local objects PiperOrigin-RevId: 314377779 Change-Id: Ia1cca0b994d3d3745407e1b568cc46e4b6cafd9e
This commit is contained in:
parent
895709e1a8
commit
d7132b0c79
@ -5142,9 +5142,8 @@ class Graph(object):
|
|||||||
def _auto_cast_variable_read_dtype(self, dtype):
|
def _auto_cast_variable_read_dtype(self, dtype):
|
||||||
self._thread_local._auto_cast_variable_read_dtype = dtype # pylint: disable=protected-access
|
self._thread_local._auto_cast_variable_read_dtype = dtype # pylint: disable=protected-access
|
||||||
|
|
||||||
@tf_contextlib.contextmanager
|
|
||||||
def _enable_auto_casting_variables(self, dtype):
|
def _enable_auto_casting_variables(self, dtype):
|
||||||
"""Context manager to automatically cast AutoCastVariables.
|
"""Returns a context manager to automatically cast AutoCastVariables.
|
||||||
|
|
||||||
If an AutoCastVariable `var` is used under this context manager, it will be
|
If an AutoCastVariable `var` is used under this context manager, it will be
|
||||||
casted to `dtype` before being used.
|
casted to `dtype` before being used.
|
||||||
@ -5154,15 +5153,10 @@ class Graph(object):
|
|||||||
Args:
|
Args:
|
||||||
dtype: The dtype that AutoCastVariables should be casted to.
|
dtype: The dtype that AutoCastVariables should be casted to.
|
||||||
|
|
||||||
Yields:
|
Returns:
|
||||||
Nothing.
|
Context manager.
|
||||||
"""
|
"""
|
||||||
prev_read_dtype = self._auto_cast_variable_read_dtype
|
return enable_auto_cast_variables(dtype, graph=self)
|
||||||
try:
|
|
||||||
self._auto_cast_variable_read_dtype = dtype
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
self._auto_cast_variable_read_dtype = prev_read_dtype
|
|
||||||
|
|
||||||
def _mutation_lock(self):
|
def _mutation_lock(self):
|
||||||
"""Returns a lock to guard code that creates & mutates ops.
|
"""Returns a lock to guard code that creates & mutates ops.
|
||||||
@ -5179,6 +5173,36 @@ class Graph(object):
|
|||||||
return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
|
return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
|
||||||
|
|
||||||
|
|
||||||
|
class enable_auto_cast_variables(object):
|
||||||
|
"""Enables the autocasting of `AutoCastVariable`s.
|
||||||
|
|
||||||
|
Under this context manager, `AutoCastVariable`s will be cast to `dtype` if
|
||||||
|
`dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dtype, graph=None):
|
||||||
|
if dtype and not dtype.is_floating:
|
||||||
|
self._dtype = None
|
||||||
|
else:
|
||||||
|
self._dtype = dtype
|
||||||
|
if graph is None:
|
||||||
|
self._graph = get_default_graph()
|
||||||
|
else:
|
||||||
|
self._graph = graph
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
# For performance, access `_thread_local` attr directly rather than
|
||||||
|
# @property wrappers.
|
||||||
|
graph_thread_local = self._graph._thread_local
|
||||||
|
self._prev_read_dtype = getattr(graph_thread_local,
|
||||||
|
"_auto_cast_variable_read_dtype", None)
|
||||||
|
graph_thread_local._auto_cast_variable_read_dtype = self._dtype
|
||||||
|
|
||||||
|
def __exit__(self, type_arg, value_arg, traceback_arg):
|
||||||
|
self._graph._thread_local._auto_cast_variable_read_dtype = (
|
||||||
|
self._prev_read_dtype)
|
||||||
|
|
||||||
|
|
||||||
# TODO(agarwal): currently device directives in an outer eager scope will not
|
# TODO(agarwal): currently device directives in an outer eager scope will not
|
||||||
# apply to inner graph mode code. Fix that.
|
# apply to inner graph mode code. Fix that.
|
||||||
|
|
||||||
|
@ -937,8 +937,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
call_fn = self.call
|
call_fn = self.call
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with base_layer_utils.autocast_context_manager(
|
with ops.enable_auto_cast_variables(self._compute_dtype_object):
|
||||||
self._compute_dtype_object):
|
|
||||||
# Add auto_control_deps in V2 when they are not already added by
|
# Add auto_control_deps in V2 when they are not already added by
|
||||||
# a `tf.function`.
|
# a `tf.function`.
|
||||||
if (ops.executing_eagerly_outside_functions() and
|
if (ops.executing_eagerly_outside_functions() and
|
||||||
@ -999,8 +998,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
with ops.name_scope_v2(self.name):
|
with ops.name_scope_v2(self.name):
|
||||||
self._maybe_build(inputs)
|
self._maybe_build(inputs)
|
||||||
cast_inputs = self._maybe_cast_inputs(inputs, input_list)
|
cast_inputs = self._maybe_cast_inputs(inputs, input_list)
|
||||||
with base_layer_utils.autocast_context_manager(
|
with ops.enable_auto_cast_variables(self._compute_dtype_object):
|
||||||
self._compute_dtype_object):
|
|
||||||
outputs = self.call(cast_inputs, *args, **kwargs)
|
outputs = self.call(cast_inputs, *args, **kwargs)
|
||||||
self._handle_activity_regularization(inputs, outputs)
|
self._handle_activity_regularization(inputs, outputs)
|
||||||
self._set_mask_metadata(inputs, outputs, input_masks, build_graph)
|
self._set_mask_metadata(inputs, outputs, input_masks, build_graph)
|
||||||
@ -1288,7 +1286,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
if callable(loss):
|
if callable(loss):
|
||||||
# We run the loss without autocasting, as regularizers are often
|
# We run the loss without autocasting, as regularizers are often
|
||||||
# numerically unstable in float16.
|
# numerically unstable in float16.
|
||||||
with base_layer_utils.autocast_context_manager(None):
|
with ops.enable_auto_cast_variables(None):
|
||||||
loss = loss()
|
loss = loss()
|
||||||
if loss is None:
|
if loss is None:
|
||||||
return None # Will be filtered out when computing the .losses property
|
return None # Will be filtered out when computing the .losses property
|
||||||
|
@ -481,23 +481,6 @@ def training_arg_passed_to_call(argspec, args, kwargs):
|
|||||||
return 'training' in full_args and full_args['training'] is not None
|
return 'training' in full_args and full_args['training'] is not None
|
||||||
|
|
||||||
|
|
||||||
def autocast_context_manager(dtype):
|
|
||||||
"""Returns a context manager to autocast AutoCastVariables.
|
|
||||||
|
|
||||||
Under this context manager, AutoCastVariables will be casted to `dtype` if
|
|
||||||
`dtype` is floating-point. Otherwise, AutoCastVariables will not be casted.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dtype: The dtype to cast AutoCastVariables to, or None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A context manager to automatically cast AutoCastVariables.
|
|
||||||
"""
|
|
||||||
if dtype and not dtype.is_floating:
|
|
||||||
dtype = None
|
|
||||||
return ops.get_default_graph()._enable_auto_casting_variables(dtype) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
|
|
||||||
def is_subclassed(layer):
|
def is_subclassed(layer):
|
||||||
"""Returns True if the object is a subclassed layer or subclassed model."""
|
"""Returns True if the object is a subclassed layer or subclassed model."""
|
||||||
return (layer.__module__.find('keras.engine') == -1 and
|
return (layer.__module__.find('keras.engine') == -1 and
|
||||||
|
@ -767,8 +767,7 @@ class Layer(base_layer.Layer):
|
|||||||
|
|
||||||
if not self.dynamic:
|
if not self.dynamic:
|
||||||
try:
|
try:
|
||||||
with base_layer_utils.autocast_context_manager(
|
with ops.enable_auto_cast_variables(self._compute_dtype_object):
|
||||||
self._compute_dtype_object):
|
|
||||||
outputs = call_fn(cast_inputs, *args, **kwargs)
|
outputs = call_fn(cast_inputs, *args, **kwargs)
|
||||||
|
|
||||||
except errors.OperatorNotAllowedInGraphError as e:
|
except errors.OperatorNotAllowedInGraphError as e:
|
||||||
@ -812,8 +811,7 @@ class Layer(base_layer.Layer):
|
|||||||
with backend.name_scope(self._name_scope()):
|
with backend.name_scope(self._name_scope()):
|
||||||
self._maybe_build(inputs)
|
self._maybe_build(inputs)
|
||||||
cast_inputs = self._maybe_cast_inputs(inputs)
|
cast_inputs = self._maybe_cast_inputs(inputs)
|
||||||
with base_layer_utils.autocast_context_manager(
|
with ops.enable_auto_cast_variables(self._compute_dtype_object):
|
||||||
self._compute_dtype_object):
|
|
||||||
outputs = self.call(cast_inputs, *args, **kwargs)
|
outputs = self.call(cast_inputs, *args, **kwargs)
|
||||||
self._handle_activity_regularization(inputs, outputs)
|
self._handle_activity_regularization(inputs, outputs)
|
||||||
self._set_mask_metadata(inputs, outputs, input_masks)
|
self._set_mask_metadata(inputs, outputs, input_masks)
|
||||||
@ -1023,7 +1021,7 @@ class Layer(base_layer.Layer):
|
|||||||
if callable(loss):
|
if callable(loss):
|
||||||
# We run the loss without autocasting, as regularizers are often
|
# We run the loss without autocasting, as regularizers are often
|
||||||
# numerically unstable in float16.
|
# numerically unstable in float16.
|
||||||
with base_layer_utils.autocast_context_manager(None):
|
with ops.enable_auto_cast_variables(None):
|
||||||
loss = loss()
|
loss = loss()
|
||||||
if loss is None:
|
if loss is None:
|
||||||
return None # Will be filtered out when computing the .losses property
|
return None # Will be filtered out when computing the .losses property
|
||||||
|
@ -26,6 +26,7 @@ import weakref
|
|||||||
|
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
@ -521,8 +522,7 @@ def layer_call_wrapper(call_collection, method):
|
|||||||
with base_layer_utils.call_context().enter(
|
with base_layer_utils.call_context().enter(
|
||||||
layer, inputs=inputs, build_graph=False, training=training,
|
layer, inputs=inputs, build_graph=False, training=training,
|
||||||
saving=True):
|
saving=True):
|
||||||
with base_layer_utils.autocast_context_manager(
|
with ops.enable_auto_cast_variables(layer._compute_dtype_object): # pylint: disable=protected-access
|
||||||
layer._compute_dtype_object): # pylint: disable=protected-access
|
|
||||||
ret = method(*args, **kwargs)
|
ret = method(*args, **kwargs)
|
||||||
_restore_layer_losses(original_losses)
|
_restore_layer_losses(original_losses)
|
||||||
return ret
|
return ret
|
||||||
|
Loading…
x
Reference in New Issue
Block a user