diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index e5723a3ef98..9ab606d8038 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -30,12 +30,12 @@ from tensorflow.python.keras.engine.base_layer import Layer
 from tensorflow.python.keras.engine.input_spec import InputSpec
 from tensorflow.python.keras.utils import tf_utils
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variables as tf_variables
-from tensorflow.python.platform import device_context
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util.tf_export import keras_export
 
@@ -514,7 +514,7 @@ class BatchNormalizationBase(Layer):
     use_fused_avg_updates = (
         ops.executing_eagerly_outside_functions() and
         isinstance(self.momentum, (float, int)) and
-        device_context.enclosing_tpu_context() is None)
+        enclosing_xla_context() is None)
     if use_fused_avg_updates:
       exponential_avg_factor = 1.0 - self.momentum
     else:
@@ -930,6 +930,23 @@ def replace_in_base_docstring(replacements):
   return string
 
 
+def enclosing_xla_context():
+  """Recursively find and return the XLAControlFlowContext."""
+  graph = ops.get_default_graph()
+  while graph is not None:
+    # pylint: disable=protected-access
+    context_ = graph._get_control_flow_context()
+    # pylint: enable=protected-access
+    while context_ is not None:
+      if isinstance(context_, control_flow_ops.XLAControlFlowContext):
+        return context_
+      context_ = context_.outer_context
+    # This may be a FuncGraph due to defuns or v2 control flow. We need to
+    # find the original graph with the XLAControlFlowContext.
+    graph = getattr(graph, 'outer_graph', None)
+  return None
+
+
 @keras_export(v1=['keras.layers.BatchNormalization'])  # pylint: disable=missing-docstring
 class BatchNormalization(BatchNormalizationBase):