Don't do fused average updates inside XLA context as it may create extra tf.cond which causes OOM on TPUs.
PiperOrigin-RevId: 327294174 Change-Id: I7caa62d77e5c86a6afe7aaca22c7231d8f2304b6
This commit is contained in:
parent
f3c5b4ea29
commit
62cdfd542c
@ -30,12 +30,12 @@ from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.keras.engine.input_spec import InputSpec
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.platform import device_context
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
@ -514,7 +514,7 @@ class BatchNormalizationBase(Layer):
|
||||
use_fused_avg_updates = (
|
||||
ops.executing_eagerly_outside_functions() and
|
||||
isinstance(self.momentum, (float, int)) and
|
||||
device_context.enclosing_tpu_context() is None)
|
||||
enclosing_xla_context() is None)
|
||||
if use_fused_avg_updates:
|
||||
exponential_avg_factor = 1.0 - self.momentum
|
||||
else:
|
||||
@ -930,6 +930,23 @@ def replace_in_base_docstring(replacements):
|
||||
return string
|
||||
|
||||
|
||||
def enclosing_xla_context():
|
||||
"""Recursively find and return the XLAControlFlowContext."""
|
||||
graph = ops.get_default_graph()
|
||||
while graph is not None:
|
||||
# pylint: disable=protected-access
|
||||
context_ = graph._get_control_flow_context()
|
||||
# pylint: enable=protected-access
|
||||
while context_ is not None:
|
||||
if isinstance(context_, control_flow_ops.XLAControlFlowContext):
|
||||
return context_
|
||||
context_ = context_.outer_context
|
||||
# This may be a FuncGraph due to defuns or v2 control flow. We need to
|
||||
# find the original graph with the XLAControlFlowContext.
|
||||
graph = getattr(graph, 'outer_graph', None)
|
||||
return None
|
||||
|
||||
|
||||
@keras_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring
|
||||
class BatchNormalization(BatchNormalizationBase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user