diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD
index db27dd79062..ff82714ce3e 100644
--- a/tensorflow/python/tpu/BUILD
+++ b/tensorflow/python/tpu/BUILD
@@ -137,7 +137,6 @@ py_library(
         "tpu_strategy_util.py",
         "tpu_system_metadata.py",
         "training_loop.py",
-        "xla.py",
     ],
     srcs_version = "PY2AND3",
     deps = [
@@ -164,6 +163,7 @@ py_library(
         "//tensorflow/python:training",
         "//tensorflow/python:util",
         "//tensorflow/python:variable_scope",
+        "//tensorflow/python/compiler/xla",
         "//tensorflow/python/ops/losses",
         "//tensorflow/python/tpu/profiler",
     ],
diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py
index 55273a5203e..d5d7ea266da 100644
--- a/tensorflow/python/tpu/tpu.py
+++ b/tensorflow/python/tpu/tpu.py
@@ -25,6 +25,7 @@ from six.moves import xrange  # pylint: disable=redefined-builtin
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
 from tensorflow.python.compat import compat as api_compat
+from tensorflow.python.compiler.xla import xla
 from tensorflow.python.framework import device as pydev
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -36,7 +37,6 @@ from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.tpu import tpu_function
-from tensorflow.python.tpu import xla
 from tensorflow.python.tpu.ops import tpu_ops
 from tensorflow.python.util import compat
 from tensorflow.python.util import nest
diff --git a/tensorflow/python/tpu/training_loop.py b/tensorflow/python/tpu/training_loop.py
index cffeb7e915a..06c84e56416 100644
--- a/tensorflow/python/tpu/training_loop.py
+++ b/tensorflow/python/tpu/training_loop.py
@@ -19,12 +19,12 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.compiler.xla import xla
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.tpu import tensor_tracer
 from tensorflow.python.tpu import tpu_function
-from tensorflow.python.tpu import xla
 
 
 def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
diff --git a/tensorflow/python/tpu/xla.py b/tensorflow/python/tpu/xla.py
deleted file mode 100644
index 58476fae3d1..00000000000
--- a/tensorflow/python/tpu/xla.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# =============================================================================
-"""XLA utility functions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import collections
-
-from tensorflow.python.util import tf_inspect
-
-
-def is_flat(outputs):
-  """Checks if outputs is a flat structure.
-
-    Following structures and values are considered flat:
-    1) None
-    2) A single object
-    3) A list or tuple of Tensors/Operations
-
-    The only structures that this function understands are sequences and
-    dictionaries.  E.g. this means that if outputs contains a single
-    user-defined Object, it is considered to be flat. Errors are raised later on
-    if that Object cannot be converted to a Tensor.
-
-  Args:
-    outputs: Output from `computation` inside `xla.compile`.
-
-  Returns:
-    A boolean indicates whether outputs is flat.
-  """
-  # If outputs is a list or tuple, check if it has any nested structure. If
-  # there is, then outputs is non-flat.
-  if isinstance(outputs, collections.Sequence):
-    for o in outputs:
-      if isinstance(o, collections.Sequence) or isinstance(o, dict):
-        return False
-
-  # If outputs is a dict, it is non-flat.
-  if isinstance(outputs, dict):
-    return False
-
-  # Getting here means either outputs itself is a single non-structured value
-  # or it is a flat list of single non-structured values.
-  return True
-
-
-def check_function_argument_count(func, input_arity, infeed_queue):
-  """Validate the number of input arguments to an XLA function.
-
-  Args:
-    func: the Python function that will be called to generate the body of an XLA
-      computation graph.
-    input_arity: the number of explicit arguments supplied by the caller.
-    infeed_queue: if not None, the infeed queue that will supply
-      additional arguments to the function.
-
-  Returns:
-    None if function can be called with the supplied number of
-      arguments, or an error string if it cannot.
-  """
-  def format_error(complaint, quantity):
-    return '%s %d argument%s' % (complaint, quantity, ''
-                                 if quantity == 1 else 's')
-
-  num_args_supplied = input_arity
-  if infeed_queue is not None:
-    num_args_supplied += infeed_queue.number_of_tuple_elements
-  arg_spec = tf_inspect.getargspec(func)
-  num_func_args = len(arg_spec.args)
-  if arg_spec.defaults is None:
-    num_func_defaults = 0
-  else:
-    num_func_defaults = len(arg_spec.defaults)
-  min_func_args = num_func_args - num_func_defaults
-  if num_args_supplied < min_func_args:
-    # The required number of arguments is not enough to call the function.
-    if num_func_defaults == 0 and arg_spec.varargs is None:
-      return format_error('exactly', num_func_args)
-    else:
-      return format_error('at least', min_func_args)
-  if arg_spec.varargs is None and num_args_supplied > num_func_args:
-    # The required number of arguments is too many to call the function.
-    if num_func_defaults == 0:
-      return format_error('exactly', num_func_args)
-    else:
-      return format_error('at most', num_func_args)
-  # Reaching here means either
-  # 1) There are varargs, func can accept any number of arguments greater than
-  # the minimum.
-  # 2) Number of supplied arguments falls in range of acceptable argument count
-  # of func.
-  return None