diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index f83f65152c1..8fee3057b8d 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -5142,9 +5142,8 @@ class Graph(object): def _auto_cast_variable_read_dtype(self, dtype): self._thread_local._auto_cast_variable_read_dtype = dtype # pylint: disable=protected-access - @tf_contextlib.contextmanager def _enable_auto_casting_variables(self, dtype): - """Context manager to automatically cast AutoCastVariables. + """Returns a context manager to automatically cast AutoCastVariables. If an AutoCastVariable `var` is used under this context manager, it will be casted to `dtype` before being used. @@ -5154,15 +5153,10 @@ class Graph(object): Args: dtype: The dtype that AutoCastVariables should be casted to. - Yields: - Nothing. + Returns: + Context manager. """ - prev_read_dtype = self._auto_cast_variable_read_dtype - try: - self._auto_cast_variable_read_dtype = dtype - yield - finally: - self._auto_cast_variable_read_dtype = prev_read_dtype + return enable_auto_cast_variables(dtype, graph=self) def _mutation_lock(self): """Returns a lock to guard code that creates & mutates ops. @@ -5179,6 +5173,36 @@ class Graph(object): return self._group_lock.group(_SESSION_RUN_LOCK_GROUP) +class enable_auto_cast_variables(object): + """Enables the autocasting of `AutoCastVariable`s. + + Under this context manager, `AutoCastVariable`s will be cast to `dtype` if + `dtype` is floating-point. Otherwise, `AutoCastVariable`s will not be cast. + """ + + def __init__(self, dtype, graph=None): + if dtype and not dtype.is_floating: + self._dtype = None + else: + self._dtype = dtype + if graph is None: + self._graph = get_default_graph() + else: + self._graph = graph + + def __enter__(self): + # For performance, access `_thread_local` attr directly rather than + # @property wrappers. + graph_thread_local = self._graph._thread_local + self._prev_read_dtype = getattr(graph_thread_local, + "_auto_cast_variable_read_dtype", None) + graph_thread_local._auto_cast_variable_read_dtype = self._dtype + + def __exit__(self, type_arg, value_arg, traceback_arg): + self._graph._thread_local._auto_cast_variable_read_dtype = ( + self._prev_read_dtype) + + # TODO(agarwal): currently device directives in an outer eager scope will not # apply to inner graph mode code. Fix that. diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index c7d25f31d73..3e03b2d9ddb 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -937,8 +937,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): call_fn = self.call try: - with base_layer_utils.autocast_context_manager( - self._compute_dtype_object): + with ops.enable_auto_cast_variables(self._compute_dtype_object): # Add auto_control_deps in V2 when they are not already added by # a `tf.function`. if (ops.executing_eagerly_outside_functions() and @@ -999,8 +998,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): with ops.name_scope_v2(self.name): self._maybe_build(inputs) cast_inputs = self._maybe_cast_inputs(inputs, input_list) - with base_layer_utils.autocast_context_manager( - self._compute_dtype_object): + with ops.enable_auto_cast_variables(self._compute_dtype_object): outputs = self.call(cast_inputs, *args, **kwargs) self._handle_activity_regularization(inputs, outputs) self._set_mask_metadata(inputs, outputs, input_masks, build_graph) @@ -1288,7 +1286,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): if callable(loss): # We run the loss without autocasting, as regularizers are often # numerically unstable in float16. - with base_layer_utils.autocast_context_manager(None): + with ops.enable_auto_cast_variables(None): loss = loss() if loss is None: return None # Will be filtered out when computing the .losses property diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index d7bd3d5d372..fa56807f4a5 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -481,23 +481,6 @@ def training_arg_passed_to_call(argspec, args, kwargs): return 'training' in full_args and full_args['training'] is not None -def autocast_context_manager(dtype): - """Returns a context manager to autocast AutoCastVariables. - - Under this context manager, AutoCastVariables will be casted to `dtype` if - `dtype` is floating-point. Otherwise, AutoCastVariables will not be casted. - - Args: - dtype: The dtype to cast AutoCastVariables to, or None. - - Returns: - A context manager to automatically cast AutoCastVariables. - """ - if dtype and not dtype.is_floating: - dtype = None - return ops.get_default_graph()._enable_auto_casting_variables(dtype) # pylint: disable=protected-access - - def is_subclassed(layer): """Returns True if the object is a subclassed layer or subclassed model.""" return (layer.__module__.find('keras.engine') == -1 and diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index 725334f8535..66eb93b64ea 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -767,8 +767,7 @@ class Layer(base_layer.Layer): if not self.dynamic: try: - with base_layer_utils.autocast_context_manager( - self._compute_dtype_object): + with ops.enable_auto_cast_variables(self._compute_dtype_object): outputs = call_fn(cast_inputs, *args, **kwargs) except errors.OperatorNotAllowedInGraphError as e: @@ -812,8 +811,7 @@ class Layer(base_layer.Layer): with backend.name_scope(self._name_scope()): self._maybe_build(inputs) cast_inputs = self._maybe_cast_inputs(inputs) - with base_layer_utils.autocast_context_manager( - self._compute_dtype_object): + with ops.enable_auto_cast_variables(self._compute_dtype_object): outputs = self.call(cast_inputs, *args, **kwargs) self._handle_activity_regularization(inputs, outputs) self._set_mask_metadata(inputs, outputs, input_masks) @@ -1023,7 +1021,7 @@ class Layer(base_layer.Layer): if callable(loss): # We run the loss without autocasting, as regularizers are often # numerically unstable in float16. - with base_layer_utils.autocast_context_manager(None): + with ops.enable_auto_cast_variables(None): loss = loss() if loss is None: return None # Will be filtered out when computing the .losses property diff --git a/tensorflow/python/keras/saving/saved_model/save_impl.py b/tensorflow/python/keras/saving/saved_model/save_impl.py index 7802470c523..f2e6c967b14 100644 --- a/tensorflow/python/keras/saving/saved_model/save_impl.py +++ b/tensorflow/python/keras/saving/saved_model/save_impl.py @@ -26,6 +26,7 @@ import weakref from tensorflow.python.eager import def_function from tensorflow.python.eager import function +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend as K @@ -521,8 +522,7 @@ def layer_call_wrapper(call_collection, method): with base_layer_utils.call_context().enter( layer, inputs=inputs, build_graph=False, training=training, saving=True): - with base_layer_utils.autocast_context_manager( - layer._compute_dtype_object): # pylint: disable=protected-access + with ops.enable_auto_cast_variables(layer._compute_dtype_object): # pylint: disable=protected-access ret = method(*args, **kwargs) _restore_layer_losses(original_losses) return ret