diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index b1ebcdbe5b9..8970fce1460 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -72,7 +72,6 @@ load(
     "if_ios",
     "if_mobile",
     "if_not_windows",
-    "if_tpu",
     "tf_android_core_proto_headers",
     "tf_cc_test",
     "tf_cc_test_mkl",
@@ -117,6 +116,7 @@ load(
     "tf_protos_all_impl",
     "tf_protos_grappler_impl",
     "tf_protos_profiler_impl",
+    "tf_tpu_dependencies",
 )
 load(
     "//tensorflow/core/platform:rules_cc.bzl",
@@ -1086,9 +1086,7 @@ cc_library(
     ]) + if_tensorrt([
         "//tensorflow/compiler/tf2tensorrt:trt_engine_resource_op_kernels",
         "//tensorflow/compiler/tf2tensorrt:trt_op_kernels",
-    ]) + if_tpu([
-        "//tensorflow/core/tpu/kernels",
-    ]),
+    ]) + tf_tpu_dependencies(),
 )
 
 cc_library(
diff --git a/tensorflow/core/platform/build_config.bzl b/tensorflow/core/platform/build_config.bzl
index 3bfbe617122..cd902ac3353 100644
--- a/tensorflow/core/platform/build_config.bzl
+++ b/tensorflow/core/platform/build_config.bzl
@@ -43,6 +43,7 @@ load(
     _tf_py_clif_cc = "tf_py_clif_cc",
     _tf_pyclif_proto_library = "tf_pyclif_proto_library",
     _tf_resource_deps = "tf_resource_deps",
+    _tf_tpu_dependencies = "tf_tpu_dependencies",
     _tf_windows_aware_platform_deps = "tf_windows_aware_platform_deps",
 )
 
@@ -88,3 +89,4 @@ tf_py_clif_cc = _tf_py_clif_cc
 tf_pyclif_proto_library = _tf_pyclif_proto_library
 tf_resource_deps = _tf_resource_deps
 tf_windows_aware_platform_deps = _tf_windows_aware_platform_deps
+tf_tpu_dependencies = _tf_tpu_dependencies
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 9f84b9205f1..78191bff8f9 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -1,7 +1,7 @@
 # Platform-specific build configurations.
 
 load("@com_google_protobuf//:protobuf.bzl", "proto_gen")
-load("//tensorflow:tensorflow.bzl", "clean_dep", "if_not_windows")
+load("//tensorflow:tensorflow.bzl", "clean_dep", "if_not_windows", "if_tpu")
 load("//tensorflow/core/platform:build_config_root.bzl", "if_static")
 load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
 load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
@@ -800,3 +800,6 @@ def if_llvm_system_z_available(then, otherwise = []):
         "//tensorflow:linux_s390x": then,
         "//conditions:default": otherwise,
     })
+
+def tf_tpu_dependencies():
+    return if_tpu(["//tensorflow/core/tpu/kernels"])
diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD
index c47fdc0f9d2..f35f7151222 100644
--- a/tensorflow/core/tpu/kernels/BUILD
+++ b/tensorflow/core/tpu/kernels/BUILD
@@ -99,13 +99,22 @@ tf_kernel_library(
     name = "tpu_configuration_ops",
     srcs = ["tpu_configuration_ops.cc"],
     hdrs = ["tpu_configuration_ops.h"],
-    deps = [
+    copts = select({
+        WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
+        DEFAULT: [],
+    }),
+    deps = select({
+        WITH_TPU_SUPPORT: [":tpu_util"],
+        DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_util"],
+    }) + [
         ":tpu_compilation_cache_factory",
         ":tpu_compilation_cache_interface",
         ":tpu_compilation_cache_local_lookup",
         ":tpu_compilation_cache_lookup",
+        ":tpu_compilation_cache_rpc_lookup",
         ":tpu_mesh_state_interface",
         ":tpu_op_consts",
+        ":tpu_pod_state",
         "//tensorflow/c:tf_status",
         "//tensorflow/c:tf_status_helper",
         "//tensorflow/compiler/xla:util",
@@ -116,6 +125,7 @@ tf_kernel_library(
         "//tensorflow/core/tpu:tpu_config_c_api",
         "//tensorflow/core/tpu:tpu_configuration",
         "//tensorflow/core/tpu:tpu_defs",
+        "//tensorflow/stream_executor/lib",
         "//tensorflow/stream_executor/tpu:proto_helper",
     ],
     alwayslink = 1,
@@ -447,6 +457,7 @@ cc_library(
         "//tensorflow/core:protos_all_cc",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/time",
+        tf_grpc_cc_dependency(),
     ],
 )
 
@@ -505,10 +516,18 @@ cc_library(
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/tpu:tpu_api",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        tf_grpc_cc_dependency(),
     ],
     alwayslink = 1,
 )
 
+# An alias for
+cc_library(
+    name = "tpu_compilation_cache_cc_proto",
+    deps = [":tpu_compilation_cache_proto_cc"],
+)
+
 cc_library(
     name = "tpu_compilation_cache_rpc_support_hdrs",
     hdrs = ["tpu_compilation_cache_rpc_support.h"],
@@ -518,7 +537,7 @@ cc_library(
     }),
     deps = select({
         WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"],  # build_cleaner: keep
-        DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"],  # build_cleaner: keep
+        DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"],  # build_cleaner: keep
     }) + [
         ":tpu_compilation_cache_entry",
         ":tpu_compilation_cache_interface",
@@ -606,7 +625,7 @@ cc_library(
     }),
     deps = select({
         WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"],
-        DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"],
+        DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"],
     }) + [
         ":tpu_compilation_cache_common_proto_cc",
         tf_grpc_cc_dependency(),
@@ -628,7 +647,7 @@ cc_library(
         ],
         DEFAULT: [
             "//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support",  # build_cleaner: keep
-            "//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc",  # build_cleaner: keep
+            "//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto",  # build_cleaner: keep
         ],
     }) + [
         ":tpu_compilation_cache_common_proto_cc",
@@ -939,10 +958,14 @@ cc_library(
         WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
         DEFAULT: [],
     }),
-    deps = [
+    deps = select({
+        WITH_TPU_SUPPORT: [":tpu_util"],
+        DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_util"],
+    }) + [
         ":tpu_compilation_cache_service",
-        ":tpu_util",
+        "//tensorflow/c:tf_status",
+        "//tensorflow/c:tf_status_helper",
+        "//tensorflow/core/tpu:tpu_api",
         "//tensorflow/core:framework",
-        tf_grpc_cc_dependency(),
     ],
 )
diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc
index 5a8c283c7c2..271a9697f18 100644
--- a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc
+++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc
@@ -27,8 +27,10 @@ limitations under the License.
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
+#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h"
 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
+#include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
 #include "tensorflow/core/tpu/tpu_api.h"
 #include "tensorflow/core/tpu/tpu_config_c_api.h"
 #include "tensorflow/core/tpu/tpu_configuration.h"
@@ -37,7 +39,6 @@ limitations under the License.
 
 namespace tensorflow {
 namespace {
-
 Status GetTpuMeshStateInterface(const ResourceMgr* rmgr,
                                 tpu::TpuMeshStateInterface** state) {
   if (!rmgr->Lookup(rmgr->default_container(),
@@ -69,7 +70,6 @@ Status DeleteIfExists(ResourceMgr* resource_manager,
   VLOG(1) << "Error removing resource " << resource_name << " : " << status;
   return status;
 }
-
 }  // namespace
 
 Status CreateTpuCompilationCache(
@@ -82,36 +82,39 @@ Status CreateTpuCompilationCache(
       });
 }
 
-void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
-  VLOG(1) << "ConfigureDistributedTpuOp";
-  XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
-
+xla::StatusOr<std::vector<int32_t>> ConstructDevicesPerHost(
+    OpKernelContext* ctx) {
   std::vector<int32_t> num_devices_per_host;
   int chips_per_host = -1;
   for (int i = 0; i < ctx->num_inputs(); ++i) {
     const Tensor& input_tensor = ctx->input(i);
-    OP_REQUIRES(
-        ctx, TensorShapeUtils::IsScalar(input_tensor.shape()),
-        errors::InvalidArgument("Input ", i, " should be a scalar but has ",
-                                input_tensor.dims(), " dimensions"));
+    if (!TensorShapeUtils::IsScalar(input_tensor.shape())) {
+      return errors::InvalidArgument("Input ", i,
+                                     " should be a scalar but has ",
+                                     input_tensor.dims(), " dimensions");
+    }
     if (chips_per_host == -1) {
       chips_per_host = input_tensor.scalar<int32_t>()();
     } else {
-      OP_REQUIRES(
-          ctx, chips_per_host == input_tensor.scalar<int32>()(),
-          errors::Internal("Host ", i, " has ", input_tensor.scalar<int32>()(),
-                           " TPU chips but host 0 has ", chips_per_host));
+      if (chips_per_host != input_tensor.scalar<int32>()()) {
+        return errors::Internal("Host ", i, " has ",
+                                input_tensor.scalar<int32>()(),
+                                " TPU chips but host 0 has ", chips_per_host);
+      }
     }
     num_devices_per_host.push_back(input_tensor.scalar<int32_t>()());
   }
+  return num_devices_per_host;
+}
 
-  TF_Status* status = TF_NewStatus();
-  size_t host_config_output_size;
-  char* host_config_output;
+void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
+  VLOG(1) << "ConfigureDistributedTpuOp";
+  XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
 
-  auto* rmgr = GetTPUConfigResourceMgr();
-  OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
-                          rmgr, tpu::kTpuMeshStateInterfaceResourceName));
+  xla::StatusOr<std::vector<int32_t>> num_devices_per_host =
+      ConstructDevicesPerHost(ctx);
+  OP_REQUIRES_OK(ctx, num_devices_per_host.status());
+  ResourceMgr* rmgr = GetTPUConfigResourceMgr();
 
   // Create the subgraph compilation cache and put it in the local resource
   // manager.
@@ -119,9 +122,13 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
   OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
   core::ScopedUnref compilation_cache_ref(compilation_cache);
 
-  tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
-      num_devices_per_host.size(), num_devices_per_host.data(),
-      compilation_cache, &host_config_output_size, &host_config_output, status);
+  std::string host_config_output;
+  OP_REQUIRES_OK(
+      ctx, ConstructTpuPodState(rmgr, *num_devices_per_host, compilation_cache,
+                                &host_config_output));
+
+  OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
+                          rmgr, tpu::kTpuMeshStateInterfaceResourceName));
 
   auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
   OP_REQUIRES_OK(
@@ -130,13 +137,7 @@ void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
 
   Tensor* ctx_output;
   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
-  ctx_output->scalar<tstring>()() =
-      std::string(host_config_output, host_config_output_size);
-
-  OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
-  TF_DeleteStatus(status);
-
-  tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
+  ctx_output->scalar<tstring>()() = std::move(host_config_output);
 
   VLOG(1) << "ConfigureDistributedTpuOp done";
 }
@@ -186,30 +187,39 @@ void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
     mapping_arg.push_back(mapping[i].data());
   }
 
-  TF_Status* status = TF_NewStatus();
-  size_t tpu_topology_output_size;
-  char* tpu_topology_output;
-
   tpu::TpuMeshStateInterface* mesh_state;
   auto* rmgr = GetTPUConfigResourceMgr();
   OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
   core::ScopedUnref mesh_state_unref(mesh_state);
 
+  // TODO(b/166858751): this code to check if `TpuPodState` exists is ported
+  // from a legacy library that may have staled. A candidate for cleanup.
+  TpuPodState* pod_state;
+  OP_REQUIRES_OK(ctx, GetTPUPodState(rmgr, &pod_state));
+  core::ScopedUnref pod_state_unref(pod_state);
+
+  size_t tpu_topology_output_size;
+  char* tpu_topology_output = nullptr;
+  TF_Status* status = TF_NewStatus();
+  auto cleanup = xla::MakeCleanup([&status, &tpu_topology_output]() {
+    TF_DeleteStatus(status);
+    tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(
+        tpu_topology_output);
+  });
+
   auto* mesh_common_state = mesh_state->mesh_common_state();
   tpu::ConfigApiFn()->WaitForDistributedTpuOp_DoWorkFn(
       num_hosts, num_devices_per_host,
       const_cast<const int32_t**>(mapping_arg.data()), mesh_common_state,
       &tpu_topology_output_size, &tpu_topology_output, status);
 
+  OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
+
   Tensor* ctx_output;
   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
   ctx_output->scalar<tstring>()() =
       std::string(tpu_topology_output, tpu_topology_output_size);
 
-  OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
-  TF_DeleteStatus(status);
-  tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(tpu_topology_output);
-
   VLOG(1) << "WaitForDistributedTpuOp done";
 }
 
@@ -217,17 +227,14 @@ void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) {
   VLOG(1) << "ShutdownDistributedTpuOp";
   XLA_SCOPED_LOGGING_TIMER("ShutdownDistributedTpuOp");
 
-  TF_Status* status = TF_NewStatus();
+  auto* rmgr = GetTPUConfigResourceMgr();
   OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
-                          GetTPUConfigResourceMgr(),
-                          tpu::kTpuMeshStateInterfaceResourceName));
-  tpu::ConfigApiFn()->ShutdownDistributedTpuOp_DoWorkFn(status);
-  OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
-  TF_DeleteStatus(status);
+                          rmgr, tpu::kTpuMeshStateInterfaceResourceName));
 
-  OP_REQUIRES_OK(
-      ctx, DeleteIfExists<tpu::TpuCompilationCacheInterface>(
-               GetTPUConfigResourceMgr(), tpu::kCompilationCacheResourceName));
+  OP_REQUIRES_OK(ctx,
+                 DeleteIfExists<TpuPodState>(rmgr, kTpuPodStateResourceName));
+  OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheInterface>(
+                          rmgr, tpu::kCompilationCacheResourceName));
 
   VLOG(1) << "ShutdownDistributedTpuOp done";
 }
@@ -239,10 +246,6 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
   auto* rmgr = GetTPUConfigResourceMgr();
   auto tpu_host_config = ctx->input(0).scalar<tstring>()();
 
-  size_t device_id_output_size;
-  int32_t* device_id_output;
-  TF_Status* status = TF_NewStatus();
-
   bool is_master_worker =
       tpu::ConfigApiFn()->TpuConfigurationApi_HasTPUPodStateFn();
   if (!is_master_worker) {
@@ -275,10 +278,18 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
     local_compilation_cache = nullptr;
   }
 
+  TF_Status* status = TF_NewStatus();
+  size_t device_id_output_size;
+  int32_t* device_id_output = nullptr;
+  auto cleanup = xla::MakeCleanup([&status, &device_id_output]() {
+    TF_DeleteStatus(status);
+    tpu::ConfigApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
+  });
   tpu::ConfigApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(
       tpu_host_config.size(), tpu_host_config.data(),
-      enable_whole_mesh_compilations_, local_compilation_cache,
-      &device_id_output_size, &device_id_output, status);
+      enable_whole_mesh_compilations_, is_master_worker, &device_id_output_size,
+      &device_id_output, status);
+  OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
 
   if (local_compilation_cache != nullptr) {
     local_compilation_cache->Unref();
@@ -289,6 +300,30 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
     OP_REQUIRES_OK(
         ctx, rmgr->Create(rmgr->default_container(),
                           tpu::kCompiledProtoCacheResourceName, proto_lookup));
+  } else {
+    int64_t cache_size_bytes;
+    tpu::ConfigApiFn()->TpuConfigurationApi_RemoteCompilationCacheSizeInBytesFn(
+        &cache_size_bytes);
+
+    char* server_address_output = nullptr;
+    auto cleanup_server_address = xla::MakeCleanup([&server_address_output]() {
+      tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(
+          server_address_output);
+    });
+    size_t server_address_output_size;
+    tpu::ConfigApiFn()
+        ->TpuConfigurationApi_CompilationCacheServerAddressFromConfigFn(
+            tpu_host_config.size(), tpu_host_config.data(),
+            &server_address_output_size, &server_address_output, status);
+    OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
+
+    std::string server_address(server_address_output,
+                               server_address_output_size);
+    tpu::TpuCompilationCacheLookup* proto_lookup =
+        new tpu::TpuCompilationCacheRpcLookup(server_address, cache_size_bytes);
+    OP_REQUIRES_OK(
+        ctx, rmgr->Create(rmgr->default_container(),
+                          tpu::kCompiledProtoCacheResourceName, proto_lookup));
   }
 
   Tensor* ctx_output;
@@ -301,10 +336,6 @@ void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
     ctx_output->flat<int32>()(i) = device_id_output[i];
   }
 
-  OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
-  TF_DeleteStatus(status);
-  tpu::ConfigApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
-
   VLOG(1) << "InitializeHostForDistributedTpuOp done";
 }
 
diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.h b/tensorflow/core/tpu/kernels/tpu_configuration_ops.h
index d0bf5809842..d58712ae3dd 100644
--- a/tensorflow/core/tpu/kernels/tpu_configuration_ops.h
+++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.h
@@ -15,14 +15,22 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
 
+#include <stdint.h>
+
+#include <vector>
+
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
 
 namespace tensorflow {
 
 Status CreateTpuCompilationCache(
     ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache);
 
+xla::StatusOr<std::vector<int32_t>> ConstructDevicesPerHost(
+    OpKernelContext* ctx);
+
 // The ConfigureDistributedTpu op is used to start an TPUDriver from
 // TensorFlow. It should be run on a TPU_SYSTEM device and returns the
 // connection host:port for the CompilationCacheServer. The
diff --git a/tensorflow/core/tpu/kernels/tpu_pod_state.cc b/tensorflow/core/tpu/kernels/tpu_pod_state.cc
index a45a4d63708..e7f13a657ed 100644
--- a/tensorflow/core/tpu/kernels/tpu_pod_state.cc
+++ b/tensorflow/core/tpu/kernels/tpu_pod_state.cc
@@ -14,12 +14,78 @@ limitations under the License.
 ==============================================================================*/
 #include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
 
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/core/tpu/tpu_api.h"
+
+#if defined(LIBTFTPU)
 #include "tensorflow/core/tpu/kernels/tpu_util.h"
+#else
+#include "tensorflow/core/tpu/kernels/tpu_util.h"  // copybara"
+#endif
 
 namespace tensorflow {
-
 const char kTpuPodStateResourceName[] = "tpu_pod_state";
 
+namespace {
+Status GetServerAddressAndPort(std::string* server_address, int* serving_port) {
+  TF_Status* status = TF_NewStatus();
+  char* server_address_output = nullptr;
+  auto cleanup = xla::MakeCleanup([&status, &server_address_output]() {
+    TF_DeleteStatus(status);
+    tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(
+        server_address_output);
+  });
+  size_t server_address_output_size;
+  *serving_port = -1;
+  tpu::ConfigApiFn()->TpuConfigurationApi_GetServerAddressAndPortFn(
+      &server_address_output_size, &server_address_output, serving_port,
+      status);
+  CHECK_NE(*serving_port, -1);
+  TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
+  return Status::OK();
+}
+
+// Attempt to delete resource_name from resource_manager's default_container.
+// Returns OK if the deletion succeeded, or if the resource was not found. Else
+// return the deletion error.
+template <class ResourceT>
+Status DeleteIfExists(ResourceMgr* resource_manager,
+                      const char* resource_name) {
+  VLOG(1) << "Removing resource " << resource_name << " if it exists";
+  Status status = resource_manager->Delete<ResourceT>(
+      resource_manager->default_container(), resource_name);
+  if (status.ok()) {
+    VLOG(1) << "Removed existing resource " << resource_name;
+    return Status::OK();
+  }
+  if (status.code() == error::NOT_FOUND) {
+    VLOG(1) << "No resource " << resource_name << " to remove";
+    return Status::OK();
+  }
+  VLOG(1) << "Error removing resource " << resource_name << " : " << status;
+  return status;
+}
+
+xla::StatusOr<std::unique_ptr<TpuCompilationCacheService>>
+ConstructCacheService(ResourceMgr* rmgr, int serving_port,
+                      tpu::TpuCompilationCacheInterface* compilation_cache) {
+  xla::StatusOr<std::unique_ptr<::grpc::ServerBuilder>> server_builder;
+#if defined(LIBTFTPU)
+  server_builder = tpu::CreateServerBuilder(serving_port);
+#else
+  server_builder = tpu::CreateServerBuilderGoogle(serving_port);
+#endif
+  TF_RETURN_IF_ERROR(server_builder.status());
+
+  auto cache_service = absl::make_unique<TpuCompilationCacheService>(
+      server_builder.ValueOrDie().get(), compilation_cache);
+  cache_service->SetMemoryQuota(1ul << 31);  // 2GB
+  cache_service->Start();
+  return cache_service;
+}
+}  // namespace
+
 TpuPodState::TpuPodState(
     int service_port, std::unique_ptr<TpuCompilationCacheService> cache_service)
     : cache_service_(std::move(cache_service)), service_port_(service_port) {}
@@ -29,7 +95,7 @@ TpuPodState::~TpuPodState() {
     VLOG(1) << "Shutting down Compilation Cache Service.";
     if (cache_service_->Shutdown(20)) {
       if (service_port_ >= 0) {
-        tpu::RecycleUnusedPort(service_port_);
+        tpu::UtilApiFn()->TpuNetUtil_RecycleUnusedPortFn(service_port_);
       }
     } else {
       LOG(ERROR)
@@ -67,4 +133,38 @@ bool HasTPUPodState(const ResourceMgr* rmgr) {
   return true;
 }
 
+Status ConstructTpuPodState(
+    ResourceMgr* rmgr, const std::vector<int32_t>& num_devices_per_host,
+    tpu::TpuCompilationCacheInterface* compilation_cache,
+    std::string* host_config_proto) {
+  TF_Status* status = TF_NewStatus();
+  auto status_cleanup =
+      xla::MakeCleanup([&status]() { TF_DeleteStatus(status); });
+
+  int serving_port;
+  std::string server_address;
+  TF_RETURN_IF_ERROR(GetServerAddressAndPort(&server_address, &serving_port));
+
+  char* host_config_output = nullptr;
+  auto host_config_cleanup = xla::MakeCleanup([&host_config_output]() {
+    tpu::ConfigApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
+  });
+  size_t host_config_output_size;
+  tpu::ConfigApiFn()->ConfigureDistributedTpuOp_DoWorkFn(
+      num_devices_per_host.size(), num_devices_per_host.data(),
+      server_address.size(), server_address.data(), &host_config_output_size,
+      &host_config_output, status);
+  TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
+  *host_config_proto = std::string(host_config_output, host_config_output_size);
+
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<TpuCompilationCacheService> cache_service,
+      ConstructCacheService(rmgr, serving_port, compilation_cache));
+
+  // Delete TpuPodState if it exists, and recreate below.
+  TF_RETURN_IF_ERROR(
+      DeleteIfExists<TpuPodState>(rmgr, kTpuPodStateResourceName));
+  return rmgr->Create(rmgr->default_container(), kTpuPodStateResourceName,
+                      new TpuPodState(serving_port, std::move(cache_service)));
+}
 }  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_pod_state.h b/tensorflow/core/tpu/kernels/tpu_pod_state.h
index 9f37e28f60f..07ad3bee553 100644
--- a/tensorflow/core/tpu/kernels/tpu_pod_state.h
+++ b/tensorflow/core/tpu/kernels/tpu_pod_state.h
@@ -15,7 +15,9 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_
 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_
 
-#include "grpcpp/server_builder.h"
+#include <string>
+#include <vector>
+
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_service.h"
 
@@ -49,6 +51,11 @@ Status GetTPUPodState(const ResourceMgr* rmgr, TpuPodState** pod_state);
 // manager.
 bool HasTPUPodState(const ResourceMgr* rmgr);
 
+// Construct TpuPodState.
+Status ConstructTpuPodState(
+    ResourceMgr* rmgr, const std::vector<int32_t>& num_devices_per_host,
+    tpu::TpuCompilationCacheInterface* compilation_cache,
+    std::string* host_config_proto);
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_
diff --git a/tensorflow/core/tpu/kernels/tpu_util.cc b/tensorflow/core/tpu/kernels/tpu_util.cc
index 837c23c6cf5..6f31d066db5 100644
--- a/tensorflow/core/tpu/kernels/tpu_util.cc
+++ b/tensorflow/core/tpu/kernels/tpu_util.cc
@@ -14,6 +14,7 @@ limitations under the License.
 ==============================================================================*/
 #include "tensorflow/core/tpu/kernels/tpu_util.h"
 
+#include "absl/strings/str_format.h"
 #include "absl/strings/str_split.h"
 #include "tensorflow/core/platform/random.h"
 #include "tensorflow/core/tpu/tpu_api.h"
@@ -97,8 +98,13 @@ Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
   return Status::OK();
 }
 
-void RecycleUnusedPort(int port) {
-  UtilApiFn()->TpuNetUtil_RecycleUnusedPortFn(port);
+xla::StatusOr<std::unique_ptr<::grpc::ServerBuilder>> CreateServerBuilder(
+    int serving_port) {
+  auto server_builder = absl::make_unique<::grpc::ServerBuilder>();
+  server_builder->AddListeningPort(
+      absl::StrFormat("[::]:%d", serving_port),
+      ::grpc::InsecureServerCredentials());  // NOLINT
+  return std::move(server_builder);
 }
 }  // namespace tpu
 }  // namespace tensorflow
diff --git a/tensorflow/core/tpu/kernels/tpu_util.h b/tensorflow/core/tpu/kernels/tpu_util.h
index 834db31c3d8..d45934f31b6 100644
--- a/tensorflow/core/tpu/kernels/tpu_util.h
+++ b/tensorflow/core/tpu/kernels/tpu_util.h
@@ -15,9 +15,11 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
 
+#include <memory>
 #include <string>
 #include <vector>
 
+#include "grpcpp/server_builder.h"
 #include "absl/strings/str_cat.h"
 #include "tensorflow/cc/framework/ops.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -55,10 +57,9 @@ Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes,
 Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
                                    std::vector<TensorShape>* shapes);
 
-// We only recycle ports which were given to us by the portserver. For ports
-// we obtained through local trial-and-error, there is no reason to expect the
-// port to remain available after it is unbound.
-void RecycleUnusedPort(int port);
+// Creates gRPC ServerBuilder.
+xla::StatusOr<std::unique_ptr<::grpc::ServerBuilder>> CreateServerBuilder(
+    int serving_port);
 }  // namespace tpu
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/tpu/tpu_config_c_api.h b/tensorflow/core/tpu/tpu_config_c_api.h
index 08417dbf907..de4b2e25570 100644
--- a/tensorflow/core/tpu/tpu_config_c_api.h
+++ b/tensorflow/core/tpu/tpu_config_c_api.h
@@ -32,8 +32,9 @@ extern "C" {
 
 TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
     const size_t num_cores_per_host_size, const int32_t* num_cores_per_host,
-    void* tpu_compilation_cache_interface, size_t* host_config_output_size,
-    char** host_config_output, TF_Status* status);
+    size_t server_address_size, const char* server_address,
+    size_t* host_config_output_size, char** host_config_output,
+    TF_Status* status);
 
 TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
     const size_t num_hosts, const size_t num_cores_per_host,
@@ -42,11 +43,9 @@ TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
     size_t* tpu_topology_output_size, char** tpu_topology_output,
     TF_Status* status);
 
-TFTPU_CAPI_EXPORT void ShutdownDistributedTpuOp_DoWork(TF_Status* status);
-
 TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork(
     const size_t tpu_host_config_size, const char* tpu_host_config,
-    const bool enable_whole_mesh_compilations, void* local_compilation_cache,
+    const bool enable_whole_mesh_compilations, bool is_master_worker,
     size_t* core_id_output_size, int32_t** core_id_output, TF_Status* status);
 
 TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork(
@@ -65,12 +64,22 @@ TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus,
                                                        TF_Status* status);
 TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit,
                                                           TF_Status* status);
+
+TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes(
+    int64_t* cache_size_in_bytes);
+TFTPU_CAPI_EXPORT
+void TpuConfigurationApi_CompilationCacheServerAddressFromConfig(
+    size_t tpu_host_config_size, const char* tpu_host_config,
+    size_t* server_address_output_size, char** server_address_output,
+    TF_Status* status);
+TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort(
+    size_t* server_address_output_size, char** server_address_output,
+    int* port_output, TF_Status* status);
 }
 
 struct TfTpu_ConfigApiFn {
   TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork);
   TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork);
-  TFTPU_ADD_FN_IN_STRUCT(ShutdownDistributedTpuOp_DoWork);
   TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork);
   TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork);
   TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
@@ -79,6 +88,10 @@ struct TfTpu_ConfigApiFn {
   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState);
   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost);
   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit);
+  TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes);
+  TFTPU_ADD_FN_IN_STRUCT(
+      TpuConfigurationApi_CompilationCacheServerAddressFromConfig);
+  TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort);
 };
 
 #endif  // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc
index cb8871a60c5..fde2712a2f0 100644
--- a/tensorflow/core/tpu/tpu_library_init_fns.inc
+++ b/tensorflow/core/tpu/tpu_library_init_fns.inc
@@ -11,7 +11,6 @@ tensorflow::Status SetTpuConfigStructFns(void* library_handle) {
 
   TFTPU_SET_FN(config_fn, ConfigureDistributedTpuOp_DoWork);
   TFTPU_SET_FN(config_fn, WaitForDistributedTpuOp_DoWork);
-  TFTPU_SET_FN(config_fn, ShutdownDistributedTpuOp_DoWork);
   TFTPU_SET_FN(config_fn, InitializeHostForDistributedTpuOp_DoWork);
   TFTPU_SET_FN(config_fn, SetGlobalTPUArrayOp_DoWork);
   TFTPU_SET_FN(config_fn, DisconnectDistributedTpuChipsOp_DoWork);
@@ -20,6 +19,11 @@ tensorflow::Status SetTpuConfigStructFns(void* library_handle) {
   TFTPU_SET_FN(config_fn, TpuConfigurationApi_HasTPUPodState);
   TFTPU_SET_FN(config_fn, TpuConfigurationApi_TpusPerHost);
   TFTPU_SET_FN(config_fn, TpuConfigurationApi_TpuMemoryLimit);
+  TFTPU_SET_FN(config_fn,
+               TpuConfigurationApi_RemoteCompilationCacheSizeInBytes);
+  TFTPU_SET_FN(config_fn,
+               TpuConfigurationApi_CompilationCacheServerAddressFromConfig);
+  TFTPU_SET_FN(config_fn, TpuConfigurationApi_GetServerAddressAndPort);
 
   return tensorflow::Status::OK();
 }