Reduce Layer.__call__ overhead by ~15%.

Performance improvements to how Variable autocasting is enabled:

- Uses a class-based context manager to enable autocasting rather than
  contextlib.contextmanager wrapper (class-based is faster).
- Skips property accesses
- Fewer reads of expensive thread local objects

PiperOrigin-RevId: 314377779
Change-Id: Ia1cca0b994d3d3745407e1b568cc46e4b6cafd9e
This commit is contained in:
Thomas O'Malley 2020-06-02 12:09:03 -07:00 committed by TensorFlower Gardener
parent 895709e1a8
commit d7132b0c79
5 changed files with 42 additions and 39 deletions

View File

@ -5142,9 +5142,8 @@ class Graph(object):
def _auto_cast_variable_read_dtype(self, dtype):
self._thread_local._auto_cast_variable_read_dtype = dtype # pylint: disable=protected-access
@tf_contextlib.contextmanager
def _enable_auto_casting_variables(self, dtype):
"""Context manager to automatically cast AutoCastVariables.
"""Returns a context manager to automatically cast AutoCastVariables.
If an AutoCastVariable `var` is used under this context manager, it will be
casted to `dtype` before being used.
@ -5154,15 +5153,10 @@ class Graph(object):
Args:
dtype: The dtype that AutoCastVariables should be casted to.
Yields:
Nothing.
Returns:
Context manager.
"""
prev_read_dtype = self._auto_cast_variable_read_dtype
try:
self._auto_cast_variable_read_dtype = dtype
yield
finally:
self._auto_cast_variable_read_dtype = prev_read_dtype
return enable_auto_cast_variables(dtype, graph=self)
def _mutation_lock(self):
"""Returns a lock to guard code that creates & mutates ops.
@ -5179,6 +5173,36 @@ class Graph(object):
return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
class enable_auto_cast_variables(object):
"""Enables the autocasting of `AutoCastVariable`s.
Under this context manager, `AutoCastVariable`s will be cast to `dtype` if
`dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast.
"""
def __init__(self, dtype, graph=None):
if dtype and not dtype.is_floating:
self._dtype = None
else:
self._dtype = dtype
if graph is None:
self._graph = get_default_graph()
else:
self._graph = graph
def __enter__(self):
# For performance, access `_thread_local` attr directly rather than
# @property wrappers.
graph_thread_local = self._graph._thread_local
self._prev_read_dtype = getattr(graph_thread_local,
"_auto_cast_variable_read_dtype", None)
graph_thread_local._auto_cast_variable_read_dtype = self._dtype
def __exit__(self, type_arg, value_arg, traceback_arg):
self._graph._thread_local._auto_cast_variable_read_dtype = (
self._prev_read_dtype)
# TODO(agarwal): currently device directives in an outer eager scope will not
# apply to inner graph mode code. Fix that.

View File

@ -937,8 +937,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
call_fn = self.call
try:
with base_layer_utils.autocast_context_manager(
self._compute_dtype_object):
with ops.enable_auto_cast_variables(self._compute_dtype_object):
# Add auto_control_deps in V2 when they are not already added by
# a `tf.function`.
if (ops.executing_eagerly_outside_functions() and
@ -999,8 +998,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
with ops.name_scope_v2(self.name):
self._maybe_build(inputs)
cast_inputs = self._maybe_cast_inputs(inputs, input_list)
with base_layer_utils.autocast_context_manager(
self._compute_dtype_object):
with ops.enable_auto_cast_variables(self._compute_dtype_object):
outputs = self.call(cast_inputs, *args, **kwargs)
self._handle_activity_regularization(inputs, outputs)
self._set_mask_metadata(inputs, outputs, input_masks, build_graph)
@ -1288,7 +1286,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
if callable(loss):
# We run the loss without autocasting, as regularizers are often
# numerically unstable in float16.
with base_layer_utils.autocast_context_manager(None):
with ops.enable_auto_cast_variables(None):
loss = loss()
if loss is None:
return None # Will be filtered out when computing the .losses property

View File

@ -481,23 +481,6 @@ def training_arg_passed_to_call(argspec, args, kwargs):
return 'training' in full_args and full_args['training'] is not None
def autocast_context_manager(dtype):
"""Returns a context manager to autocast AutoCastVariables.
Under this context manager, AutoCastVariables will be casted to `dtype` if
`dtype` is floating-point. Otherwise, AutoCastVariables will not be casted.
Args:
dtype: The dtype to cast AutoCastVariables to, or None.
Returns:
A context manager to automatically cast AutoCastVariables.
"""
if dtype and not dtype.is_floating:
dtype = None
return ops.get_default_graph()._enable_auto_casting_variables(dtype) # pylint: disable=protected-access
def is_subclassed(layer):
"""Returns True if the object is a subclassed layer or subclassed model."""
return (layer.__module__.find('keras.engine') == -1 and

View File

@ -767,8 +767,7 @@ class Layer(base_layer.Layer):
if not self.dynamic:
try:
with base_layer_utils.autocast_context_manager(
self._compute_dtype_object):
with ops.enable_auto_cast_variables(self._compute_dtype_object):
outputs = call_fn(cast_inputs, *args, **kwargs)
except errors.OperatorNotAllowedInGraphError as e:
@ -812,8 +811,7 @@ class Layer(base_layer.Layer):
with backend.name_scope(self._name_scope()):
self._maybe_build(inputs)
cast_inputs = self._maybe_cast_inputs(inputs)
with base_layer_utils.autocast_context_manager(
self._compute_dtype_object):
with ops.enable_auto_cast_variables(self._compute_dtype_object):
outputs = self.call(cast_inputs, *args, **kwargs)
self._handle_activity_regularization(inputs, outputs)
self._set_mask_metadata(inputs, outputs, input_masks)
@ -1023,7 +1021,7 @@ class Layer(base_layer.Layer):
if callable(loss):
# We run the loss without autocasting, as regularizers are often
# numerically unstable in float16.
with base_layer_utils.autocast_context_manager(None):
with ops.enable_auto_cast_variables(None):
loss = loss()
if loss is None:
return None # Will be filtered out when computing the .losses property

View File

@ -26,6 +26,7 @@ import weakref
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
@ -521,8 +522,7 @@ def layer_call_wrapper(call_collection, method):
with base_layer_utils.call_context().enter(
layer, inputs=inputs, build_graph=False, training=training,
saving=True):
with base_layer_utils.autocast_context_manager(
layer._compute_dtype_object): # pylint: disable=protected-access
with ops.enable_auto_cast_variables(layer._compute_dtype_object): # pylint: disable=protected-access
ret = method(*args, **kwargs)
_restore_layer_losses(original_losses)
return ret