From bcfb60d0a138d215980b0881e4619a2d9b20e489 Mon Sep 17 00:00:00 2001
From: George Karpenkov <cheshire@google.com>
Date: Tue, 28 Jul 2020 13:23:04 -0700
Subject: [PATCH] [TF2XLA] [NFC] Break apart the [TF2XLA/MLIR] -> xla_compiler
 dependency edge

This is needed for invoking the MLIR tf2xla bridge from xla_compiler.

This CL breaks apart items from xla_compiler into individual build targets,
which are then depended on from the MLIR TF bridge.

PiperOrigin-RevId: 323640340
Change-Id: I78b972503db9e7b5254014ca7e889005490d8339
---
 tensorflow/compiler/aot/BUILD                 |   2 +
 tensorflow/compiler/jit/BUILD                 |   8 +
 tensorflow/compiler/jit/kernels/BUILD         |   1 +
 tensorflow/compiler/mlir/BUILD                |   1 +
 tensorflow/compiler/mlir/tensorflow/BUILD     |   7 +-
 .../tensorflow/utils/compile_mlir_util.cc     |  33 ++-
 .../mlir/tensorflow/utils/compile_mlir_util.h |  17 +-
 tensorflow/compiler/mlir/xla/BUILD            |  10 +-
 .../compiler/mlir/xla/mlir_hlo_to_hlo.cc      |   7 +-
 .../compiler/mlir/xla/mlir_hlo_to_hlo.h       |   5 +-
 .../xla/transforms/legalize_tf_with_tf2xla.cc |   7 +-
 .../xla/transforms/mhlo_to_lhlo_with_xla.cc   |   2 +
 tensorflow/compiler/tf2xla/BUILD              | 189 +++++++++++++++++-
 tensorflow/compiler/tf2xla/kernels/BUILD      |  14 ++
 tensorflow/compiler/tf2xla/lib/BUILD          |   1 +
 tensorflow/compiler/tf2xla/xla_argument.cc    |  53 +++++
 tensorflow/compiler/tf2xla/xla_argument.h     | 121 +++++++++++
 tensorflow/compiler/tf2xla/xla_compiler.cc    | 122 -----------
 tensorflow/compiler/tf2xla/xla_compiler.h     | 180 +----------------
 tensorflow/compiler/tf2xla/xla_context.cc     |   1 -
 tensorflow/compiler/tf2xla/xla_context.h      |   2 +-
 tensorflow/compiler/tf2xla/xla_expression.cc  |  19 ++
 tensorflow/compiler/tf2xla/xla_expression.h   |   7 +
 tensorflow/compiler/tf2xla/xla_helpers.cc     |  91 ++++++++-
 tensorflow/compiler/tf2xla/xla_helpers.h      |  95 ++++++++-
 tensorflow/compiler/tf2xla/xla_op_kernel.cc   |  45 ++---
 tensorflow/compiler/tf2xla/xla_op_kernel.h    |  10 +-
 tensorflow/compiler/tf2xla/xla_resource.cc    |   1 -
 tensorflow/core/tpu/BUILD                     |   3 +-
 tensorflow/core/tpu/kernels/BUILD             |   2 +
 30 files changed, 668 insertions(+), 388 deletions(-)
 create mode 100644 tensorflow/compiler/tf2xla/xla_argument.cc
 create mode 100644 tensorflow/compiler/tf2xla/xla_argument.h

diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index d091146c75a..ff255dd9cc1 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -308,6 +308,8 @@ cc_library(
     ],
     deps = [
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_context",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/core:framework",
     ],
     alwayslink = 1,
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index b52a350dc48..ecbb1a5d200 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -95,6 +95,7 @@ cc_library(
         ":xla_kernel_creator",  # buildcleaner: keep
         "//tensorflow/compiler/jit/kernels:xla_ops",
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/compiler/xla/service:cpu_plugin",  # buildcleaner: keep
         "//tensorflow/core:core_cpu_internal",
@@ -115,6 +116,7 @@ cc_library(
         ":xla_kernel_creator",  # buildcleaner: keep
         "//tensorflow/compiler/jit/kernels:xla_ops",
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/compiler/xla/service:gpu_plugin",  # buildcleaner: keep
         "//tensorflow/core:core_cpu_internal",
@@ -172,6 +174,7 @@ XLA_DEVICE_DEPS = [
     "//tensorflow/compiler/tf2xla:common",
     "//tensorflow/compiler/tf2xla:tf2xla_util",
     "//tensorflow/compiler/tf2xla:xla_compiler",
+    "//tensorflow/compiler/tf2xla:xla_op_registry",
     "//tensorflow/compiler/tf2xla/kernels:xla_ops",
     "//tensorflow/compiler/xla:util",
     "//tensorflow/compiler/xla/client:client_library",
@@ -343,6 +346,7 @@ cc_library(
         "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_context",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla/client:client_library",
@@ -406,6 +410,7 @@ cc_library(
         ":compilation_passes",
         "//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration",
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
@@ -641,6 +646,7 @@ cc_library(
         "//tensorflow/compiler/tf2xla:side_effect_util",
         "//tensorflow/compiler/tf2xla:tf2xla_util",
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
         "//tensorflow/compiler/tf2xla/cc:xla_ops",
         "//tensorflow/compiler/xla:status_macros",
@@ -700,6 +706,7 @@ cc_library(
     hdrs = ["device_util.h"],
     deps = [
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/core:framework",
@@ -914,6 +921,7 @@ cc_library(
         "//tensorflow/compiler/jit/graphcycles",
         "//tensorflow/compiler/tf2xla:resource_operation_table",
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/core:core_cpu",
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 347bae087df..eb9ad8a2e85 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -21,6 +21,7 @@ XLA_OPS_DEPS = [
     "//tensorflow/compiler/tf2xla:common",
     "//tensorflow/compiler/tf2xla:tf2xla_util",
     "//tensorflow/compiler/tf2xla:xla_compiler",
+    "//tensorflow/compiler/tf2xla:xla_op_registry",
     "//tensorflow/compiler/xla:executable_run_options",
     "//tensorflow/compiler/xla:status_macros",
     "//tensorflow/compiler/xla:statusor",
diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD
index 57f923caa91..01c187790b7 100644
--- a/tensorflow/compiler/mlir/BUILD
+++ b/tensorflow/compiler/mlir/BUILD
@@ -150,6 +150,7 @@ tf_cc_binary(
         "//tensorflow/compiler/mlir/tensorflow:translate_registration",
         "//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op",
         "//tensorflow/compiler/mlir/xla:xla_mlir_translate",
+        "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:tensorflow",
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 2a800cfc8c4..fe1f47d8d69 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -1477,10 +1477,13 @@ COMPILE_MLIR_UTIL_DEPS = [
     "//tensorflow/compiler/mlir/xla:xla_legalize_tf",
     "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
     "//tensorflow/compiler/tf2xla:common",
-    "//tensorflow/compiler/tf2xla:xla_compiler",
+    "//tensorflow/compiler/tf2xla:xla_helpers",
+    "//tensorflow/compiler/tf2xla:xla_argument",
+    "//tensorflow/compiler/xla/client:xla_computation",
+    "//tensorflow/core/common_runtime:core_cpu_internal",
+    "//tensorflow/core/platform:logging",
     "//tensorflow/core:framework",
     "//tensorflow/core:protos_all_cc",
-    "//tensorflow/core/platform:logging",
     "//tensorflow/stream_executor/lib",
     "//tensorflow/compiler/xla:xla_data_proto_cc",
     "//tensorflow/compiler/xla/service:hlo",
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
index 5e548da55f1..16bc851d3a6 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
@@ -83,7 +83,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string,
 Status GetXlaInputShapes(
     mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
     bool use_tuple_args,
-    const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+    const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
     std::vector<xla::Shape>* xla_input_shapes) {
   xla_input_shapes->clear();
 
@@ -135,9 +135,8 @@ Status GetXlaInputShapes(
 // output based on static shapes in MLIR module
 Status GetOutputInfo(
     mlir::ModuleOp module,
-    const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
-    xla::Shape* xla_output_shape,
-    std::vector<XlaCompiler::OutputDescription>* outputs) {
+    const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
+    xla::Shape* xla_output_shape, std::vector<XlaOutputDescription>* outputs) {
   auto shape_representation_fn_no_fast_memory =
       [shape_representation_fn](const TensorShape& shape, DataType dtype) {
         return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false);
@@ -161,7 +160,7 @@ Status GetOutputInfo(
 
     // Construct OutputDescription for result.
     outputs->emplace_back();
-    XlaCompiler::OutputDescription& out_desc = outputs->back();
+    XlaOutputDescription& out_desc = outputs->back();
     TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &out_desc.type));
     // TODO(ycao): Support constant output.
     out_desc.is_constant = false;
@@ -185,7 +184,7 @@ Status GetOutputInfo(
 // TODO(ycao): Implement logic to compute resource updates when we need to
 // support graphs with resource updates in MLIR-based TF compiler bridge.
 void GetResourceUpdatesForMlir(
-    std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
+    std::vector<XlaResourceUpdate>* resource_updates) {
   resource_updates->clear();
 }
 
@@ -265,7 +264,7 @@ Status ConvertMLIRToXlaComputation(
     mlir::ModuleOp module_op, llvm::StringRef device_type,
     xla::XlaComputation* xla_computation, bool use_tuple_args,
     bool return_tuple,
-    const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+    const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
     std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
   mlir::PassManager tf2xla(module_op.getContext());
   tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
@@ -341,8 +340,8 @@ Status ConvertMLIRToXlaComputation(
 static Status CompileMlirToXlaHlo(
     mlir::ModuleOp module_op, llvm::ArrayRef<TensorShape> arg_shapes,
     llvm::StringRef device_type, bool use_tuple_args,
-    XlaCompiler::ShapeRepresentationFn shape_representation_fn,
-    XlaCompiler::CompilationResult* compilation_result,
+    XlaHelpers::ShapeRepresentationFn shape_representation_fn,
+    XlaCompilationResult* compilation_result,
     std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
   if (VLOG_IS_ON(1))
     tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
@@ -391,8 +390,8 @@ static Status CompileMlirToXlaHlo(
 Status CompileSerializedMlirToXlaHlo(
     llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
     llvm::StringRef device_type, bool use_tuple_args,
-    const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
-    XlaCompiler::CompilationResult* compilation_result,
+    const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
+    XlaCompilationResult* compilation_result,
     std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
   RegisterDialects();
   mlir::MLIRContext mlir_context;
@@ -411,16 +410,16 @@ Status CompileSerializedMlirToXlaHlo(
 // removed from the signature.
 // Returns the original indices for the other arguments on success.
 static StatusOr<std::vector<int>> RewriteWithArgs(
-    mlir::ModuleOp module, llvm::ArrayRef<const XlaCompiler::Argument> args) {
+    mlir::ModuleOp module, llvm::ArrayRef<const XlaArgument> args) {
   mlir::FuncOp main_fn = module.lookupSymbol<mlir::FuncOp>("main");
   std::vector<int> params;
 
   auto builder = mlir::OpBuilder(main_fn.getBody());
   std::vector<int> args_to_erase;
   for (int idx = 0; idx < args.size(); idx++) {
-    const XlaCompiler::Argument& xla_arg = args[idx];
+    const XlaArgument& xla_arg = args[idx];
     mlir::BlockArgument mlir_arg = main_fn.getArgument(idx);
-    if (xla_arg.kind != XlaCompiler::Argument::kConstant) {
+    if (xla_arg.kind != XlaArgument::kConstant) {
       params.push_back(idx);
       continue;
     }
@@ -439,11 +438,11 @@ static StatusOr<std::vector<int>> RewriteWithArgs(
 }
 
 Status CompileGraphToXlaHlo(
-    const Graph& graph, llvm::ArrayRef<const XlaCompiler::Argument> args,
+    const Graph& graph, llvm::ArrayRef<const XlaArgument> args,
     llvm::StringRef device_type, bool use_tuple_args,
     const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
-    const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
-    XlaCompiler::CompilationResult* compilation_result,
+    const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
+    XlaCompilationResult* compilation_result,
     std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
   RegisterDialects();
 
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h
index 24b60dcb346..719a96f52d4 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h
@@ -20,7 +20,10 @@ limitations under the License.
 #include "llvm/ADT/StringRef.h"
 #include "mlir/IR/Module.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
-#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_argument.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
 #include "tensorflow/stream_executor/lib/statusor.h"
@@ -57,7 +60,7 @@ Status ConvertMLIRToXlaComputation(
     mlir::ModuleOp module_op, llvm::StringRef device_type,
     xla::XlaComputation* xla_computation, bool use_tuple_args,
     bool return_tuple,
-    const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr,
+    const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr,
     std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
 
 // Compiles a serialized MLIR module into XLA HLO, generates all accompanying
@@ -65,17 +68,17 @@ Status ConvertMLIRToXlaComputation(
 Status CompileSerializedMlirToXlaHlo(
     llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
     llvm::StringRef device_type, bool use_tuple_args,
-    const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
-    XlaCompiler::CompilationResult* compilation_result,
+    const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
+    XlaCompilationResult* compilation_result,
     std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
 
 // Same as the above but takes input as TensorFlow Graph.
 Status CompileGraphToXlaHlo(
-    const Graph& graph, llvm::ArrayRef<const XlaCompiler::Argument> args,
+    const Graph& graph, llvm::ArrayRef<const XlaArgument> args,
     llvm::StringRef device_type, bool use_tuple_args,
     const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
-    const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
-    XlaCompiler::CompilationResult* compilation_result,
+    const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
+    XlaCompilationResult* compilation_result,
     std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index 838b060079c..55daec0395e 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -92,7 +92,11 @@ cc_library(
         "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
         "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
         "//tensorflow/compiler/mlir/tensorflow:translate_utils",
-        "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_compilation_device",
+        "//tensorflow/compiler/tf2xla:xla_context",
+        "//tensorflow/compiler/tf2xla:xla_expression",
+        "//tensorflow/compiler/tf2xla:xla_helpers",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/compiler/xla/client:xla_builder",
         "//tensorflow/core:core_cpu_lib",
         "//tensorflow/core:framework",
@@ -125,8 +129,10 @@ cc_library(
         "//tensorflow/compiler/mlir/hlo",
         "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
         "//tensorflow/compiler/mlir/hlo:lhlo",
+        "//tensorflow/compiler/xla:debug_options_flags",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla/service:backend",
         "//tensorflow/compiler/xla/service:buffer_assignment",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/service:hlo_casting_utils",
@@ -228,7 +234,7 @@ cc_library(
         "//tensorflow/compiler/mlir/tensorflow:convert_type",
         "//tensorflow/compiler/mlir/tensorflow:error_util",
         "//tensorflow/compiler/tf2xla:common",
-        "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_helpers",
         "//tensorflow/compiler/xla:comparison_util",
         "//tensorflow/compiler/xla:literal_util",
         "//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index a4c3c43cfbf..e45cf1b56ee 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -43,7 +43,6 @@ limitations under the License.
 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
-#include "tensorflow/compiler/tf2xla/xla_compiler.h"
 #include "tensorflow/compiler/xla/client/lib/matrix.h"
 #include "tensorflow/compiler/xla/client/lib/quantize.h"
 #include "tensorflow/compiler/xla/client/lib/slicing.h"
@@ -463,7 +462,7 @@ class ConvertToHloModule {
   // single value.
   explicit ConvertToHloModule(
       mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
-      tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn)
+      tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn)
       : module_(module),
         module_builder_("main"),
         use_tuple_args_(use_tuple_args),
@@ -545,7 +544,7 @@ class ConvertToHloModule {
 
   // Shape representation function to determine entry function argument and
   // result shapes.
-  tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
+  tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_;
 
   // Unique suffix to give to the name of the next lowered region.
   size_t region_id_ = 0;
@@ -1500,7 +1499,7 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
 
 Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
                            bool use_tuple_args, bool return_tuple,
-                           const tensorflow::XlaCompiler::ShapeRepresentationFn
+                           const tensorflow::XlaHelpers::ShapeRepresentationFn
                                shape_representation_fn) {
   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
   ConvertToHloModule converter(module, use_tuple_args, return_tuple,
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
index 8bfe4c76b04..d84aa92d3e2 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
@@ -18,9 +18,10 @@ limitations under the License.
 
 #include "mlir/IR/Module.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
-#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/core/framework/tensor_shape.h"
 
 namespace mlir {
 
@@ -33,7 +34,7 @@ namespace mlir {
 // single value.
 Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto,
                            bool use_tuple_args, bool return_tuple,
-                           const tensorflow::XlaCompiler::ShapeRepresentationFn
+                           const tensorflow::XlaHelpers::ShapeRepresentationFn
                                shape_representation_fn = nullptr);
 
 // Creates XlaOp equivalent of a given MLIR operation using the operand info
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index 34e12d3300e..1743ae7be17 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -48,7 +48,8 @@ limitations under the License.
 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
 #include "tensorflow/compiler/tf2xla/xla_context.h"
 #include "tensorflow/compiler/tf2xla/xla_expression.h"
-#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/common_runtime/device_factory.h"
@@ -410,7 +411,7 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() {
         device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(),
         shape_or.ValueOrDie());
     tensorflow::Tensor& tensor = tensors.back();
-    tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expr, &tensor);
+    tensorflow::XlaExpression::AssignExpressionToTensor(expr, &tensor);
     inputs.emplace_back(&tensor);
   }
 
@@ -438,7 +439,7 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() {
   for (int i = 0, e = op_->getNumResults(); i < e; i++) {
     tensorflow::Tensor* output = op_context.mutable_output(i);
     const tensorflow::XlaExpression* expr =
-        tensorflow::XlaOpKernelContext::CastExpressionFromTensor(*output);
+        tensorflow::XlaExpression::CastExpressionFromTensor(*output);
     if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp)
       return op_->emitError(
           "expects XlaExpression of kind kXlaOp in compiled output");
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index 519068893e7..d45f1ba8ec6 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -37,6 +37,8 @@ limitations under the License.
 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
+#include "tensorflow/compiler/xla/debug_options_flags.h"
+#include "tensorflow/compiler/xla/service/backend.h"
 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 663e34c2b8e..1e57c11b2cf 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -50,6 +50,7 @@ cc_library(
     visibility = ["//visibility:public"],
     deps = [
         ":xla_compiler",
+        ":xla_op_registry",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/core:framework",
         "//tensorflow/core:framework_internal",
@@ -145,6 +146,7 @@ cc_library(
         ":tf2xla_proto_cc",
         ":tf2xla_util",
         ":xla_compiler",
+        ":xla_op_registry",
         "//tensorflow/compiler/aot:aot_only_var_handle_op",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/compiler/xla/client",
@@ -316,14 +318,8 @@ cc_library(
     srcs = [
         "const_analysis.cc",
         "graph_compiler.cc",
-        "xla_compilation_device.cc",
         "xla_compiler.cc",
-        "xla_context.cc",
-        "xla_expression.cc",
-        "xla_helpers.cc",
         "xla_op_kernel.cc",
-        "xla_op_registry.cc",
-        "xla_resource.cc",
         "xla_cpu_backend.cc",
     ] + if_cuda_is_configured([
         "xla_gpu_backend.cc",
@@ -333,14 +329,10 @@ cc_library(
     hdrs = [
         "const_analysis.h",
         "graph_compiler.h",
-        "xla_compilation_device.h",
         "xla_compiler.h",
-        "xla_context.h",
-        "xla_expression.h",
         "xla_helpers.h",
         "xla_op_kernel.h",
         "xla_op_registry.h",
-        "xla_resource.h",
     ],
     visibility = [":friends"],
     deps = [
@@ -351,10 +343,18 @@ cc_library(
         ":sharding_util",
         ":side_effect_util",
         ":tf2xla_util",
+        ":xla_argument",
+        ":xla_compilation_device",
+        ":xla_context",
+        ":xla_expression",
+        ":xla_helpers",
+        ":xla_op_registry",
+        ":xla_resource",
         "//tensorflow/compiler/jit:common",
         "//tensorflow/compiler/jit:flags",
         "//tensorflow/compiler/jit:shape_inference",
         "//tensorflow/compiler/jit:xla_cluster_util",
+        "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
         "//tensorflow/compiler/tf2xla/lib:util",
         "//tensorflow/compiler/xla:literal",
         "//tensorflow/compiler/xla:shape_util",
@@ -370,6 +370,7 @@ cc_library(
         "//tensorflow/compiler/xla/client:xla_computation",
         "//tensorflow/compiler/xla/client/lib:arithmetic",
         "//tensorflow/compiler/xla/client/lib:constants",
+        "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
@@ -388,6 +389,172 @@ cc_library(
     alwayslink = 1,
 )
 
+cc_library(
+    name = "xla_compilation_device",
+    srcs = [
+        "xla_compilation_device.cc",
+    ],
+    hdrs = [
+        "xla_compilation_device.h",
+    ],
+    deps = [
+        ":common",
+        ":frontend_attributes_util",
+        ":sharding_util",
+        ":xla_context",
+        ":xla_helpers",
+        "//tensorflow/compiler/xla/client:xla_builder",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:session_options",
+        "//tensorflow/core/common_runtime:core_cpu_internal",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "xla_context",
+    srcs = [
+        "xla_context.cc",
+    ],
+    hdrs = [
+        "xla_context.h",
+    ],
+    deps = [
+        ":common",
+        ":xla_expression",
+        ":xla_helpers",
+        "//tensorflow/compiler/xla:literal",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status_macros",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:xla_data_proto_cc",
+        "//tensorflow/compiler/xla/client:client_library",
+        "//tensorflow/compiler/xla/client:xla_builder",
+        "//tensorflow/compiler/xla/client:xla_computation",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/common_runtime:core_cpu_internal",
+        "@com_google_absl//absl/types:span",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "xla_op_registry",
+    srcs = [
+        "xla_op_registry.cc",
+    ],
+    hdrs = [
+        "xla_op_registry.h",
+    ],
+    visibility = [":friends"],
+    deps = [
+        ":common",
+        ":xla_context",
+        "//tensorflow/compiler/jit:flags",
+        "//tensorflow/compiler/jit:xla_cluster_util",
+        "//tensorflow/compiler/xla/client:client_library",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:session_options",
+        "//tensorflow/core:stream_executor_no_cuda",
+        "//tensorflow/core/common_runtime:core_cpu_internal",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "xla_expression",
+    srcs = [
+        "xla_expression.cc",
+    ],
+    hdrs = [
+        "xla_expression.h",
+    ],
+    deps = [
+        ":common",
+        ":xla_resource",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla/client",
+        "//tensorflow/compiler/xla/client:xla_builder",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/types:optional",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "xla_resource",
+    srcs = [
+        "xla_resource.cc",
+    ],
+    hdrs = [
+        "xla_resource.h",
+    ],
+    deps = [
+        ":common",
+        ":sharding_util",
+        ":xla_helpers",
+        "//tensorflow/compiler/xla:xla_data_proto_cc",
+        "//tensorflow/compiler/xla/client:xla_builder",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "xla_helpers",
+    srcs = [
+        "xla_helpers.cc",
+    ],
+    hdrs = [
+        "xla_helpers.h",
+    ],
+    visibility = [":friends"],
+    deps = [
+        ":common",
+        ":host_compute_metadata_proto_cc",
+        "//tensorflow/compiler/tf2xla/lib:util",
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla/client:xla_builder",
+        "//tensorflow/compiler/xla/client:xla_computation",
+        "//tensorflow/compiler/xla/client/lib:arithmetic",
+        "//tensorflow/compiler/xla/client/lib:constants",
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "@com_google_absl//absl/types:span",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "xla_argument",
+    srcs = [
+        "xla_argument.cc",
+    ],
+    hdrs = [
+        "xla_argument.h",
+    ],
+    deps = [
+        ":host_compute_metadata_proto_cc",
+        ":xla_resource",
+        "//tensorflow/compiler/xla/client:xla_builder",
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/core:framework",
+        "@com_google_absl//absl/types:span",
+    ],
+    alwayslink = 1,
+)
+
 cc_library(
     name = "common",
     srcs = [
@@ -564,6 +731,8 @@ tf_cc_test(
         ":common",
         ":side_effect_util",
         ":xla_compiler",
+        ":xla_expression",
+        ":xla_resource",
         "//tensorflow/cc:cc_ops",
         "//tensorflow/cc:function_ops",
         "//tensorflow/cc:functional_ops",
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index ec0cb9c0b66..26051c98cb7 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -145,7 +145,12 @@ tf_kernel_library(
         "//tensorflow/compiler/jit:xla_activity_listener",
         "//tensorflow/compiler/jit:xla_activity_proto_cc",
         "//tensorflow/compiler/tf2xla:common",
+        "//tensorflow/compiler/tf2xla:xla_compilation_device",
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_context",
+        "//tensorflow/compiler/tf2xla:xla_helpers",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
+        "//tensorflow/compiler/tf2xla:xla_resource",
         "//tensorflow/compiler/tf2xla/lib:broadcast",
         "//tensorflow/compiler/tf2xla/lib:data_format",
         "//tensorflow/compiler/tf2xla/lib:random",
@@ -223,6 +228,8 @@ cc_library(
     deps = [
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_helpers",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/compiler/xla:literal_util",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:util",
@@ -276,6 +283,8 @@ tf_kernel_library(
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:side_effect_util",
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_helpers",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/compiler/tf2xla/ops:xla_ops",
         "//tensorflow/compiler/xla:literal",
         "//tensorflow/compiler/xla:status_macros",
@@ -296,6 +305,8 @@ tf_kernel_library(
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:side_effect_util",
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_context",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/compiler/tf2xla/ops:xla_ops",
         "//tensorflow/compiler/xla:literal",
         "//tensorflow/compiler/xla/client:xla_builder",
@@ -314,6 +325,8 @@ tf_kernel_library(
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:side_effect_util",
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_context",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/compiler/tf2xla/ops:xla_ops",
         "//tensorflow/compiler/xla:literal",
         "//tensorflow/compiler/xla/client:xla_builder",
@@ -333,6 +346,7 @@ tf_kernel_library(
     ],
     deps = [
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/core:array_ops_op_lib",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index f0bd97c85eb..531679d3905 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -38,6 +38,7 @@ cc_library(
     hdrs = ["random.h"],
     deps = [
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_helpers",
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla/client:xla_builder",
diff --git a/tensorflow/compiler/tf2xla/xla_argument.cc b/tensorflow/compiler/tf2xla/xla_argument.cc
new file mode 100644
index 00000000000..fe31025386e
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_argument.cc
@@ -0,0 +1,53 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_argument.h"
+
+namespace tensorflow {
+
+bool XlaArgument::operator==(const XlaArgument& other) const {
+  if (std::tie(kind, resource_kind, type, name, initialized, max_array_size,
+               tensor_array_gradients) !=
+      std::tie(other.kind, other.resource_kind, other.type, other.name,
+               other.initialized, other.max_array_size,
+               other.tensor_array_gradients)) {
+    return false;
+  }
+  if (absl::holds_alternative<xla::Shape>(shape)) {
+    if (!absl::holds_alternative<xla::Shape>(other.shape)) {
+      return false;
+    }
+    if (!xla::Shape::Equal()(absl::get<xla::Shape>(shape),
+                             absl::get<xla::Shape>(other.shape))) {
+      return false;
+    }
+  } else {
+    if (!absl::holds_alternative<TensorShape>(other.shape)) {
+      return false;
+    }
+    if (absl::get<TensorShape>(shape) != absl::get<TensorShape>(other.shape)) {
+      return false;
+    }
+  }
+  if (constant_value.shape() != other.constant_value.shape()) {
+    return false;
+  }
+  if (is_same_data_across_replicas != other.is_same_data_across_replicas) {
+    return false;
+  }
+  return constant_value.tensor_data() == other.constant_value.tensor_data();
+}
+
+}  // end namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_argument.h b/tensorflow/compiler/tf2xla/xla_argument.h
new file mode 100644
index 00000000000..e2cd634e1d5
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_argument.h
@@ -0,0 +1,121 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_
+#define TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_
+
+#include "absl/types/span.h"
+#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
+#include "tensorflow/compiler/tf2xla/xla_resource.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+
+// Describes how to derive the value of each _Arg node in the graph/function
+// being compiled. There must be one Argument for each _Arg index.
+struct XlaArgument {
+  enum Kind {
+    // Default value; not a valid kind.
+    kInvalid,
+
+    // Argument is a compile-time constant. No associated runtime parameter.
+    kConstant,
+
+    // Argument is a Variable, TensorArray, or Stack resource. Has an
+    // associated runtime parameter iff `initialized` is true.
+    kResource,
+
+    // Argument is a run-time parameter.
+    kParameter,
+
+    // Argument is an XLA token.
+    kToken,
+
+    // Argument is a TensorList.
+    kTensorList,
+  };
+
+  Kind kind = kInvalid;
+
+  // The type of the argument. If the argument is a resource, this
+  // is the type of the variable's value, not DT_RESOURCE.
+  DataType type = DT_INVALID;
+
+  // The shape of the argument. For:
+  // * a parameter: the shape of the parameter. We allow setting the xla shape
+  //   if known. This helps avoid conversions to and from TensorShape.
+  // * a constant: ignored; the shape given by constant_value is used
+  //     instead.
+  // * an uninitialized resource: ignored. We don't yet know the shape of an
+  //     uninitialized resource (otherwise we would have initialized it!)
+  // * an initialized variable: the shape of the variable's value.
+  // * an initialized TensorArray or Stack resource: the shape of an entry in
+  //   the TensorArray/Stack. Note this is the size of a single entry, not the
+  //   XLA data structure that represents the complete stack/array.
+  absl::variant<TensorShape, xla::Shape> shape;
+
+  // The value of the argument, if it is a compile-time constant. Must be a
+  // host-memory tensor.
+  Tensor constant_value;
+
+  // The name of this argument, used for debugging.
+  string name;
+
+  // The name of TensorFlow _Arg node, used for debugging.
+  string node_name;
+
+  // For a kResource, what kind of resource is it?
+  XlaResource::Kind resource_kind = XlaResource::kInvalid;
+
+  // For a kResource, has this resource been initialized?
+  bool initialized = false;
+
+  // For a kResource, is this resource on Fast Memory.
+  bool fast_mem = false;
+
+  // For a TensorArray or Stack resource, what is the array's declared size?
+  // (Used for lazy initialization.)
+  int64 max_array_size = -1;
+
+  // TensorArray resource parameters are passed as (array, gradient array 0,
+  // ..., gradient array k), where the gradient arrays are in the same order
+  // as `tensor_array_gradients`.
+  std::set<string> tensor_array_gradients;
+
+  // dynamic dims to arg number map. Empty if no dynamic shapes.
+  std::map<int32, int32> dynamic_dim_to_arg_num_map;
+  bool is_pad_arg = false;
+
+  // Whether this argument will receive the same data across all replicas.
+  bool is_same_data_across_replicas = false;
+
+  bool operator==(const XlaArgument& other) const;
+
+  // Returns a human-readable summary of the argument.
+  string HumanString() const;
+
+  // Returns the dimension sizes for either TensorShape or xla::Shape.
+  std::vector<int64> DimensionSizes() const;
+  absl::InlinedVector<int64, 4> DimensionSizesAsInlinedVector() const;
+
+  // Returns the human-readable string for either TensorShape or xla::Shape.
+  string ShapeHumanString() const;
+};
+
+}  // end namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 0722c30787f..db54f2f6563 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -422,39 +422,6 @@ Status BuildComputation(
 
 }  // namespace
 
-bool XlaCompiler::Argument::operator==(
-    const XlaCompiler::Argument& other) const {
-  if (std::tie(kind, resource_kind, type, name, initialized, max_array_size,
-               tensor_array_gradients) !=
-      std::tie(other.kind, other.resource_kind, other.type, other.name,
-               other.initialized, other.max_array_size,
-               other.tensor_array_gradients)) {
-    return false;
-  }
-  if (absl::holds_alternative<xla::Shape>(shape)) {
-    if (!absl::holds_alternative<xla::Shape>(other.shape)) {
-      return false;
-    }
-    if (!xla::Shape::Equal()(absl::get<xla::Shape>(shape),
-                             absl::get<xla::Shape>(other.shape))) {
-      return false;
-    }
-  } else {
-    if (!absl::holds_alternative<TensorShape>(other.shape)) {
-      return false;
-    }
-    if (absl::get<TensorShape>(shape) != absl::get<TensorShape>(other.shape)) {
-      return false;
-    }
-  }
-  if (constant_value.shape() != other.constant_value.shape()) {
-    return false;
-  }
-  if (is_same_data_across_replicas != other.is_same_data_across_replicas) {
-    return false;
-  }
-  return constant_value.tensor_data() == other.constant_value.tensor_data();
-}
 
 string XlaCompiler::Argument::HumanString() const {
   string common;
@@ -1494,93 +1461,4 @@ xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
   return iter->second;
 }
 
-XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn() {
-  return [](const TensorShape& shape, DataType dtype,
-            bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
-    xla::Shape xla_shape;
-    TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
-    return xla_shape;
-  };
-}
-
-// Rewrites the layout of xla_shape if there is tiled sharding.
-Status RewriteLayoutWithShardedShape(
-    const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
-    XlaCompiler::ShapeRepresentationFn shape_representation_fn,
-    xla::Shape* xla_shape) {
-  if (sharding && !sharding->IsTileMaximal()) {
-    // After sharding, per core shape might have different layout. For example,
-    // before sharding, a shape [128, 128] will be assigned default
-    // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
-    // the sharded shapes will have minor-to-major {0, 1}.
-    //
-    // As a result, for sharded shapes, we set their layout to per core shape's
-    // layout.
-    //
-    // TODO(endlessroad): for variable input & update, we might have
-    // different layouts which will prevent input output aliasing and
-    // increase memory usage. Investigate such cases.
-    int64 device = *sharding->tile_assignment().begin();
-    std::vector<int64> offset =
-        sharding->TileOffsetForDevice(*xla_shape, device);
-    std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device);
-    std::vector<int64> dimensions(xla_shape->rank());
-    for (int64 i = 0; i < xla_shape->rank(); ++i) {
-      dimensions[i] = limit[i] - offset[i];
-    }
-    xla::Shape per_device_xla_shape =
-        xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
-    TensorShape per_device_tensor_shape;
-    TF_RETURN_IF_ERROR(
-        XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape));
-    TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
-                                            xla_shape->element_type()));
-    TF_ASSIGN_OR_RETURN(per_device_xla_shape,
-                        shape_representation_fn(per_device_tensor_shape, dtype,
-                                                use_fast_memory));
-    *xla_shape->mutable_layout() = per_device_xla_shape.layout();
-  }
-  return Status::OK();
-}
-
-// There is a shape_representation_fn or sharding for an output, this function
-// uses a reshape to fix the layout.
-xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
-    xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
-    XlaCompiler::ShapeRepresentationFn shape_representation_fn,
-    absl::optional<xla::OpSharding> sharding, bool fast_mem) {
-  if (original_shape.IsTuple()) {
-    std::vector<xla::XlaOp> elements;
-    for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
-      auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
-      TF_ASSIGN_OR_RETURN(auto element,
-                          ReshapeWithCorrectRepresentationAndSharding(
-                              builder, xla::GetTupleElement(original, i),
-                              original_shape.tuple_shapes(i),
-                              shape_representation_fn, subsharding, fast_mem));
-      elements.push_back(element);
-    }
-    return xla::Tuple(builder, elements);
-  }
-  if (!original_shape.IsArray()) return original;
-  TensorShape shape;
-  TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
-  TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
-                                          original_shape.element_type()));
-  TF_ASSIGN_OR_RETURN(auto to_shape,
-                      shape_representation_fn(shape, dtype, fast_mem));
-  if (sharding) {
-    TF_ASSIGN_OR_RETURN(auto hlo_sharding,
-                        xla::HloSharding::FromProto(*sharding));
-    TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
-        hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
-  }
-  if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
-    for (int64 i = 0; i < original_shape.rank(); ++i) {
-      to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
-    }
-  }
-  return xla::Reshape(to_shape, original);
-}
-
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index b95d250636a..b0d93cde846 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -21,8 +21,10 @@ limitations under the License.
 #include "absl/types/span.h"
 #include "absl/types/variant.h"
 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
+#include "tensorflow/compiler/tf2xla/xla_argument.h"
 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
 #include "tensorflow/compiler/tf2xla/xla_expression.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -97,96 +99,7 @@ class XlaContext;
 // `tensor_array_gradients` ordered set.
 class XlaCompiler {
  public:
-  // Describes how to derive the value of each _Arg node in the graph/function
-  // being compiled. There must be one Argument for each _Arg index.
-  struct Argument {
-    enum Kind {
-      // Default value; not a valid kind.
-      kInvalid,
-
-      // Argument is a compile-time constant. No associated runtime parameter.
-      kConstant,
-
-      // Argument is a Variable, TensorArray, or Stack resource. Has an
-      // associated runtime parameter iff `initialized` is true.
-      kResource,
-
-      // Argument is a run-time parameter.
-      kParameter,
-
-      // Argument is an XLA token.
-      kToken,
-
-      // Argument is a TensorList.
-      kTensorList,
-    };
-
-    Kind kind = kInvalid;
-
-    // The type of the argument. If the argument is a resource, this
-    // is the type of the variable's value, not DT_RESOURCE.
-    DataType type = DT_INVALID;
-
-    // The shape of the argument. For:
-    // * a parameter: the shape of the parameter. We allow setting the xla shape
-    //   if known. This helps avoid conversions to and from TensorShape.
-    // * a constant: ignored; the shape given by constant_value is used
-    //     instead.
-    // * an uninitialized resource: ignored. We don't yet know the shape of an
-    //     uninitialized resource (otherwise we would have initialized it!)
-    // * an initialized variable: the shape of the variable's value.
-    // * an initialized TensorArray or Stack resource: the shape of an entry in
-    //   the TensorArray/Stack. Note this is the size of a single entry, not the
-    //   XLA data structure that represents the complete stack/array.
-    absl::variant<TensorShape, xla::Shape> shape;
-
-    // The value of the argument, if it is a compile-time constant. Must be a
-    // host-memory tensor.
-    Tensor constant_value;
-
-    // The name of this argument, used for debugging.
-    string name;
-
-    // The name of TensorFlow _Arg node, used for debugging.
-    string node_name;
-
-    // For a kResource, what kind of resource is it?
-    XlaResource::Kind resource_kind = XlaResource::kInvalid;
-
-    // For a kResource, has this resource been initialized?
-    bool initialized = false;
-
-    // For a kResource, is this resource on Fast Memory.
-    bool fast_mem = false;
-
-    // For a TensorArray or Stack resource, what is the array's declared size?
-    // (Used for lazy initialization.)
-    int64 max_array_size = -1;
-
-    // TensorArray resource parameters are passed as (array, gradient array 0,
-    // ..., gradient array k), where the gradient arrays are in the same order
-    // as `tensor_array_gradients`.
-    std::set<string> tensor_array_gradients;
-
-    // dynamic dims to arg number map. Empty if no dynamic shapes.
-    std::map<int32, int32> dynamic_dim_to_arg_num_map;
-    bool is_pad_arg = false;
-
-    // Whether this argument will receive the same data across all replicas.
-    bool is_same_data_across_replicas = false;
-
-    bool operator==(const Argument& other) const;
-
-    // Returns a human-readable summary of the argument.
-    string HumanString() const;
-
-    // Returns the dimension sizes for either TensorShape or xla::Shape.
-    std::vector<int64> DimensionSizes() const;
-    absl::InlinedVector<int64, 4> DimensionSizesAsInlinedVector() const;
-
-    // Returns the human-readable string for either TensorShape or xla::Shape.
-    string ShapeHumanString() const;
-  };
+  using Argument = ::tensorflow::XlaArgument;
 
   // Options pertaining to an individual call to CompileGraph() or
   // CompileFunction().
@@ -221,77 +134,11 @@ class XlaCompiler {
     bool alias_resource_update = false;
   };
 
-  struct OutputDescription {
-    // Type and shape of the output. The shape is the unflattened shape.
-    // When `type` is DT_RESOURCE, `shape` is the shape of the resource
-    // variable's value.
-    DataType type;
-    TensorShape shape;
+  using OutputDescription = ::tensorflow::XlaOutputDescription;
 
-    // Constant output value, if known to be constant at JIT compilation time.
-    // 'Tensor' is in host memory.
-    bool is_constant = false;
-    Tensor constant_value;
+  using ResourceUpdate = ::tensorflow::XlaResourceUpdate;
 
-    // When this output is a resource, i.e. `type == DT_RESOURCE`, this is
-    // the index of the input that contains the resource.
-    int input_index;
-
-    // Whether this output is a TensorList.
-    bool is_tensor_list = false;
-  };
-
-  // Describes a variable write side effect of the computation.
-  struct ResourceUpdate {
-    // Index of the input that contains the variable resource to write to.
-    int input_index;
-
-    // Type and shape of the tensor to be written back.
-    // The `shape` field has the same meaning as the Argument::shape field.
-    DataType type;
-    TensorShape shape;
-
-    // Was the value of the variable modified by the computation?
-    // (Always true, unless `return_updated_values_for_all_resources` is true.)
-    bool modified;
-
-    // If the resource is a TensorArray, the set of gradients read or written.
-    std::set<string> tensor_array_gradients_accessed;
-  };
-
-  struct CompilationResult {
-    // Vector that maps from the parameters of the XLA computation to their
-    // original argument positions. To handle compile-time constant inputs, the
-    // parameters to the XLA computation may be a subset of the original
-    // arguments. The relative ordering of parameters are maintained.
-    std::vector<int> input_mapping;
-
-    // Input shapes of the computation. If we are flattening inputs, these are
-    // the flattened shapes.
-    std::vector<xla::Shape> xla_input_shapes;
-
-    // Output shape in XLA format. The output shape is always a tuple. If we
-    // are flattening outputs, these are the flattened shapes.
-    xla::Shape xla_output_shape;
-
-    // TensorFlow shapes of outputs, together with the values of any
-    // constant arguments. Vector indexed by Tensorflow _Retval number,
-    // containing both constant and non-constant results.
-    std::vector<OutputDescription> outputs;
-
-    // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
-    // matching RecvAtHost/SendFromHost Ops in the outer graph.
-    tf2xla::HostComputeMetadata host_compute_metadata;
-
-    // Resources whose values were updated by the computation, ordered
-    // by return value position (which is the same as the order the resources
-    // were passed as arguments). Resource updates follow the non-constant
-    // results in the outputs of XLA computation.
-    std::vector<ResourceUpdate> resource_updates;
-
-    // The XLA computation built from the tensorflow subgraph.
-    std::shared_ptr<xla::XlaComputation> computation;
-  };
+  using CompilationResult = ::tensorflow::XlaCompilationResult;
 
   typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType,
                                                   bool)>
@@ -518,21 +365,6 @@ class XlaCompiler {
   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
 };
 
-// Creates an identity shape representation function.
-XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn();
-
-// Rewrites the layout of xla_shape if there is tiled sharding.
-Status RewriteLayoutWithShardedShape(
-    const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
-    XlaCompiler::ShapeRepresentationFn shape_representation_fn,
-    xla::Shape* xla_shape);
-
-// Adds reshapes to fix the layout of an output, if a shape_representation_fn or
-// sharding is present.
-xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
-    xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
-    XlaCompiler::ShapeRepresentationFn shape_representation_fn,
-    absl::optional<xla::OpSharding> sharding, bool fast_mem);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index c94c4805d53..cb5bf34208f 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -24,7 +24,6 @@ limitations under the License.
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
-#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index eb4ad3fe6a1..e44ac05b702 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -20,7 +20,6 @@ limitations under the License.
 
 #include <vector>
 
-#include "tensorflow/compiler/tf2xla/xla_compiler.h"
 #include "tensorflow/compiler/tf2xla/xla_expression.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
@@ -33,6 +32,7 @@ limitations under the License.
 namespace tensorflow {
 
 class XlaOpKernelContext;
+class XlaCompiler;
 
 // The XlaContext is the data structure that holds the state of an XLA
 // compilation, that is accessible from OpKernelContexts when compiling a
diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc
index 49f108ed6c8..34e108bb6bf 100644
--- a/tensorflow/compiler/tf2xla/xla_expression.cc
+++ b/tensorflow/compiler/tf2xla/xla_expression.cc
@@ -163,4 +163,23 @@ xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
   }
 }
 
+const XlaExpression* XlaExpression::CastExpressionFromTensor(
+    const Tensor& tensor) {
+  const XlaExpression* expression =
+      reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
+  CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
+      << expression->HumanString();
+  return expression;
+}
+
+// Assigns an XlaExpression to a tensor on an XLA compilation device.
+void XlaExpression::AssignExpressionToTensor(const XlaExpression& value,
+                                             Tensor* tensor) {
+  const XlaExpression* expression =
+      reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
+  CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
+      << expression->HumanString();
+  *const_cast<XlaExpression*>(expression) = value;
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h
index 5d0bb35b182..3010964c5b7 100644
--- a/tensorflow/compiler/tf2xla/xla_expression.h
+++ b/tensorflow/compiler/tf2xla/xla_expression.h
@@ -104,6 +104,13 @@ class XlaExpression {
   // not the shape of the resource's value.
   xla::StatusOr<TensorShape> GetShape() const;
 
+  // Retrieves an XlaExpression that was allocated by a previous Op.
+  static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor);
+
+  // Assigns an XlaExpression to a tensor on an XLA compilation device.
+  static void AssignExpressionToTensor(const XlaExpression& value,
+                                       Tensor* tensor);
+
  private:
   Kind kind_ = Kind::kInvalid;
 
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 74247bbaec7..8c4b55aec8a 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -22,8 +22,6 @@ limitations under the License.
 #include "tensorflow/compiler/tf2xla/literal_util.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/type_util.h"
-#include "tensorflow/compiler/tf2xla/xla_context.h"
-#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
 #include "tensorflow/compiler/xla/client/lib/constants.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -128,4 +126,93 @@ xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand,
   return xla::ConvertElementType(operand, convert_to);
 }
 
+XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn() {
+  return [](const TensorShape& shape, DataType dtype,
+            bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
+    xla::Shape xla_shape;
+    TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
+    return xla_shape;
+  };
+}
+
+// Rewrites the layout of xla_shape if there is tiled sharding.
+Status RewriteLayoutWithShardedShape(
+    const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
+    XlaHelpers::ShapeRepresentationFn shape_representation_fn,
+    xla::Shape* xla_shape) {
+  if (sharding && !sharding->IsTileMaximal()) {
+    // After sharding, per core shape might have different layout. For example,
+    // before sharding, a shape [128, 128] will be assigned default
+    // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
+    // the sharded shapes will have minor-to-major {0, 1}.
+    //
+    // As a result, for sharded shapes, we set their layout to per core shape's
+    // layout.
+    //
+    // TODO(endlessroad): for variable input & update, we might have
+    // different layouts which will prevent input output aliasing and
+    // increase memory usage. Investigate such cases.
+    int64 device = *sharding->tile_assignment().begin();
+    std::vector<int64> offset =
+        sharding->TileOffsetForDevice(*xla_shape, device);
+    std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device);
+    std::vector<int64> dimensions(xla_shape->rank());
+    for (int64 i = 0; i < xla_shape->rank(); ++i) {
+      dimensions[i] = limit[i] - offset[i];
+    }
+    xla::Shape per_device_xla_shape =
+        xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
+    TensorShape per_device_tensor_shape;
+    TF_RETURN_IF_ERROR(
+        XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape));
+    TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
+                                            xla_shape->element_type()));
+    TF_ASSIGN_OR_RETURN(per_device_xla_shape,
+                        shape_representation_fn(per_device_tensor_shape, dtype,
+                                                use_fast_memory));
+    *xla_shape->mutable_layout() = per_device_xla_shape.layout();
+  }
+  return Status::OK();
+}
+
+// There is a shape_representation_fn or sharding for an output, this function
+// uses a reshape to fix the layout.
+xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
+    xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
+    XlaHelpers::ShapeRepresentationFn shape_representation_fn,
+    absl::optional<xla::OpSharding> sharding, bool fast_mem) {
+  if (original_shape.IsTuple()) {
+    std::vector<xla::XlaOp> elements;
+    for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
+      auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
+      TF_ASSIGN_OR_RETURN(auto element,
+                          ReshapeWithCorrectRepresentationAndSharding(
+                              builder, xla::GetTupleElement(original, i),
+                              original_shape.tuple_shapes(i),
+                              shape_representation_fn, subsharding, fast_mem));
+      elements.push_back(element);
+    }
+    return xla::Tuple(builder, elements);
+  }
+  if (!original_shape.IsArray()) return original;
+  TensorShape shape;
+  TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
+  TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
+                                          original_shape.element_type()));
+  TF_ASSIGN_OR_RETURN(auto to_shape,
+                      shape_representation_fn(shape, dtype, fast_mem));
+  if (sharding) {
+    TF_ASSIGN_OR_RETURN(auto hlo_sharding,
+                        xla::HloSharding::FromProto(*sharding));
+    TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
+        hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
+  }
+  if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
+    for (int64 i = 0; i < original_shape.rank(); ++i) {
+      to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
+    }
+  }
+  return xla::Reshape(to_shape, original);
+}
+
 }  // end namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h
index 490923526bd..3a9375ec1f4 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.h
+++ b/tensorflow/compiler/tf2xla/xla_helpers.h
@@ -19,8 +19,9 @@ limitations under the License.
 #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
 
 #include "absl/types/span.h"
-#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding.h"
 #include "tensorflow/core/framework/tensor.h"
 
 namespace tensorflow {
@@ -72,6 +73,98 @@ class XlaHelpers {
   // than the xla::PrimitiveType.
   static xla::XlaOp ConvertElementType(const xla::XlaOp& operand,
                                        const DataType new_element_type);
+
+  typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType,
+                                                  bool)>
+      ShapeRepresentationFn;
+};
+
+// Creates an identity shape representation function.
+XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn();
+
+// Rewrites the layout of xla_shape if there is tiled sharding.
+Status RewriteLayoutWithShardedShape(
+    const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
+    XlaHelpers::ShapeRepresentationFn shape_representation_fn,
+    xla::Shape* xla_shape);
+
+// Adds reshapes to fix the layout of an output, if a shape_representation_fn or
+// sharding is present.
+xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
+    xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
+    XlaHelpers::ShapeRepresentationFn shape_representation_fn,
+    absl::optional<xla::OpSharding> sharding, bool fast_mem);
+
+struct XlaOutputDescription {
+  // Type and shape of the output. The shape is the unflattened shape.
+  // When `type` is DT_RESOURCE, `shape` is the shape of the resource
+  // variable's value.
+  DataType type;
+  TensorShape shape;
+
+  // Constant output value, if known to be constant at JIT compilation time.
+  // 'Tensor' is in host memory.
+  bool is_constant = false;
+  Tensor constant_value;
+
+  // When this output is a resource, i.e. `type == DT_RESOURCE`, this is
+  // the index of the input that contains the resource.
+  int input_index;
+
+  // Whether this output is a TensorList.
+  bool is_tensor_list = false;
+};
+
+// Describes a variable write side effect of the computation.
+struct XlaResourceUpdate {
+  // Index of the input that contains the variable resource to write to.
+  int input_index;
+
+  // Type and shape of the tensor to be written back.
+  // The `shape` field has the same meaning as the Argument::shape field.
+  DataType type;
+  TensorShape shape;
+
+  // Was the value of the variable modified by the computation?
+  // (Always true, unless `return_updated_values_for_all_resources` is true.)
+  bool modified;
+
+  // If the resource is a TensorArray, the set of gradients read or written.
+  std::set<string> tensor_array_gradients_accessed;
+};
+
+struct XlaCompilationResult {
+  // Vector that maps from the parameters of the XLA computation to their
+  // original argument positions. To handle compile-time constant inputs, the
+  // parameters to the XLA computation may be a subset of the original
+  // arguments. The relative ordering of parameters are maintained.
+  std::vector<int> input_mapping;
+
+  // Input shapes of the computation. If we are flattening inputs, these are
+  // the flattened shapes.
+  std::vector<xla::Shape> xla_input_shapes;
+
+  // Output shape in XLA format. The output shape is always a tuple. If we
+  // are flattening outputs, these are the flattened shapes.
+  xla::Shape xla_output_shape;
+
+  // TensorFlow shapes of outputs, together with the values of any
+  // constant arguments. Vector indexed by Tensorflow _Retval number,
+  // containing both constant and non-constant results.
+  std::vector<XlaOutputDescription> outputs;
+
+  // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
+  // matching RecvAtHost/SendFromHost Ops in the outer graph.
+  tf2xla::HostComputeMetadata host_compute_metadata;
+
+  // Resources whose values were updated by the computation, ordered
+  // by return value position (which is the same as the order the resources
+  // were passed as arguments). Resource updates follow the non-constant
+  // results in the outputs of XLA computation.
+  std::vector<XlaResourceUpdate> resource_updates;
+
+  // The XLA computation built from the tensorflow subgraph.
+  std::shared_ptr<xla::XlaComputation> computation;
 };
 
 }  // end namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 27766408716..735a6c7291e 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -49,33 +49,13 @@ XlaCompiler* XlaOpKernelContext::compiler() const {
   return xla_context()->compiler();
 }
 
-// Retrieves an XlaExpression that was allocated by a previous Op.
-const XlaExpression* XlaOpKernelContext::CastExpressionFromTensor(
-    const Tensor& tensor) {
-  const XlaExpression* expression =
-      reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
-  CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
-      << expression->HumanString();
-  return expression;
-}
-
-// Assigns an XlaExpression to a tensor on an XLA compilation device.
-void XlaOpKernelContext::AssignExpressionToTensor(const XlaExpression& value,
-                                                  Tensor* tensor) {
-  const XlaExpression* expression =
-      reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
-  CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
-      << expression->HumanString();
-  *const_cast<XlaExpression*>(expression) = value;
-}
-
 const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
-  return *CastExpressionFromTensor(context_->input(index));
+  return *XlaExpression::CastExpressionFromTensor(context_->input(index));
 }
 
 const XlaExpression& XlaOpKernelContext::InputExpression(
     absl::string_view name) {
-  return *CastExpressionFromTensor(GetInputTensorByName(name));
+  return *XlaExpression::CastExpressionFromTensor(GetInputTensorByName(name));
 }
 
 xla::XlaOp XlaOpKernelContext::Input(int index) {
@@ -108,7 +88,8 @@ DataType XlaOpKernelContext::input_type(int index) const {
   if (type == DT_UINT8) {
     // Masqueraded XlaExpression could have different type. See
     // XlaOpKernelContext::SetOutputExpression for details.
-    auto expression = CastExpressionFromTensor(context_->input(index));
+    auto expression =
+        XlaExpression::CastExpressionFromTensor(context_->input(index));
     type = expression->dtype();
   }
   return type;
@@ -120,7 +101,7 @@ DataType XlaOpKernelContext::InputType(absl::string_view name) {
   if (type == DT_UINT8) {
     // Masqueraded XlaExpression could have different type. See
     // XlaOpKernelContext::SetOutputExpression for details.
-    auto expression = CastExpressionFromTensor(tensor);
+    auto expression = XlaExpression::CastExpressionFromTensor(tensor);
     type = expression->dtype();
   }
   return type;
@@ -385,7 +366,8 @@ Status XlaOpKernelContext::InputList(absl::string_view name,
   handles->clear();
   shapes->clear();
   for (const Tensor& input : inputs) {
-    handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder()));
+    handles->push_back(
+        XlaExpression::CastExpressionFromTensor(input)->AsXlaOp(builder()));
     shapes->push_back(input.shape());
   }
   return Status::OK();
@@ -408,7 +390,7 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
                                const XlaOpKernelContext* ctx,
                                TensorShape* shape, xla::XlaOp* value) {
   const XlaExpression* expression =
-      XlaOpKernelContext::CastExpressionFromTensor(tensor);
+      XlaExpression::CastExpressionFromTensor(tensor);
   XlaResource* variable = expression->resource();
   TF_RET_CHECK(variable != nullptr);
   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
@@ -461,7 +443,8 @@ Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
 Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
                                                    TensorShape* shape) const {
   const Tensor& tensor = context_->input(index);
-  const XlaExpression* expression = CastExpressionFromTensor(tensor);
+  const XlaExpression* expression =
+      XlaExpression::CastExpressionFromTensor(tensor);
   XlaResource* variable = expression->resource();
   TF_RET_CHECK(variable != nullptr);
   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
@@ -502,8 +485,8 @@ void XlaOpKernelContext::SetOutputExpression(int index,
       TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
       TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
     }
-    XlaOpKernelContext::AssignExpressionToTensor(
-        expression, context_->mutable_output(index));
+    XlaExpression::AssignExpressionToTensor(expression,
+                                            context_->mutable_output(index));
     return Status::OK();
   }();
   if (!status.ok()) {
@@ -542,7 +525,7 @@ void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
 
 Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
   const XlaExpression* expression =
-      CastExpressionFromTensor(context_->input(index));
+      XlaExpression::CastExpressionFromTensor(context_->input(index));
   TF_RET_CHECK(expression->resource() != nullptr);
   *resource = expression->resource();
   return Status::OK();
@@ -554,7 +537,7 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type,
                             const XlaOpKernelContext* ctx, xla::XlaOp handle,
                             xla::XlaBuilder* builder) {
   const XlaExpression* expression =
-      XlaOpKernelContext::CastExpressionFromTensor(tensor);
+      XlaExpression::CastExpressionFromTensor(tensor);
   XlaResource* variable = expression->resource();
   TF_RET_CHECK(variable != nullptr);
   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 6987b6fbb98..3cf51e6ec6f 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -17,6 +17,9 @@ limitations under the License.
 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
 
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/tf2xla/xla_expression.h"
+#include "tensorflow/compiler/tf2xla/xla_resource.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -284,13 +287,6 @@ class XlaOpKernelContext {
   // separate specialization of the computation for each DataType.
   const xla::XlaComputation* GetOrCreateMul(const DataType type);
 
-  // Assigns an XlaExpression to a tensor on an XLA compilation device.
-  static void AssignExpressionToTensor(const XlaExpression& value,
-                                       Tensor* tensor);
-
-  // Retrieves an XlaExpression that was assigned to the specified tensor.
-  static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor);
-
  private:
   // Returns the tensor of input `name`.
   const Tensor& GetInputTensorByName(absl::string_view name);
diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc
index 32d42cb8a42..bec0b46611d 100644
--- a/tensorflow/compiler/tf2xla/xla_resource.cc
+++ b/tensorflow/compiler/tf2xla/xla_resource.cc
@@ -21,7 +21,6 @@ limitations under the License.
 #include "absl/memory/memory.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/sharding_util.h"
-#include "tensorflow/compiler/tf2xla/xla_context.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 
diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD
index 30a90c1da6c..0a17ba3d408 100644
--- a/tensorflow/core/tpu/BUILD
+++ b/tensorflow/core/tpu/BUILD
@@ -57,6 +57,7 @@ cc_library(
         ":tpu_defs",
         ":tpu_node_device_util",
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
     ],
     alwayslink = 1,
 )
@@ -180,8 +181,8 @@ cc_library(
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:tf2xla_util",
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/core:core_cpu_internal",
-        "//tensorflow/core:framework",
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD
index 7f64758d238..e5f49158231 100644
--- a/tensorflow/core/tpu/kernels/BUILD
+++ b/tensorflow/core/tpu/kernels/BUILD
@@ -656,6 +656,7 @@ cc_library(
     deps = [
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto_cc",
@@ -673,6 +674,7 @@ cc_library(
     srcs = ["topk_ops.cc"],
     deps = [
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla:xla_op_registry",
         "//tensorflow/compiler/xla/client:xla_builder",
         "//tensorflow/compiler/xla/client/lib:arithmetic",
         "//tensorflow/core/tpu:tpu_defs",