Merge pull request #45718 from geetachavan1/cherrypicks_CI7J6

[Cherrypick:r2.3] Don't do fused average updates inside XLA context as it may create extra tf.cond which causes OOM on TPUs.
This commit is contained in:
Mihai Maruseac 2020-12-17 09:24:54 -08:00 committed by GitHub
commit 804dab6a4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -30,12 +30,12 @@ from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import device_context
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export
@ -514,7 +514,7 @@ class BatchNormalizationBase(Layer):
use_fused_avg_updates = (
ops.executing_eagerly_outside_functions() and
isinstance(self.momentum, (float, int)) and
device_context.enclosing_tpu_context() is None)
enclosing_xla_context() is None)
if use_fused_avg_updates:
exponential_avg_factor = 1.0 - self.momentum
else:
@ -930,6 +930,23 @@ def replace_in_base_docstring(replacements):
return string
def enclosing_xla_context():
"""Recursively find and return the XLAControlFlowContext."""
graph = ops.get_default_graph()
while graph is not None:
# pylint: disable=protected-access
context_ = graph._get_control_flow_context()
# pylint: enable=protected-access
while context_ is not None:
if isinstance(context_, control_flow_ops.XLAControlFlowContext):
return context_
context_ = context_.outer_context
# This may be a FuncGraph due to defuns or v2 control flow. We need to
# find the original graph with the XLAControlFlowContext.
graph = getattr(graph, 'outer_graph', None)
return None
@keras_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring
class BatchNormalization(BatchNormalizationBase):