diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD
index d81b9a4b84c..664780abcc7 100644
--- a/tensorflow/compiler/xla/pjrt/BUILD
+++ b/tensorflow/compiler/xla/pjrt/BUILD
@@ -168,6 +168,46 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "tpu_client",
+    srcs = ["tpu_client.cc"],
+    hdrs = ["tpu_client.h"],
+    visibility = [
+        "//learning/brain/research/jax:__subpackages__",
+        "//learning/deepmind/tensorflow/tensorfn:__subpackages__",
+        "//learning/pathways:__subpackages__",
+    ],
+    deps = [
+        ":local_device_state",
+        ":pjrt_client",
+        ":tracked_device_buffer",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla/client:client_library",
+        "//tensorflow/compiler/xla/service:computation_placer",
+        "//tensorflow/compiler/xla/service:shaped_buffer",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/tpu:tpu_executor_dlsym_initializer",
+        "//tensorflow/core/tpu:tpu_on_demand_compiler",
+        "//tensorflow/stream_executor:device_memory",
+        "//tensorflow/stream_executor:stream",
+        "//tensorflow/stream_executor/lib",
+        "//tensorflow/stream_executor/tpu:tpu_computation_placer",
+        "//tensorflow/stream_executor/tpu:tpu_executable_interface",
+        "//tensorflow/stream_executor/tpu:tpu_executor",
+        "//tensorflow/stream_executor/tpu:tpu_executor_interface",
+        "//tensorflow/stream_executor/tpu:tpu_platform_interface",
+        "//tensorflow/stream_executor/tpu:tpu_topology_external",
+        "//tensorflow/stream_executor/tpu:tpu_transfer_manager",
+        "@com_google_absl//absl/container:inlined_vector",
+        "@com_google_absl//absl/flags:flag",
+        "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/status",
+    ],
+)
+
 cc_library(
     name = "interpreter_device",
     srcs = ["interpreter_device.cc"],
diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.cc b/tensorflow/compiler/xla/pjrt/tpu_client.cc
new file mode 100644
index 00000000000..a8711631605
--- /dev/null
+++ b/tensorflow/compiler/xla/pjrt/tpu_client.cc
@@ -0,0 +1,246 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
+
+#include <memory>
+#include <vector>
+
+#include "absl/container/inlined_vector.h"
+#include "absl/memory/memory.h"
+#include "absl/status/status.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
+#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
+#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/shape.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/casts.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/tpu/tpu_computation_placer.h"
+#include "tensorflow/stream_executor/tpu/tpu_executable_interface.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
+#include "tensorflow/stream_executor/tpu/tpu_stream.h"
+
+namespace tf_tpu = tensorflow::tpu;
+
+namespace xla {
+namespace {
+
+class TpuDeviceState : public LocalDeviceState {
+ public:
+  TpuDeviceState(se::StreamExecutor* executor, LocalClient* client,
+                 bool asynchronous);
+
+  Status ThenMemcpyDeviceToDevice(se::Stream* transfer_stream,
+                                  se::Stream* dst_stream,
+                                  se::DeviceMemoryBase src_buffer,
+                                  se::DeviceMemoryBase dst_buffer) override;
+};
+
+TpuDeviceState::TpuDeviceState(se::StreamExecutor* executor,
+                               LocalClient* client, bool asynchronous)
+    : LocalDeviceState(executor, client, LocalDeviceState::kAsynchronous,
+                       asynchronous,
+                       /*allow_event_reuse=*/false) {}
+
+Status TpuDeviceState::ThenMemcpyDeviceToDevice(
+    se::Stream* transfer_stream, se::Stream* dst_stream,
+    se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
+  auto* transfer_tpu_stream = tensorflow::down_cast<tf_tpu::TpuStream*>(
+      transfer_stream->implementation());
+  tf_tpu::TpuTopologyExternal topology =
+      tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
+  // TODO(b/157179600): use device-to-device transfers when implemented instead
+  // of copying via host.
+  if (topology.version() == kTpuV4) {
+    LOG(WARNING)
+        << "device-to-device transfers not yet implemented, copying via host";
+    auto* dst_tpu_stream =
+        tensorflow::down_cast<tf_tpu::TpuStream*>(dst_stream->implementation());
+    TF_RET_CHECK(src_buffer.size() == dst_buffer.size());
+    auto host_tmp = std::make_unique<char[]>(src_buffer.size());
+    TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueTransferDeviceToHost(
+        src_buffer, host_tmp.get(), src_buffer.size()));
+    dst_stream->ThenWaitFor(transfer_stream);
+    TF_RETURN_IF_ERROR(dst_tpu_stream->EnqueueTransferHostToDevice(
+        dst_buffer, host_tmp.get(), dst_buffer.size()));
+    transfer_stream->ThenWaitFor(dst_stream);
+    char* tmp = host_tmp.release();
+    dst_stream->ThenDoHostCallback([tmp] { delete[] tmp; });
+  } else {
+    TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueOnTpuDeviceSendRecvLocal(
+        src_buffer, dst_buffer));
+  }
+  return Status::OK();
+}
+
+class PjRtTpuClient : public PjRtClient {
+ public:
+  PjRtTpuClient(LocalClient* client,
+                std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
+                tf_tpu::TpuPlatformInterface* tpu_platform);
+
+  StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
+      int num_replicas, int num_partitions) const override;
+
+  bool EnqueueD2DTransfersOnSrcStream() const override {
+    return tpu_platform_->topology().version() == kTpuV4;
+  }
+
+  StatusOr<absl::optional<std::string>> ExecutableFingerprint(
+      const PjRtExecutable& executable) const override;
+
+ private:
+  tf_tpu::TpuPlatformInterface* tpu_platform_;
+};
+
+PjRtTpuClient::PjRtTpuClient(LocalClient* client,
+                             std::vector<std::unique_ptr<PjRtDevice>> devices,
+                             int host_id,
+                             tf_tpu::TpuPlatformInterface* tpu_platform)
+    : PjRtClient("tpu", client, std::move(devices), host_id,
+                 /*allocator=*/nullptr,
+                 /*host_memory_allocator=*/nullptr,
+                 /*should_stage_host_to_device_transfers=*/false,
+                 /*gpu_run_options=*/nullptr),
+      tpu_platform_(tpu_platform) {}
+
+StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
+    int num_replicas, int num_partitions) const {
+  tf_tpu::TpuPlatformInterface* platform =
+      tf_tpu::TpuPlatformInterface::GetRegisteredPlatform();
+  tf_tpu::TpuHostLocationExternal host = platform->GetTpuHostLocation();
+  int num_local_devices = host.Cores(kTensorCore).size();
+  if (num_replicas * num_partitions <= num_local_devices) {
+    return tf_tpu::TpuComputationPlacer::AssignLocalDevices(host, num_replicas,
+                                                            num_partitions);
+  }
+  // Fallback to default global device assignment if we can't run locally.
+  return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions);
+}
+
+StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
+    const PjRtExecutable& executable) const {
+  if (executable.client() != this) {
+    return InvalidArgument(
+        "Passed executable from different client (platform '%s') to "
+        "PjRtTpuClient::ExecutableFingerprint",
+        executable.client()->platform_name());
+  }
+  if (executable.executables().size() > 1) {
+    LOG(INFO) << "ExecutableFingerprint not fully implemented for MPMD "
+                 "executables, fingerprint may not be unique.";
+  }
+  xla::TpuExecutableInterface* tpu_executable =
+      tensorflow::down_cast<xla::TpuExecutableInterface*>(
+          executable.executables()[0]->executable());
+  return absl::optional<std::string>(tpu_executable->fingerprint());
+}
+
+StatusOr<std::vector<std::unique_ptr<PjRtDevice>>> GetTpuDevices(
+    LocalClient* client,
+    std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
+  std::vector<std::unique_ptr<PjRtDevice>> devices;
+  tf_tpu::TpuTopologyExternal topology =
+      tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
+
+  std::map<int, int> core_id_to_device_ordinal;
+  for (int i = 0; i < client->device_count(); ++i) {
+    se::StreamExecutor* executor =
+        client->backend().stream_executor(i).ValueOrDie();
+    tf_tpu::TpuExecutorInterface* tpu_executor =
+        tensorflow::down_cast<tf_tpu::TpuExecutorInterface*>(
+            executor->implementation());
+    core_id_to_device_ordinal[tpu_executor->GetCoreLocationExternal().Id()] = i;
+  }
+
+  for (const tf_tpu::TpuCoreLocationExternal& core :
+       topology.cores(TpuCoreTypeEnum::kTensorCore)) {
+    auto it = core_id_to_device_ordinal.find(core.Id());
+    int device_ordinal =
+        (it != core_id_to_device_ordinal.end()) ? it->second : -1;
+    int host_id = topology.IdForHost(core.host_coordinates());
+    const tf_tpu::TpuDimensionsExternal coords = core.chip_coordinates();
+    std::array<int, 3> coords_array = {coords.x, coords.y, coords.z};
+    std::unique_ptr<LocalDeviceState> local_device_state;
+    if (device_ordinal >= 0) {
+      local_device_state = std::move(local_device_states[device_ordinal]);
+    }
+    auto device = absl::make_unique<PjRtTpuDevice>(
+        core, std::move(local_device_state), host_id, coords_array,
+        std::string(tf_tpu::TpuVersionEnumToString(topology.version())));
+    devices.push_back(std::move(device));
+  }
+  return devices;
+}
+
+}  // namespace
+
+StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient(
+    bool asynchronous, absl::Duration init_retry_timeout) {
+  tf_tpu::TpuPlatformInterface* platform =
+      tf_tpu::TpuPlatformInterface::GetRegisteredPlatform();
+  if (platform == nullptr) {
+    return InvalidArgument("TpuPlatform is not available.");
+  }
+  // NOTE: We retry in a loop since some pod failures are transient (e.g. some
+  // RPCs may timeout waiting for other hosts to come up, but will succeed
+  // at a later point if retried).
+  auto start = absl::Now();
+  // TODO(b/165870356): TpuPlatform::Initialized() always returns true!
+  auto status = platform->Initialize({});
+  while (!platform->Initialized()) {
+    status = platform->Initialize({});
+    if (!status.ok()) {
+      LOG(ERROR) << "Platform initialization failed: " << status;
+      if ((absl::Now() - start) >= init_retry_timeout) {
+        return status;
+      }
+    }
+  }
+  if (platform->VisibleDeviceCount() <= 0) {
+    return InvalidArgument("No TPU devices found.");
+  }
+  LocalClientOptions options;
+  options.set_platform(platform);
+  TF_ASSIGN_OR_RETURN(LocalClient * client,
+                      ClientLibrary::GetOrCreateLocalClient(options));
+
+  std::vector<std::unique_ptr<LocalDeviceState>> local_device_states;
+  local_device_states.reserve(client->device_count());
+  for (int i = 0; i < client->device_count(); ++i) {
+    se::StreamExecutor* executor =
+        client->backend().stream_executor(i).ValueOrDie();
+    local_device_states.push_back(
+        absl::make_unique<TpuDeviceState>(executor, client, asynchronous));
+  }
+
+  TF_ASSIGN_OR_RETURN(auto devices,
+                      GetTpuDevices(client, std::move(local_device_states)));
+  int host_id = platform->GetTpuHostLocation().Id();
+
+  return std::shared_ptr<PjRtClient>(absl::make_unique<PjRtTpuClient>(
+      client, std::move(devices), host_id, platform));
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.h b/tensorflow/compiler/xla/pjrt/tpu_client.h
new file mode 100644
index 00000000000..1a458c1480b
--- /dev/null
+++ b/tensorflow/compiler/xla/pjrt/tpu_client.h
@@ -0,0 +1,60 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_
+#define TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_
+
+#include <array>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/stream_executor/tpu/tpu_topology.h"
+
+namespace xla {
+
+class PjRtTpuDevice : public PjRtDevice {
+ public:
+  PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core,
+                std::unique_ptr<LocalDeviceState> local_device_state,
+                int host_id, const std::array<int, 3>& coords,
+                std::string device_kind)
+      : PjRtDevice(core.Id(), std::move(local_device_state),
+                   /*platform_name=*/"tpu", std::move(device_kind), host_id),
+        core_(core),
+        coords_(coords) {}
+
+  const std::array<int, 3>& coords() const { return coords_; }
+  int core_on_chip() const { return core_.index(); }
+  const tensorflow::tpu::TpuCoreLocationExternal core() const { return core_; }
+
+  std::string DebugString() const override {
+    return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), host_id(),
+                           coords_[0], coords_[1], coords_[2], core_.index());
+  }
+
+ private:
+  const tensorflow::tpu::TpuCoreLocationExternal core_;
+  const std::array<int, 3> coords_;
+};
+
+StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient(
+    bool asynchronous,
+    absl::Duration init_retry_timeout = absl::ZeroDuration());
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_
diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD
index 484c9a47e60..85586809014 100644
--- a/tensorflow/core/tpu/BUILD
+++ b/tensorflow/core/tpu/BUILD
@@ -177,6 +177,23 @@ cc_library(
     ],
 )
 
+# This is an alternative to "tpu_api_dlsym_initializer" that only initializes
+# methods needed for the base TPU executor APIs (and thus has fewer deps). Do
+# not link in both this and "tpu_api_dlsym_initializer".
+cc_library(
+    name = "tpu_executor_dlsym_initializer",
+    srcs = ["tpu_executor_dlsym_initializer.cc"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":tpu_api_dlsym_set_fn",
+        ":tpu_executor_init_fns",
+        "//tensorflow/core:lib",
+        "//tensorflow/stream_executor/tpu:tpu_computation_placer",
+        "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
+    ],
+    alwayslink = True,
+)
+
 cc_library(
     name = "tpu_api_dlsym_set_fn",
     hdrs = ["tpu_api_dlsym_set_fn.h"],
@@ -294,6 +311,7 @@ cc_library(
 cc_library(
     name = "tpu_on_demand_compiler",
     srcs = ["tpu_on_demand_compiler.cc"],
+    visibility = ["//visibility:public"],
     deps = [
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:util",
diff --git a/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc b/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc
new file mode 100644
index 00000000000..4d84781f4e3
--- /dev/null
+++ b/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc
@@ -0,0 +1,70 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// TODO(skye): this is largely a copy of tpu_api_dlsym_initializer.cc. Figure
+// out how to deduplicate these files a little.
+
+#include <dlfcn.h>
+
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/tpu/tpu_api_dlsym_set_fn.h"
+#if !defined(PLATFORM_GOOGLE)
+#include "tensorflow/core/tpu/tpu_executor_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
+#include "tensorflow/stream_executor/tpu/tpu_platform.h"
+#endif
+
+namespace tensorflow {
+namespace tpu {
+
+#if defined(PLATFORM_GOOGLE)
+Status InitializeTpuLibrary(void* library_handle) {
+  return errors::Unimplemented("You must statically link in a TPU library.");
+}
+#else  // PLATFORM_GOOGLE
+#include "tensorflow/core/tpu/tpu_executor_init_fns.inc"
+
+Status InitializeTpuLibrary(void* library_handle) {
+  Status s = SetExecutorStructFn(library_handle);
+
+  // TPU platform registration must only be performed after the library is
+  // loaded. We do not want to register a TPU platform in XLA without the
+  // supporting library providing the necessary APIs.
+  if (s.ok()) {
+    void (*initialize_fn)();
+    initialize_fn = reinterpret_cast<decltype(initialize_fn)>(
+        dlsym(library_handle, "TfTpu_Initialize"));
+    (*initialize_fn)();
+
+    RegisterTpuPlatform();
+  }
+
+  return s;
+}
+
+bool FindAndLoadTpuLibrary() {
+  void* library = dlopen("libtpu.so", RTLD_NOW);
+  if (library) {
+    InitializeTpuLibrary(library);
+  }
+  return true;
+}
+
+static bool tpu_library_finder = FindAndLoadTpuLibrary();
+#endif  // PLATFORM_GOOGLE
+
+}  // namespace tpu
+}  // namespace tensorflow
diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD
index df0867042f2..540a0a234ff 100644
--- a/tensorflow/stream_executor/tpu/BUILD
+++ b/tensorflow/stream_executor/tpu/BUILD
@@ -222,6 +222,7 @@ cc_library(
 cc_library(
     name = "tpu_transfer_manager",
     srcs = ["tpu_transfer_manager_registration.cc"],
+    visibility = ["//visibility:public"],
     deps = [
         ":tpu_executor",
         ":tpu_transfer_manager_base",
@@ -256,6 +257,7 @@ cc_library(
     name = "tpu_computation_placer",
     srcs = ["tpu_computation_placer.cc"],
     hdrs = ["tpu_computation_placer.h"],
+    visibility = ["//visibility:public"],
     deps = [
         ":status_helper",
         ":tpu_executor",
diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.h b/tensorflow/stream_executor/tpu/tpu_platform_interface.h
index fee9d92b42d..240148977e3 100644
--- a/tensorflow/stream_executor/tpu/tpu_platform_interface.h
+++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.h
@@ -56,6 +56,10 @@ class TpuPlatformInterface : public stream_executor::Platform {
   virtual const TpuTopologyPtr GetTopologyPtr() = 0;
 
   virtual const TpuHostLocationExternal GetTpuHostLocation() const = 0;
+
+  TpuTopologyExternal topology() {
+    return TpuTopologyExternal(GetTopologyPtr());
+  }
 };
 
 }  // namespace tpu