From 62cdfd542ca8bfba5aa30d06e4b34444146bf40e Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Tue, 18 Aug 2020 13:24:04 -0700 Subject: [PATCH] Don't do fused average updates inside XLA context as it may create extra tf.cond which causes OOM on TPUs. PiperOrigin-RevId: 327294174 Change-Id: I7caa62d77e5c86a6afe7aaca22c7231d8f2304b6 --- .../python/keras/layers/normalization.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index e5723a3ef98..9ab606d8038 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -30,12 +30,12 @@ from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.engine.input_spec import InputSpec from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables as tf_variables -from tensorflow.python.platform import device_context from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import keras_export @@ -514,7 +514,7 @@ class BatchNormalizationBase(Layer): use_fused_avg_updates = ( ops.executing_eagerly_outside_functions() and isinstance(self.momentum, (float, int)) and - device_context.enclosing_tpu_context() is None) + enclosing_xla_context() is None) if use_fused_avg_updates: exponential_avg_factor = 1.0 - self.momentum else: @@ -930,6 +930,23 @@ def replace_in_base_docstring(replacements): return string +def enclosing_xla_context(): + """Recursively find and return the XLAControlFlowContext.""" + graph = ops.get_default_graph() + while graph is not None: + # pylint: disable=protected-access + context_ = graph._get_control_flow_context() + # pylint: enable=protected-access + while context_ is not None: + if isinstance(context_, control_flow_ops.XLAControlFlowContext): + return context_ + context_ = context_.outer_context + # This may be a FuncGraph due to defuns or v2 control flow. We need to + # find the original graph with the XLAControlFlowContext. + graph = getattr(graph, 'outer_graph', None) + return None + + @keras_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring class BatchNormalization(BatchNormalizationBase):