From 3ba0deba916b52d1a8ee0b13b94352c60203072d Mon Sep 17 00:00:00 2001
From: Frank Chen <frankchn@google.com>
Date: Fri, 7 Aug 2020 17:51:06 -0700
Subject: [PATCH] Introduce additional TPU infeed and outfeed ops

PiperOrigin-RevId: 325542225
Change-Id: Ie972e60d6c5639b71719837c500ecc716eda2ebd
---
 tensorflow/core/tpu/BUILD                     |   8 +-
 tensorflow/core/tpu/kernels/BUILD             | 107 ++++
 .../core/tpu/kernels/image_resize_ops.cc      | 155 +++++
 tensorflow/core/tpu/kernels/infeed_ops.cc     | 529 ++++++++++++++++++
 tensorflow/core/tpu/kernels/infeed_ops.h      |  69 +++
 tensorflow/core/tpu/kernels/outfeed_ops.cc    | 116 ++++
 tensorflow/core/tpu/kernels/outfeed_ops.h     |  69 +++
 .../core/tpu/kernels/replication_ops.cc       |  27 +
 .../core/tpu/kernels/tpu_handle_to_key_op.cc  |  62 ++
 tensorflow/core/tpu/kernels/transfer_ops.cc   |  98 ++++
 tensorflow/core/tpu/kernels/transfer_ops.h    |  56 ++
 tensorflow/core/tpu/tpu_defs.cc               |  18 +
 tensorflow/core/tpu/tpu_defs.h                |   6 +
 tensorflow/core/tpu/tpu_library_init_fns.inc  |   1 +
 tensorflow/stream_executor/tpu/BUILD          |   2 +
 .../stream_executor/tpu/tpu_executor_c_api.h  |   3 +
 .../tpu/tpu_transfer_manager.h                |   6 +
 .../tpu/tpu_transfer_manager_interface.cc     |  40 ++
 .../tpu/tpu_transfer_manager_interface.h      |   7 +
 19 files changed, 1378 insertions(+), 1 deletion(-)
 create mode 100644 tensorflow/core/tpu/kernels/image_resize_ops.cc
 create mode 100644 tensorflow/core/tpu/kernels/infeed_ops.cc
 create mode 100644 tensorflow/core/tpu/kernels/infeed_ops.h
 create mode 100644 tensorflow/core/tpu/kernels/outfeed_ops.cc
 create mode 100644 tensorflow/core/tpu/kernels/outfeed_ops.h
 create mode 100644 tensorflow/core/tpu/kernels/replication_ops.cc
 create mode 100644 tensorflow/core/tpu/kernels/tpu_handle_to_key_op.cc
 create mode 100644 tensorflow/core/tpu/kernels/transfer_ops.cc
 create mode 100644 tensorflow/core/tpu/kernels/transfer_ops.h
 create mode 100644 tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.cc

diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD
index 0a17ba3d408..15b2b93e46f 100644
--- a/tensorflow/core/tpu/BUILD
+++ b/tensorflow/core/tpu/BUILD
@@ -88,7 +88,13 @@ cc_library(
     name = "tpu_defs",
     srcs = ["tpu_defs.cc"],
     hdrs = ["tpu_defs.h"],
-    deps = ["//tensorflow/core:protos_all_cc"],
+    deps = [
+        ":tpu_api",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/stream_executor/tpu:c_api_conversions",
+        "//tensorflow/stream_executor/tpu:c_api_decl",
+    ],
 )
 
 cc_library(
diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD
index 1336f52ed34..157aeb3df58 100644
--- a/tensorflow/core/tpu/kernels/BUILD
+++ b/tensorflow/core/tpu/kernels/BUILD
@@ -28,10 +28,16 @@ tf_kernel_library(
     deps = [
         ":cross_replica_ops",
         ":host_compute_ops",
+        ":image_resize_ops",
+        ":infeed_ops",
+        ":outfeed_ops",
+        ":replication_ops",
         ":topk_ops",
         ":tpu_compile_op",
         ":tpu_configuration_ops",
         ":tpu_execute_op",
+        ":tpu_handle_to_key_op",
+        ":transfer_ops",
     ],
 )
 
@@ -684,3 +690,104 @@ cc_library(
     ],
     alwayslink = 1,
 )
+
+cc_library(
+    name = "infeed_ops",
+    srcs = ["infeed_ops.cc"],
+    hdrs = ["infeed_ops.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":transfer_ops",
+        "//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
+        "//tensorflow/compiler/tf2xla:common",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/common_runtime:dma_helper",
+        "//tensorflow/core/framework:protos_all_cc",
+        "//tensorflow/core/kernels:transpose_functor",
+        "//tensorflow/core/platform:status",
+        "//tensorflow/core/profiler/lib:traceme",
+        "//tensorflow/core/tpu:tpu_defs",
+        "//tensorflow/stream_executor:multi_platform_manager",
+        "//tensorflow/stream_executor/tpu:tpu_transfer_manager_base",
+        "//tensorflow/stream_executor/tpu:tpu_transfer_manager_interface",
+    ],
+    alwayslink = True,
+)
+
+cc_library(
+    name = "transfer_ops",
+    srcs = ["transfer_ops.cc"],
+    hdrs = ["transfer_ops.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core/kernels:ops_util",
+        "//tensorflow/core/profiler/lib:traceme",
+        "//tensorflow/stream_executor:multi_platform_manager",
+        "//tensorflow/stream_executor/tpu:tpu_node_context",
+        "//tensorflow/stream_executor/tpu:tpu_platform_interface",
+        "//tensorflow/stream_executor/tpu:tpu_transfer_manager_interface",
+    ],
+    alwayslink = True,
+)
+
+cc_library(
+    name = "outfeed_ops",
+    srcs = ["outfeed_ops.cc"],
+    hdrs = ["outfeed_ops.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":transfer_ops",
+        "//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
+        "//tensorflow/compiler/tf2xla:common",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/framework:protos_all_cc",
+        "//tensorflow/core/tpu:tpu_defs",
+        "//tensorflow/stream_executor:multi_platform_manager",
+    ],
+    alwayslink = True,
+)
+
+cc_library(
+    name = "image_resize_ops",
+    srcs = ["image_resize_ops.cc"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/compiler/tf2xla:common",
+        "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/xla/client:xla_builder",
+        "//tensorflow/compiler/xla/client/lib:constants",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/tpu:tpu_defs",
+        "@com_google_absl//absl/strings",
+    ],
+    alwayslink = True,
+)
+
+cc_library(
+    name = "replication_ops",
+    srcs = ["replication_ops.cc"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/tpu:tpu_defs",
+    ],
+    alwayslink = True,
+)
+
+cc_library(
+    name = "tpu_handle_to_key_op",
+    srcs = ["tpu_handle_to_key_op.cc"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":tpu_compilation_cache_interface",
+        ":tpu_op_consts",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/tpu:tpu_configuration",
+    ],
+    alwayslink = True,
+)
diff --git a/tensorflow/core/tpu/kernels/image_resize_ops.cc b/tensorflow/core/tpu/kernels/image_resize_ops.cc
new file mode 100644
index 00000000000..fd0f5e4c7a6
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/image_resize_ops.cc
@@ -0,0 +1,155 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+
+namespace tensorflow {
+
+class TpuCustomResizeOp : public XlaOpKernel {
+ public:
+  explicit TpuCustomResizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
+    OP_REQUIRES_OK(ctx,
+                   ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
+  }
+
+  xla::Shape GetOutputShape(XlaOpKernelContext* ctx) const {
+    std::vector<int64> out_size;
+    auto status = ctx->ConstantInputAsIntVector(1, &out_size);
+    CHECK_EQ(out_size.size(), 2) << status.ToString();
+    xla::Shape output_shape =
+        TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(0));
+    output_shape.mutable_dimensions()[1] = out_size[0];
+    output_shape.mutable_dimensions()[2] = out_size[1];
+    return output_shape;
+  }
+
+  string OpaqueField() const {
+    return absl::StrCat("\"", align_corners_, half_pixel_centers_, "\"");
+  }
+
+  void CompileGrad(XlaOpKernelContext* ctx, const char* target,
+                   const xla::Shape& output_shape) {
+    auto input_shape =
+        TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(0));
+    if (ctx->InputShape(1).dim_sizes() == ctx->InputShape(0).dim_sizes()) {
+      ctx->SetOutput(
+          0, xla::ConvertElementType(ctx->Input(0), ctx->output_xla_type(0)));
+      return;
+    }
+    // The gradient should be done in two phases for large resizes.
+    auto input = ctx->Input(0);
+    if (input_shape.dimensions(1) / output_shape.dimensions(1) > 3 &&
+        input_shape.dimensions(2) / output_shape.dimensions(2) > 3) {
+      auto intermediate_shape = output_shape;
+      intermediate_shape.mutable_dimensions()[1] = input_shape.dimensions(1);
+      input = xla::CustomCall(ctx->builder(), target, {ctx->Input(0)},
+                              intermediate_shape, OpaqueField());
+    }
+    ctx->SetOutput(0, xla::CustomCall(ctx->builder(), target, {input},
+                                      output_shape, OpaqueField()));
+  }
+
+  void CompileForward(XlaOpKernelContext* ctx, const char* target) {
+    auto output_shape = GetOutputShape(ctx);
+    if (ctx->InputShape(0).dim_size(1) == output_shape.dimensions(1) &&
+        ctx->InputShape(0).dim_size(2) == output_shape.dimensions(2)) {
+      ctx->SetOutput(
+          0, xla::ConvertElementType(ctx->Input(0), ctx->output_xla_type(0)));
+      return;
+    }
+    if (ctx->InputShape(0).dim_size(1) == 1 &&
+        ctx->InputShape(0).dim_size(2) == 1) {
+      ctx->SetOutput(0,
+                     ctx->Input(0) + xla::Zeros(ctx->builder(), output_shape));
+      return;
+    }
+    ctx->SetOutput(0, xla::CustomCall(ctx->builder(), target, {ctx->Input(0)},
+                                      output_shape, OpaqueField()));
+  }
+
+ private:
+  bool align_corners_;
+  bool half_pixel_centers_;
+};
+
+class TpuResizeNearestNeighborOp : public TpuCustomResizeOp {
+ public:
+  explicit TpuResizeNearestNeighborOp(OpKernelConstruction* ctx)
+      : TpuCustomResizeOp(ctx) {}
+  void Compile(XlaOpKernelContext* ctx) override {
+    CompileForward(ctx, "ResizeNearest");
+  }
+};
+
+class TpuResizeBilinearOp : public TpuCustomResizeOp {
+ public:
+  explicit TpuResizeBilinearOp(OpKernelConstruction* ctx)
+      : TpuCustomResizeOp(ctx) {}
+  void Compile(XlaOpKernelContext* ctx) override {
+    CompileForward(ctx, "ResizeBilinear");
+  }
+};
+
+class TpuResizeNearestNeighborGradOp : public TpuCustomResizeOp {
+ public:
+  explicit TpuResizeNearestNeighborGradOp(OpKernelConstruction* ctx)
+      : TpuCustomResizeOp(ctx) {}
+  void Compile(XlaOpKernelContext* ctx) override {
+    CompileGrad(ctx, "ResizeNearestGrad", GetOutputShape(ctx));
+  }
+};
+
+class TpuResizeBilinearGradOp : public TpuCustomResizeOp {
+ public:
+  explicit TpuResizeBilinearGradOp(OpKernelConstruction* ctx)
+      : TpuCustomResizeOp(ctx) {}
+  void Compile(XlaOpKernelContext* ctx) override {
+    auto output_shape =
+        TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(1));
+    CompileGrad(ctx, "ResizeBilinearGrad", output_shape);
+  }
+};
+
+REGISTER_XLA_OP(Name("ResizeNearestNeighbor")
+                    .CompileTimeConstantInput("size")
+                    .Device(DEVICE_TPU_XLA_JIT),
+                TpuResizeNearestNeighborOp);
+
+REGISTER_XLA_OP(Name("ResizeNearestNeighborGrad")
+                    .CompileTimeConstantInput("size")
+                    .Device(DEVICE_TPU_XLA_JIT),
+                TpuResizeNearestNeighborGradOp);
+
+REGISTER_XLA_OP(Name("ResizeBilinear")
+                    .CompileTimeConstantInput("size")
+                    .Device(DEVICE_TPU_XLA_JIT),
+                TpuResizeBilinearOp);
+
+REGISTER_XLA_OP(Name("ResizeBilinearGrad").Device(DEVICE_TPU_XLA_JIT),
+                TpuResizeBilinearGradOp);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/infeed_ops.cc b/tensorflow/core/tpu/kernels/infeed_ops.cc
new file mode 100644
index 00000000000..f3fbd16b6cc
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/infeed_ops.cc
@@ -0,0 +1,529 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/tpu/kernels/infeed_ops.h"
+
+#include <algorithm>
+#include <vector>
+
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function_handle_cache.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/framework/variant_tensor_data.h"
+#include "tensorflow/core/kernels/transpose_functor.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
+#include "tensorflow/core/tpu/kernels/transfer_ops.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
+#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
+
+namespace tensorflow {
+namespace {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef tensorflow::tpu::NoncopyableBuffer LinearizerBuffer;
+typedef std::deque<LinearizerBuffer> LinearizerBufferList;
+
+// Transposes the given tensor using the tensorflow C++ transpose implementation
+// to obtain a XLA literal for the host tensor laid out as the given layout. The
+// returned tensor is normalized to the dim0major layout -- F32[10,20,30]{2,0,1}
+// is returned as F32[20,10,30]{2,1,0}.
+xla::StatusOr<Tensor> TransposeTensor(OpKernelContext* ctx,
+                                      const Tensor& input_tensor,
+                                      const xla::Shape& xla_shape) {
+  profiler::TraceMe trace_me("TransposeTensor", /*level=*/2);
+  const int64 rank = xla_shape.rank();
+  std::vector<int32> permutation(rank);
+  std::vector<int64> transposed_shapes(rank);
+  for (int64 i = 0; i < rank; ++i) {
+    permutation[i] = xla_shape.layout().minor_to_major(rank - 1 - i);
+    transposed_shapes[i] = xla_shape.dimensions(permutation[i]);
+  }
+
+  Tensor transposed_tensor;
+
+  // If this is a trivial transpose (i.e., bitcast), just create an aliased
+  // tensor with the transposed shape.
+  if (xla::LayoutUtil::IsMonotonicWithDim0Major(
+          xla::ShapeUtil::DropDegenerateDimensions(xla_shape).layout())) {
+    TensorShape shape;
+    TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(transposed_shapes, &shape));
+    TF_RETURN_IF_ERROR(transposed_tensor.BitcastFrom(
+        input_tensor, input_tensor.dtype(), shape));
+    return transposed_tensor;
+  }
+
+  AllocatorAttributes alloc_attr;
+  alloc_attr.set_on_host(true);
+  TF_RETURN_IF_ERROR(ctx->allocate_temp(input_tensor.dtype(),
+                                        TensorShape(transposed_shapes),
+                                        &transposed_tensor, alloc_attr));
+  // Eigen Transpose fails with SIGFPE if there is a dimension of size 0.
+  if (input_tensor.NumElements() > 0) {
+    TF_RETURN_IF_ERROR(DoTranspose<CPUDevice>(ctx->eigen_device<CPUDevice>(),
+                                              input_tensor, permutation,
+                                              &transposed_tensor));
+  }
+  return transposed_tensor;
+}
+
+xla::StatusOr<bool> GetLayoutOverride(OpKernelConstruction* ctx,
+                                      const char* attrn_name,
+                                      std::vector<int64>* minor_to_major) {
+  if (!ctx->HasAttr(attrn_name)) {
+    return false;
+  }
+  TF_RETURN_IF_ERROR(ctx->GetAttr(attrn_name, minor_to_major));
+  return !minor_to_major->empty();
+}
+
+Status GetInfeedShapeWithLayout(OpKernelConstruction* ctx,
+                                const char* attrn_name,
+                                const xla::Shape& input_shape,
+                                xla::Shape* output_shape) {
+  std::vector<int64> minor_to_major;
+  TF_ASSIGN_OR_RETURN(bool has_override,
+                      GetLayoutOverride(ctx, attrn_name, &minor_to_major));
+  if (!has_override) {
+    *output_shape = input_shape;
+    if (output_shape->IsTuple()) {
+      int64 tuple_elements = xla::ShapeUtil::TupleElementCount(*output_shape);
+      for (int64 i = 0; i < tuple_elements; ++i) {
+        xla::Shape* sub_shape =
+            xla::ShapeUtil::GetMutableSubshape(output_shape, {i});
+        *sub_shape->mutable_layout() = GetTPUInfeedLayout(*sub_shape).layout();
+      }
+    } else {
+      *output_shape->mutable_layout() =
+          GetTPUInfeedLayout(*output_shape).layout();
+    }
+    return Status::OK();
+  }
+
+  auto layout_func = [](const xla::Shape& shape) -> xla::Layout {
+    return GetTPUInfeedLayout(shape).layout();
+  };
+  return GetShapeWithLayout(input_shape, minor_to_major, layout_func,
+                            output_shape);
+}
+
+// LinearizedBuffersWrapper is an opaque C++ data structure for the outputs of
+// PrelinearizeOp and PrelinearizeTupleOp. It holds the resultant linearized
+// buffers and references to input tensors whose underlying storage are shared
+// with linearized buffers.
+// NOTE: This is not a feature-complete implementation of the DT_VARIANT
+// specification. In particular, we cannot currently serialize an arbitrary
+// `LinearizerBufferList` (aka `std::deque<LinearizerBuffer>`)
+// object, so the `Encode()` and `Decode()` methods are not implemented.
+struct LinearizedBuffersWrapper {
+  explicit LinearizedBuffersWrapper() {}
+  explicit LinearizedBuffersWrapper(LinearizerBufferList bufs,
+                                    std::vector<tensorflow::Tensor> ts)
+      : buffers(std::move(bufs)), tensors(std::move(ts)) {}
+  LinearizedBuffersWrapper(const LinearizedBuffersWrapper& wrapper) {
+    // tensorflow::Variant requires this copy constructor to compile.
+    LOG(FATAL) << "LinearizedBuffersWrapper should not copy.";
+  }
+  LinearizedBuffersWrapper& operator=(const LinearizedBuffersWrapper& wrapper) =
+      delete;
+  LinearizedBuffersWrapper(LinearizedBuffersWrapper&&) = default;
+  LinearizedBuffersWrapper& operator=(LinearizedBuffersWrapper&&) = default;
+  ~LinearizedBuffersWrapper() = default;
+
+  // These functions are tensorflow::Variant requirements.
+  string TypeName() const { return "(anonymous)::LinearizedBuffersWrapper"; }
+  void Encode(tensorflow::VariantTensorData* data) const {
+    LOG(ERROR) << "Encode() is not implemented for LinearizedBuffersWrapper "
+                  "objects.";
+  }
+  bool Decode(const tensorflow::VariantTensorData& data) {
+    LOG(ERROR) << "Decode() is not implemented for LinearizedBuffersWrapper "
+                  "objects.";
+    return false;
+  }
+
+  LinearizerBufferList buffers;
+  // Save references on tensors whose underlying storage are shared with
+  // LiteralLinearizer::Buffer in `buffers`.
+  std::vector<tensorflow::Tensor> tensors;
+};
+
+Status AutoTransposeAndLinearize(OpKernelContext* ctx,
+                                 const Tensor& input_tensor,
+                                 const xla::Shape& shape,
+                                 LinearizerBufferList* linearized_buffers,
+                                 std::vector<Tensor>* saved_input_tensors) {
+  const Tensor* tensor = &input_tensor;
+  // If the given layout is not in dim0major layout, tranposes the tensor.
+  bool has_transposed = false;
+  Tensor transposed_tensor;
+  if (!xla::LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {
+    // If the given layout is not in dim0major layout, transpose the tensor.
+    TF_ASSIGN_OR_RETURN(transposed_tensor,
+                        TransposeTensor(ctx, input_tensor, shape));
+    tensor = &transposed_tensor;
+    has_transposed = true;
+  }
+
+  xla::BorrowingLiteral literal;
+  TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(*tensor, &literal));
+
+  TF_RETURN_IF_ERROR(
+      xla::TpuTransferManagerInterface::GetRegisteredTpuTransferManager()
+          ->LinearizeToBuffers(literal, linearized_buffers));
+
+  // The input tensor is ref-counted. Save a handle on the input tensor if
+  // its underlying storage is shared with linearized buffers to prevent
+  // input tensor from getting freed.
+  for (const auto& buffer : *linearized_buffers) {
+    if (!buffer.owns_data() && !has_transposed) {
+      // `buffer` is created from zero-copy fast path from the un-transposed
+      // input tensor so its underlying data is shared with input tensor.
+      // Save a handle to input tensor to increment its ref-count and avoid
+      // it getting deallocated after PrelinearizeTupleOp completes.
+      saved_input_tensors->push_back(*tensor);
+      // A literal can be linearized to zero to two buffers. If any of the
+      // linearized buffer shares storage with input tensor. We save exactly
+      // one handle on the input tensor.
+      break;
+    }
+  }
+  return Status::OK();
+}
+
+// PrelinearizeOp is used to linearize one tensor to the device format.
+class PrelinearizeOp : public OpKernel {
+ public:
+  explicit PrelinearizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+    xla::Shape shape;
+    OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &shape));
+    OP_REQUIRES_OK(ctx,
+                   GetInfeedShapeWithLayout(ctx, "layout", shape, &xla_shape_));
+  }
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor& input_tensor = ctx->input(0);
+    // Validate input.
+    OP_REQUIRES(
+        ctx, input_tensor.dtype() == dtype_,
+        errors::InvalidArgument("Prelinearize dtype mismatch; expected ",
+                                DataType_Name(dtype_), ", got ",
+                                DataType_Name(input_tensor.dtype())));
+    OP_REQUIRES(
+        ctx, input_tensor.shape() == shape_,
+        errors::InvalidArgument("Prelinearize shape mismatch; expected ",
+                                shape_.DebugString(), ", got ",
+                                input_tensor.shape().DebugString()));
+
+    // Auto-transpose and prelinearize.
+    LinearizerBufferList linearized_buffers;
+    std::vector<Tensor> saved_input_tensors;
+    auto status =
+        AutoTransposeAndLinearize(ctx, input_tensor, xla_shape_,
+                                  &linearized_buffers, &saved_input_tensors);
+    OP_REQUIRES_OK(ctx, status);
+
+    // Write to output.
+    tensorflow::Tensor* output;
+    OP_REQUIRES_OK(ctx,
+                   ctx->allocate_output(0, tensorflow::TensorShape{}, &output));
+    output->scalar<tensorflow::Variant>()() = LinearizedBuffersWrapper{
+        std::move(linearized_buffers), std::move(saved_input_tensors)};
+  }
+
+  bool IsExpensive() override { return true; }
+
+ private:
+  TensorShape shape_;
+  DataType dtype_;
+  xla::Shape xla_shape_;
+
+  // PrelinearizeOp is neither copyable nor movable.
+  PrelinearizeOp(const PrelinearizeOp&) = delete;
+  PrelinearizeOp& operator=(const PrelinearizeOp&) = delete;
+};
+
+// PrelinearizeTupleOp is used to linearize multiple tensors to the device
+// format.
+class PrelinearizeTupleOp : public OpKernel {
+ public:
+  explicit PrelinearizeTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
+    OP_REQUIRES(
+        ctx, shapes_.size() == dtypes_.size(),
+        errors::InvalidArgument(
+            "shapes and dtypes must be the same length. shapes length = ",
+            shapes_.size(), ", dtypes length = ", dtypes_.size()));
+
+    std::vector<xla::Shape> xla_shapes;
+    for (int i = 0; i < shapes_.size(); i++) {
+      xla::Shape xla_shape;
+      OP_REQUIRES_OK(ctx,
+                     TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
+      xla_shapes.push_back(xla_shape);
+    }
+    OP_REQUIRES_OK(
+        ctx, GetInfeedShapeWithLayout(
+                 ctx, "layouts", xla::ShapeUtil::MakeTupleShape(xla_shapes),
+                 &tuple_shape_));
+  }
+
+  void Compute(OpKernelContext* ctx) override {
+    OpInputList values;
+    OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &values));
+    OP_REQUIRES(ctx, values.size() == shapes_.size(),
+                errors::InvalidArgument(
+                    "Wrong number of inputs to PrelinearizeTuple."));
+
+    LinearizerBufferList all_linearized_buffers;
+    std::vector<Tensor> all_saved_input_tensors;
+    for (int i = 0; i < values.size(); i++) {
+      // Validate input.
+      const Tensor& input_tensor = values[i];
+      OP_REQUIRES(ctx, input_tensor.dtype() == dtypes_[i],
+                  errors::InvalidArgument(
+                      "PrelinearizeTuple dtype mismatch at tuple element ", i,
+                      "; expected ", DataType_Name(dtypes_[i]), ", got ",
+                      DataType_Name(input_tensor.dtype())));
+      OP_REQUIRES(ctx, input_tensor.shape() == shapes_[i],
+                  errors::InvalidArgument(
+                      "PrelinearizeTuple shape mismatch at tuple element ", i,
+                      "; expected ", shapes_[i].DebugString(), ", got ",
+                      input_tensor.shape().DebugString()));
+
+      // Auto-transpose and prelinearize.
+      LinearizerBufferList linearized_buffers;
+      std::vector<Tensor> saved_input_tensors;
+      auto status = AutoTransposeAndLinearize(
+          ctx, input_tensor, tuple_shape_.tuple_shapes(i), &linearized_buffers,
+          &saved_input_tensors);
+      OP_REQUIRES_OK(ctx, status);
+      all_linearized_buffers.insert(
+          all_linearized_buffers.end(),
+          std::make_move_iterator(linearized_buffers.begin()),
+          std::make_move_iterator(linearized_buffers.end()));
+      all_saved_input_tensors.insert(
+          all_saved_input_tensors.end(),
+          std::make_move_iterator(saved_input_tensors.begin()),
+          std::make_move_iterator(saved_input_tensors.end()));
+    }
+
+    tensorflow::Tensor* output;
+    OP_REQUIRES_OK(ctx,
+                   ctx->allocate_output(0, tensorflow::TensorShape{}, &output));
+    output->scalar<tensorflow::Variant>()() = LinearizedBuffersWrapper{
+        std::move(all_linearized_buffers), std::move(all_saved_input_tensors)};
+  }
+
+  bool IsExpensive() override { return true; }
+
+ private:
+  std::vector<TensorShape> shapes_;
+  DataTypeVector dtypes_;
+  xla::Shape tuple_shape_;
+
+  // PrelinearizeTupleOp is neither copyable nor movable.
+  PrelinearizeTupleOp(const PrelinearizeTupleOp&) = delete;
+  PrelinearizeTupleOp& operator=(const PrelinearizeTupleOp&) = delete;
+};
+
+// The InfeedEnqueuePrelinearizedBufferOp op is used to transfer prelinearized
+// buffers to the device infeed queue.
+class InfeedEnqueuePrelinearizedBufferOp : public TpuTransferAsyncOpKernel {
+ public:
+  explicit InfeedEnqueuePrelinearizedBufferOp(OpKernelConstruction* ctx)
+      : TpuTransferAsyncOpKernel(ctx, "prelinearized_buffers_to_infeed", 8) {}
+
+  Status DoWork(OpKernelContext* ctx,
+                xla::TpuTransferManagerInterface* transfer_manager,
+                stream_executor::StreamExecutor* stream_executor) override {
+    const Tensor& input_tensor = ctx->input(0);
+    const LinearizedBuffersWrapper* wrapper =
+        input_tensor.scalar<tensorflow::Variant>()()
+            .get<LinearizedBuffersWrapper>();
+    TF_RETURN_IF_ERROR(transfer_manager->TransferBuffersToInfeed(
+        stream_executor, wrapper->buffers));
+
+    return Status::OK();
+  }
+
+ private:
+  // InfeedEnqueuePrelinearizedBufferOp is neither copyable nor movable.
+  InfeedEnqueuePrelinearizedBufferOp(
+      const InfeedEnqueuePrelinearizedBufferOp&) = delete;
+  InfeedEnqueuePrelinearizedBufferOp& operator=(
+      const InfeedEnqueuePrelinearizedBufferOp&) = delete;
+};
+
+}  // anonymous namespace
+
+TpuInfeedEnqueueOp::TpuInfeedEnqueueOp(OpKernelConstruction* ctx)
+    : TpuTransferAsyncOpKernel(ctx, "infeed_enqueue", 8) {
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+  xla::Shape shape;
+  OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &shape));
+  OP_REQUIRES_OK(ctx,
+                 GetInfeedShapeWithLayout(ctx, "layout", shape, &xla_shape_));
+}
+
+Status TpuInfeedEnqueueOp::DoWork(
+    OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
+    stream_executor::StreamExecutor* stream_executor) {
+  const Tensor& input_tensor = ctx->input(0);
+
+  // Validate runtime shape and fail if it doesn't match the contract.
+  if (input_tensor.dtype() != dtype_) {
+    return errors::InvalidArgument("Infeed dtype mismatch.");
+  }
+  if (input_tensor.shape() != shape_) {
+    return errors::InvalidArgument("Infeed shape mismatch; expected ",
+                                   shape_.DebugString(), ", got ",
+                                   input_tensor.shape().DebugString());
+  }
+
+  const Tensor* tensor = &input_tensor;
+  Tensor transposed_tensor;
+  if (!xla::LayoutUtil::IsMonotonicWithDim0Major(xla_shape_.layout())) {
+    // If the given layout is not in dim0major layout, transpose the tensor.
+    TF_ASSIGN_OR_RETURN(transposed_tensor,
+                        TransposeTensor(ctx, input_tensor, xla_shape_));
+    tensor = &transposed_tensor;
+  }
+
+  xla::BorrowingLiteral literal;
+  TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(*tensor, &literal));
+
+  // Transfer the given literal to the Infeed interface of the device.
+  TF_RETURN_IF_ERROR(
+      transfer_manager->TransferLiteralToInfeed(stream_executor, literal));
+  return Status::OK();
+}
+
+TpuInfeedEnqueueTupleOp::TpuInfeedEnqueueTupleOp(OpKernelConstruction* ctx)
+    : TpuTransferAsyncOpKernel(ctx, "infeed_enqueue", 8) {
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
+  OP_REQUIRES(
+      ctx, shapes_.size() == dtypes_.size(),
+      errors::InvalidArgument("shapes and dtypes must be the same length."));
+
+  std::vector<xla::Shape> xla_shapes;
+  for (int i = 0; i < shapes_.size(); i++) {
+    xla::Shape xla_shape;
+    OP_REQUIRES_OK(ctx,
+                   TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
+    xla_shapes.push_back(xla_shape);
+  }
+  OP_REQUIRES_OK(
+      ctx, GetInfeedShapeWithLayout(ctx, "layouts",
+                                    xla::ShapeUtil::MakeTupleShape(xla_shapes),
+                                    &tuple_shape_));
+}
+
+Status TpuInfeedEnqueueTupleOp::DoWork(
+    OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
+    stream_executor::StreamExecutor* stream_executor) {
+  OpInputList values;
+  TF_RETURN_IF_ERROR(ctx->input_list("inputs", &values));
+  if (values.size() != shapes_.size()) {
+    return errors::InvalidArgument(
+        "Wrong number of inputs to InfeedEnqueueTuple.");
+  }
+
+  for (const auto& shapes : shapes_) {
+    VLOG(1) << "TransferLiteralToInfeed " << shapes.DebugString();
+  }
+
+  std::vector<Tensor> maybe_transposed_tensors;
+  maybe_transposed_tensors.reserve(values.size());
+  for (int i = 0; i < values.size(); i++) {
+    // Validate runtime shapes and fail if it doesn't match the contract.
+    const Tensor* tensor = &values[i];
+    if (tensor->shape() != shapes_[i]) {
+      return errors::InvalidArgument("Infeed shape mismatch for tuple element ",
+                                     i, "; expected ", shapes_[i].DebugString(),
+                                     ", got ", tensor->shape().DebugString());
+    }
+    if (!xla::LayoutUtil::IsMonotonicWithDim0Major(
+            tuple_shape_.tuple_shapes(i).layout())) {
+      // If the given layout is not in dim0major layout, tranposes the given
+      // tensor.
+      TF_ASSIGN_OR_RETURN(
+          Tensor transposed_tensor,
+          TransposeTensor(ctx, *tensor, tuple_shape_.tuple_shapes(i)));
+      maybe_transposed_tensors.emplace_back(transposed_tensor);
+    } else {
+      maybe_transposed_tensors.emplace_back(*tensor);
+    }
+  }
+
+  xla::BorrowingLiteral tuple;
+  TF_RETURN_IF_ERROR(
+      HostTensorsToBorrowingLiteralTuple(maybe_transposed_tensors, &tuple));
+
+  // Transfer the given literal to the Infeed interface of the device.
+  TF_RETURN_IF_ERROR(
+      transfer_manager->TransferLiteralToInfeed(stream_executor, tuple));
+
+  VLOG(1) << "TransferLiteralToInfeed complete.";
+
+  return Status::OK();
+}
+
+// These ops execute on either the TPU device or the CPU device. When running on
+// CPU they must specify a non-negative value for device_ordinal to indicate
+// which TPU to send infeed to.
+REGISTER_KERNEL_BUILDER(
+    Name("InfeedEnqueue").Device(DEVICE_TPU_NODE).HostMemory("input"),
+    TpuInfeedEnqueueOp);
+REGISTER_KERNEL_BUILDER(Name("InfeedEnqueue").Device(DEVICE_CPU),
+                        TpuInfeedEnqueueOp);
+
+REGISTER_KERNEL_BUILDER(
+    Name("InfeedEnqueueTuple").Device(DEVICE_TPU_NODE).HostMemory("inputs"),
+    TpuInfeedEnqueueTupleOp);
+REGISTER_KERNEL_BUILDER(Name("InfeedEnqueueTuple").Device(DEVICE_CPU),
+                        TpuInfeedEnqueueTupleOp);
+
+// Prelinearize ops run on CPU as part of tf.data input pipeline.
+REGISTER_KERNEL_BUILDER(Name("Prelinearize").Device(DEVICE_CPU),
+                        PrelinearizeOp);
+REGISTER_KERNEL_BUILDER(Name("PrelinearizeTuple").Device(DEVICE_CPU),
+                        PrelinearizeTupleOp);
+
+// InfeedEnqueuePrelinearizedBuffer op run on CPU and takes a device_ordinal to
+// select the right device to infeed.
+REGISTER_KERNEL_BUILDER(
+    Name("InfeedEnqueuePrelinearizedBuffer").Device(DEVICE_CPU),
+    InfeedEnqueuePrelinearizedBufferOp);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/infeed_ops.h b/tensorflow/core/tpu/kernels/infeed_ops.h
new file mode 100644
index 00000000000..622583b6a73
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/infeed_ops.h
@@ -0,0 +1,69 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_INFEED_OPS_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_INFEED_OPS_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/tpu/kernels/transfer_ops.h"
+
+namespace tensorflow {
+
+// TODO(b/65200690): Rework this when there is a callback based infeed API to
+// StreamExecutor.
+
+// The InfeedEnqueue op is used to deliver data to the device infeed queue.
+class TpuInfeedEnqueueOp : public TpuTransferAsyncOpKernel {
+ public:
+  explicit TpuInfeedEnqueueOp(OpKernelConstruction* ctx);
+  Status DoWork(OpKernelContext* ctx,
+                xla::TpuTransferManagerInterface* transfer_manager,
+                stream_executor::StreamExecutor* stream_executor) override;
+
+ private:
+  TensorShape shape_;
+  DataType dtype_;
+  xla::Shape xla_shape_;
+
+  // TpuInfeedEnqueueOp is neither copyable nor movable.
+  TpuInfeedEnqueueOp(const TpuInfeedEnqueueOp&) = delete;
+  TpuInfeedEnqueueOp& operator=(const TpuInfeedEnqueueOp&) = delete;
+};
+
+// The InfeedEnqueueTuple op is used on the host to deliver multiple tensors to
+// the device infeed queue as an XLA tuple.
+class TpuInfeedEnqueueTupleOp : public TpuTransferAsyncOpKernel {
+ public:
+  explicit TpuInfeedEnqueueTupleOp(OpKernelConstruction* ctx);
+  Status DoWork(OpKernelContext* ctx,
+                xla::TpuTransferManagerInterface* transfer_manager,
+                stream_executor::StreamExecutor* stream_executor) override;
+
+ private:
+  std::vector<TensorShape> shapes_;
+  DataTypeVector dtypes_;
+  xla::Shape tuple_shape_;
+
+  // TpuInfeedEnqueueTupleOp is neither copyable nor movable.
+  TpuInfeedEnqueueTupleOp(const TpuInfeedEnqueueTupleOp&) = delete;
+  TpuInfeedEnqueueTupleOp& operator=(const TpuInfeedEnqueueTupleOp&) = delete;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_INFEED_OPS_H_
diff --git a/tensorflow/core/tpu/kernels/outfeed_ops.cc b/tensorflow/core/tpu/kernels/outfeed_ops.cc
new file mode 100644
index 00000000000..51a3a71a297
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/outfeed_ops.cc
@@ -0,0 +1,116 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/tpu/kernels/outfeed_ops.h"
+
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/tpu/kernels/transfer_ops.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+
+namespace tensorflow {
+
+TpuOutfeedDequeueOp::TpuOutfeedDequeueOp(OpKernelConstruction* ctx)
+    : TpuTransferAsyncOpKernel(ctx, "outfeed_dequeue", 1) {
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+  OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &xla_shape_));
+}
+
+Status TpuOutfeedDequeueOp::DoWork(
+    OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
+    stream_executor::StreamExecutor* stream_executor) {
+  Tensor* output;
+  TF_RETURN_IF_ERROR(ctx->allocate_output(0, shape_, &output));
+
+  // Transfer from the outfeed interface of the device.
+  xla::MutableBorrowingLiteral literal;
+  TF_RETURN_IF_ERROR(
+      HostTensorToMutableBorrowingLiteral(xla_shape_, output, &literal));
+
+  VLOG(1) << "TransferLiteralFromOutfeed "
+          << xla::ShapeUtil::HumanStringWithLayout(xla_shape_);
+
+  TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralFromOutfeed(
+      stream_executor, xla_shape_, literal));
+
+  VLOG(1) << "TransferLiteralFromOutfeed complete.";
+
+  return Status::OK();
+}
+
+// The OutfeedDequeueTuple op is used to retrieve multiple tensors from the
+// device outfeed queue.
+TpuOutfeedDequeueTupleOp::TpuOutfeedDequeueTupleOp(OpKernelConstruction* ctx)
+    : TpuTransferAsyncOpKernel(ctx, "outfeed_dequeue", 1) {
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_));
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
+  OP_REQUIRES(
+      ctx, shapes_.size() == dtypes_.size(),
+      errors::InvalidArgument("shapes and dtypes must be the same length."));
+  // The `dtypes` list is inferred from the supplied inputs, so it
+  // is always the correct length.
+  for (int i = 0; i < shapes_.size(); i++) {
+    xla::Shape xla_shape;
+    OP_REQUIRES_OK(ctx,
+                   TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape));
+    xla_shapes_.push_back(xla_shape);
+  }
+  tuple_shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes_);
+}
+
+Status TpuOutfeedDequeueTupleOp::DoWork(
+    OpKernelContext* ctx, xla::TpuTransferManagerInterface* transfer_manager,
+    stream_executor::StreamExecutor* stream_executor) {
+  VLOG(1) << "TransferLiteralFromOutfeed "
+          << xla::ShapeUtil::HumanStringWithLayout(tuple_shape_);
+
+  for (int i = 0; i < shapes_.size(); ++i) {
+    Tensor* output;
+    TF_RETURN_IF_ERROR(ctx->allocate_output(i, shapes_[i], &output));
+
+    xla::MutableBorrowingLiteral literal;
+    TF_RETURN_IF_ERROR(
+        HostTensorToMutableBorrowingLiteral(xla_shapes_[i], output, &literal));
+    TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralFromOutfeed(
+        stream_executor, xla_shapes_[i], literal));
+  }
+  return Status::OK();
+}
+
+// These ops execute on either the TPU device or the CPU device. When
+// running on CPU they must specify a non-negative value for
+// device_ordinal to indicate which TPU to receive outfeed from.
+REGISTER_KERNEL_BUILDER(
+    Name("OutfeedDequeue").Device(DEVICE_TPU_NODE).HostMemory("output"),
+    TpuOutfeedDequeueOp);
+REGISTER_KERNEL_BUILDER(Name("OutfeedDequeue").Device(DEVICE_CPU),
+                        TpuOutfeedDequeueOp);
+
+REGISTER_KERNEL_BUILDER(
+    Name("OutfeedDequeueTuple").Device(DEVICE_TPU_NODE).HostMemory("outputs"),
+    TpuOutfeedDequeueTupleOp);
+REGISTER_KERNEL_BUILDER(Name("OutfeedDequeueTuple").Device(DEVICE_CPU),
+                        TpuOutfeedDequeueTupleOp);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/outfeed_ops.h b/tensorflow/core/tpu/kernels/outfeed_ops.h
new file mode 100644
index 00000000000..5e3ed87c04b
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/outfeed_ops.h
@@ -0,0 +1,69 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_OUTFEED_OPS_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_OUTFEED_OPS_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/tpu/kernels/transfer_ops.h"
+
+namespace tensorflow {
+
+// The OutfeedDequeue op is used to retrieve a single tensor from the device
+// outfeed queue.
+class TpuOutfeedDequeueOp : public TpuTransferAsyncOpKernel {
+ public:
+  explicit TpuOutfeedDequeueOp(OpKernelConstruction* ctx);
+
+  Status DoWork(OpKernelContext* ctx,
+                xla::TpuTransferManagerInterface* transfer_manager,
+                stream_executor::StreamExecutor* stream_executor) override;
+
+ private:
+  TensorShape shape_;
+  DataType dtype_;
+  xla::Shape xla_shape_;
+
+  // OutfeedDequeueOp is neither copyable nor movable.
+  TpuOutfeedDequeueOp(const TpuOutfeedDequeueOp&) = delete;
+  TpuOutfeedDequeueOp& operator=(const TpuOutfeedDequeueOp&) = delete;
+};
+
+// The OutfeedDequeueTuple op is used to retrieve multiple tensors from the
+// device outfeed queue.
+class TpuOutfeedDequeueTupleOp : public TpuTransferAsyncOpKernel {
+ public:
+  explicit TpuOutfeedDequeueTupleOp(OpKernelConstruction* ctx);
+
+  Status DoWork(OpKernelContext* ctx,
+                xla::TpuTransferManagerInterface* transfer_manager,
+                stream_executor::StreamExecutor* stream_executor) override;
+
+ private:
+  std::vector<TensorShape> shapes_;
+  DataTypeVector dtypes_;
+  std::vector<xla::Shape> xla_shapes_;
+  xla::Shape tuple_shape_;
+
+  // OutfeedDequeueTupleOp is neither copyable nor movable.
+  TpuOutfeedDequeueTupleOp(const TpuOutfeedDequeueTupleOp&) = delete;
+  TpuOutfeedDequeueTupleOp& operator=(const TpuOutfeedDequeueTupleOp&) = delete;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_OUTFEED_OPS_H_
diff --git a/tensorflow/core/tpu/kernels/replication_ops.cc b/tensorflow/core/tpu/kernels/replication_ops.cc
new file mode 100644
index 00000000000..4c986e880e7
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/replication_ops.cc
@@ -0,0 +1,27 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_device_ops.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/tpu/tpu_defs.h"
+
+namespace tensorflow {
+
+REGISTER_KERNEL_BUILDER(Name("_TPUReplicate").Device(DEVICE_TPU_SYSTEM),
+                        XlaDeviceDummyOp);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_handle_to_key_op.cc b/tensorflow/core/tpu/kernels/tpu_handle_to_key_op.cc
new file mode 100644
index 00000000000..ec2ae91d3eb
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/tpu_handle_to_key_op.cc
@@ -0,0 +1,62 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
+#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
+#include "tensorflow/core/tpu/tpu_configuration.h"
+
+namespace tensorflow {
+
+class TpuHandleToProtoKeyOp : public OpKernel {
+ public:
+  explicit TpuHandleToProtoKeyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+  ~TpuHandleToProtoKeyOp() override = default;
+  TpuHandleToProtoKeyOp(const TpuHandleToProtoKeyOp&) = delete;
+  TpuHandleToProtoKeyOp& operator=(const TpuHandleToProtoKeyOp&) = delete;
+
+  void Compute(OpKernelContext* ctx) override {
+    VLOG(1) << "TpuHandleToProtoKeyOp::Compute " << ctx->op_kernel().name()
+            << " on device " << ctx->op_kernel().requested_device();
+    const Tensor& uid = ctx->input(0);
+
+    ResourceMgr* rm = GetTPUConfigResourceMgr();
+    tpu::TpuCompilationCacheInterface* cache;
+    OP_REQUIRES_OK(ctx, rm->Lookup<tpu::TpuCompilationCacheInterface>(
+                            rm->default_container(),
+                            tpu::kCompilationCacheResourceName, &cache));
+    core::ScopedUnref cache_unref(cache);
+
+    std::vector<std::string> keys;
+    OP_REQUIRES_OK(ctx, cache->GetKeysFromUid(uid.scalar<int64>()(), &keys));
+
+    TensorShape output_shape;
+    output_shape.AddDim(keys.size());
+    Tensor* result = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &result));
+    for (int i = 0; i < keys.size(); ++i) {
+      result->vec<tstring>()(i) = keys[i];
+    }
+  };
+};
+
+REGISTER_KERNEL_BUILDER(Name("TpuHandleToProtoKey").Device(DEVICE_CPU),
+                        TpuHandleToProtoKeyOp);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/transfer_ops.cc b/tensorflow/core/tpu/kernels/transfer_ops.cc
new file mode 100644
index 00000000000..40b85e2cfbd
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/transfer_ops.cc
@@ -0,0 +1,98 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/tpu/kernels/transfer_ops.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
+#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
+
+namespace tensorflow {
+
+TpuTransferAsyncOpKernel::TpuTransferAsyncOpKernel(OpKernelConstruction* ctx,
+                                                   const string& transfer_type,
+                                                   int number_of_threads)
+    : AsyncOpKernel(ctx),
+      thread_pool_(new thread::ThreadPool(
+          ctx->env(),
+          strings::StrCat(transfer_type, "_thread_",
+                          SanitizeThreadSuffix(def().name())),
+          /*num_threads=*/8)) {
+  OP_REQUIRES_OK(ctx, ctx->GetAttr("device_ordinal", &device_ordinal_));
+  if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
+    OP_REQUIRES(
+        ctx, device_ordinal_ >= 0,
+        errors::InvalidArgument(transfer_type,
+                                " ops must specify a device_ordinal when "
+                                "placed on CPU."));
+  }
+}
+
+void TpuTransferAsyncOpKernel::ComputeAsync(OpKernelContext* ctx,
+                                            DoneCallback done) {
+  CancellationToken token =
+      ctx->cancellation_manager()->get_cancellation_token();
+  bool already_cancelled;
+  {
+    // Only protect registering the cancellation callback as mu_ cannot be held
+    // at a point where `done` could be called.
+    mutex_lock lock(mu_);
+    already_cancelled = !ctx->cancellation_manager()->RegisterCallback(
+        token, [this]() { Cancel(); });
+  }
+  OP_REQUIRES_ASYNC(ctx, !already_cancelled,
+                    errors::Cancelled("Infeed was cancelled."), done);
+  thread_pool_->Schedule([this, ctx, done, token]() {
+    Status s = RunTransfer(ctx);
+    ctx->cancellation_manager()->DeregisterCallback(token);
+    OP_REQUIRES_OK_ASYNC(ctx, s, done);
+    done();
+  });
+}
+
+Status TpuTransferAsyncOpKernel::RunTransfer(OpKernelContext* ctx) {
+  auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform();
+
+  int real_device_ordinal = device_ordinal_;
+  if (real_device_ordinal < 0) {
+    const XlaDevice::Metadata* metadata;
+    TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata));
+    real_device_ordinal = metadata->device_ordinal();
+  }
+  stream_executor::StreamExecutor* stream_executor =
+      tpu_platform->ExecutorForDevice(real_device_ordinal).ValueOrDie();
+
+  // When Xprof profiling is off (which is the default), constructing the
+  // activity is simple enough that its overhead is negligible.
+  profiler::TraceMe activity(
+      [this] { return profiler::TraceMeOp(name(), type_string()); },
+      profiler::TraceMeLevel::kInfo);
+  return DoWork(
+      ctx, xla::TpuTransferManagerInterface::GetRegisteredTpuTransferManager(),
+      stream_executor);
+}
+
+void TpuTransferAsyncOpKernel::Cancel() {
+  mutex_lock lock(mu_);
+  TF_CHECK_OK(tpu::TpuNodeContext::CloseTpuHost());
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/transfer_ops.h b/tensorflow/core/tpu/kernels/transfer_ops.h
new file mode 100644
index 00000000000..d98d743f569
--- /dev/null
+++ b/tensorflow/core/tpu/kernels/transfer_ops.h
@@ -0,0 +1,56 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_
+#define TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_
+
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/util/stream_executor_util.h"
+#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
+
+namespace tensorflow {
+
+// Base class providing common functionality for async ops that transfer from
+// host to TPU.
+class TpuTransferAsyncOpKernel : public AsyncOpKernel {
+ public:
+  explicit TpuTransferAsyncOpKernel(OpKernelConstruction* ctx,
+                                    const string& transfer_type,
+                                    int number_of_threads);
+
+  void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
+
+ protected:
+  virtual Status DoWork(OpKernelContext* context,
+                        xla::TpuTransferManagerInterface* transfer_manager,
+                        stream_executor::StreamExecutor* stream_executor) = 0;
+
+ private:
+  Status RunTransfer(OpKernelContext* ctx);
+  void Cancel();
+
+  std::unique_ptr<thread::ThreadPool> thread_pool_;
+  int device_ordinal_;
+  mutex mu_;
+
+  // TpuTransferAsyncOpKernel is neither copyable nor movable.
+  TpuTransferAsyncOpKernel(const TpuTransferAsyncOpKernel&) = delete;
+  TpuTransferAsyncOpKernel& operator=(const TpuTransferAsyncOpKernel&) = delete;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_
diff --git a/tensorflow/core/tpu/tpu_defs.cc b/tensorflow/core/tpu/tpu_defs.cc
index 69669bfdb7b..69d4989773a 100644
--- a/tensorflow/core/tpu/tpu_defs.cc
+++ b/tensorflow/core/tpu/tpu_defs.cc
@@ -15,6 +15,10 @@ limitations under the License.
 
 #include "tensorflow/core/tpu/tpu_defs.h"
 
+#include "tensorflow/core/tpu/tpu_api.h"
+#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
+#include "tensorflow/stream_executor/tpu/c_api_decl.h"
+
 namespace tensorflow {
 
 const char* const DEVICE_TPU_NODE = "TPU";
@@ -27,4 +31,18 @@ const char* const TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR =
 const char* const kTPUReplicateAttr = "_tpu_replicate";
 const char* const kOutsideCompilationAttr = "_xla_outside_compilation";
 
+xla::Shape GetTPUInfeedLayout(const xla::Shape& shape) {
+  XLA_Shape c_shape;
+  XLA_Shape c_infeed_shape;
+
+  ApiConverter::ToC(shape, &c_shape);
+
+  tpu::ExecutorApiFn()->TpuTransferManager_GetInfeedLayoutFn(&c_shape,
+                                                             &c_infeed_shape);
+  xla::Shape infeed_shape = ApiConverter::FromC(&c_infeed_shape);
+  ApiConverter::Free(&c_shape);
+  ApiConverter::Free(&c_infeed_shape);
+  return infeed_shape;
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/tpu/tpu_defs.h b/tensorflow/core/tpu/tpu_defs.h
index 008e386dde6..29954b2289f 100644
--- a/tensorflow/core/tpu/tpu_defs.h
+++ b/tensorflow/core/tpu/tpu_defs.h
@@ -20,6 +20,7 @@ limitations under the License.
 
 #include <array>
 
+#include "tensorflow/compiler/xla/shape.h"
 #include "tensorflow/core/framework/types.pb.h"
 
 namespace tensorflow {
@@ -56,6 +57,11 @@ static constexpr std::array<DataType, 16> kTpuAllTypes = {
      DT_COMPLEX64, DT_INT64, DT_UINT64, DT_QINT8, DT_QUINT8, DT_INT8, DT_UINT8,
      DT_INT16, DT_UINT16}};
 
+// For the given shape, chooses a layout for infeed on TPU. The returned shape
+// has the same dimensions as the original shape, and only the layout is
+// changed.
+xla::Shape GetTPUInfeedLayout(const xla::Shape& shape);
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_TPU_TPU_DEFS_H_
diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc
index be9d594685e..40130bd46dd 100644
--- a/tensorflow/core/tpu/tpu_library_init_fns.inc
+++ b/tensorflow/core/tpu/tpu_library_init_fns.inc
@@ -161,6 +161,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
   TFTPU_SET_FN(executor_fn, TpuTransferManager_TransferLiteralFromDevice);
   TFTPU_SET_FN(executor_fn, TpuTransferManager_GetByteSizeRequirement);
   TFTPU_SET_FN(executor_fn, TpuTransferManager_WriteSingleTupleIndexTable);
+  TFTPU_SET_FN(executor_fn, TpuTransferManager_GetInfeedLayout);
   TFTPU_SET_FN(executor_fn, TpuTransferManager_LinearizeToBuffers);
   TFTPU_SET_FN(executor_fn, TpuTransferManager_FreeBuffers);
 
diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD
index a52f9919e6e..a8178404dff 100644
--- a/tensorflow/stream_executor/tpu/BUILD
+++ b/tensorflow/stream_executor/tpu/BUILD
@@ -203,10 +203,12 @@ cc_library(
 
 cc_library(
     name = "tpu_transfer_manager_interface",
+    srcs = ["tpu_transfer_manager_interface.cc"],
     hdrs = ["tpu_transfer_manager_interface.h"],
     visibility = ["//visibility:public"],
     deps = [
         ":noncopyable_buffer",
+        ":tpu_platform_interface",
         "//tensorflow/compiler/xla/service:transfer_manager",
     ],
 )
diff --git a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h
index 2b66c2ce4c5..013e7fe4e0c 100644
--- a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h
+++ b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h
@@ -182,6 +182,8 @@ void TpuTransferManager_WriteSingleTupleIndexTable(
     XLA_TransferManager* manager, SE_Stream* stream,
     SE_DeviceMemoryBase* elements, size_t elements_len, XLA_Shape* shape,
     SE_DeviceMemoryBase* region, SE_Status* status);
+void TpuTransferManager_GetInfeedLayout(XLA_Shape* shape,
+                                        XLA_Shape* infeed_shape);
 void TpuTransferManager_LinearizeToBuffers(
     XLA_TransferManager* manager, XLA_Literal* c_literal, char*** buffers_array,
     int64_t** buffers_size, int64_t* buffers_array_size, SE_Status* status);
@@ -341,6 +343,7 @@ struct TfTpu_ExecutorApiFn {
   TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_TransferLiteralFromDevice);
   TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_GetByteSizeRequirement);
   TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_WriteSingleTupleIndexTable);
+  TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_GetInfeedLayout);
   TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_LinearizeToBuffers);
   TFTPU_ADD_FN_IN_STRUCT(TpuTransferManager_FreeBuffers);
 
diff --git a/tensorflow/stream_executor/tpu/tpu_transfer_manager.h b/tensorflow/stream_executor/tpu/tpu_transfer_manager.h
index c201d63d2d5..e758c702204 100644
--- a/tensorflow/stream_executor/tpu/tpu_transfer_manager.h
+++ b/tensorflow/stream_executor/tpu/tpu_transfer_manager.h
@@ -81,6 +81,12 @@ class TpuTransferManager : public xla::TpuTransferManagerInterface {
       const xla::Shape& shape,
       stream_executor::DeviceMemoryBase* region) override;
 
+  Status LinearizeToBuffers(
+      const xla::LiteralSlice& literal,
+      std::deque<tensorflow::tpu::NoncopyableBuffer>* buffers) override {
+    LOG(FATAL) << "Not yet implemented.";
+  }
+
  private:
   XLA_TransferManager* manager_;
 };
diff --git a/tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.cc b/tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.cc
new file mode 100644
index 00000000000..746093972a4
--- /dev/null
+++ b/tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.cc
@@ -0,0 +1,40 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
+
+#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
+
+namespace xla {
+
+/*static*/ TpuTransferManagerInterface*
+TpuTransferManagerInterface::GetRegisteredTpuTransferManager() {
+  auto* platform = tensorflow::tpu::TpuPlatformInterface::GetRegisteredPlatform(
+      /*initialize_platform=*/false);
+  if (platform == nullptr) {
+    LOG(ERROR) << "Unable to retrieve registered TPU platform.";
+    return nullptr;
+  }
+  auto tm = xla::TransferManager::GetForPlatform(platform);
+  if (!tm.ok()) {
+    LOG(ERROR) << "Unable to retrieve TpuTransferManager. No TPU platform is "
+                  "registered for platform "
+               << platform->Name() << " and ID " << platform->id();
+    return nullptr;
+  }
+  return static_cast<TpuTransferManagerInterface*>(tm.ValueOrDie());
+}
+
+}  // namespace xla
diff --git a/tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h b/tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h
index 3f34ed8064d..b7e000b89ac 100644
--- a/tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h
+++ b/tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h
@@ -24,9 +24,16 @@ limitations under the License.
 namespace xla {
 
 class TpuTransferManagerInterface : public xla::TransferManager {
+ public:
   virtual Status TransferBuffersToInfeed(
       se::StreamExecutor* executor,
       const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) = 0;
+
+  virtual Status LinearizeToBuffers(
+      const LiteralSlice& literal,
+      std::deque<tensorflow::tpu::NoncopyableBuffer>* buffers) = 0;
+
+  static TpuTransferManagerInterface* GetRegisteredTpuTransferManager();
 };
 
 }  // namespace xla