diff --git a/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc
index c56b41861b0..f43ec5a9216 100644
--- a/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc
+++ b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc
@@ -54,9 +54,9 @@ TEST(GpuMultiStream, Basics) {
   device_assignment(0, 0) = device->id();
   compile_options.executable_build_options.set_device_assignment(
       device_assignment);
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtExecutable> executable,
-                          PjRtExecutable::Compile(computation, client.get(),
-                                                  std::move(compile_options)));
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<PjRtExecutable> executable,
+      client->Compile(computation, std::move(compile_options)));
 
   int64 dummy_size = 1 << 20;
   std::vector<int32> dummy_inputs(dummy_size);
@@ -71,22 +71,22 @@ TEST(GpuMultiStream, Basics) {
     // must wait.
     TF_ASSERT_OK_AND_ASSIGN(
         auto dummy_buffer,
-        PjRtBuffer::FromHostBuffer(
+        client->BufferFromHostBuffer(
             dummy_inputs.data(), dummy_shape,
-            PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
-            /*buffer_reference=*/nullptr, client.get(), device));
+            PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
+            /*buffer_reference=*/nullptr, device));
     TF_ASSERT_OK_AND_ASSIGN(
         auto in_buffer0,
-        PjRtBuffer::FromHostBuffer(
+        client->BufferFromHostBuffer(
             inputs.data(), shape,
-            PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
-            /*buffer_reference=*/nullptr, client.get(), device));
+            PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
+            /*buffer_reference=*/nullptr, device));
     TF_ASSERT_OK_AND_ASSIGN(
         auto in_buffer1,
-        PjRtBuffer::FromHostBuffer(
+        client->BufferFromHostBuffer(
             inputs.data(), shape,
-            PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
-            /*buffer_reference=*/nullptr, client.get(), device));
+            PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
+            /*buffer_reference=*/nullptr, device));
     // The execution may be enqueued before the transfers complete, requiring
     // adequate device-side synchronization.
     ExecuteOptions options;
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc
index 02ae37b71db..41afcb01511 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc
+++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc
@@ -576,24 +576,21 @@ void PjRtBuffer::ScopedHold::AddToInput(
   }
 }
 
-/* static */
-StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
+StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
     const void* data, const Shape& shape,
     HostBufferSemantics host_buffer_semantics,
-    std::shared_ptr<void> buffer_reference, PjRtClient* client,
-    PjRtDevice* device) {
-  tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer");
-  VLOG(2) << "PjRtBuffer::FromHostBuffer: shape: " << shape.ToString()
+    std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
+  tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer");
+  VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString()
           << " device: " << device->DebugString();
   if (shape.IsTuple()) {
-    return InvalidArgument("Use FromHostLiteral to transfer a tuple");
+    return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
   }
   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
                       device->GetLocalDeviceState());
   int64 size = ShapeUtil::ByteSizeOf(shape);
 
-  TransferManager* transfer_manager =
-      client->client()->backend().transfer_manager();
+  TransferManager* transfer_manager = client()->backend().transfer_manager();
   TF_ASSIGN_OR_RETURN(Shape compact_shape,
                       transfer_manager->ChooseCompactLayoutForShape(shape));
 
@@ -628,10 +625,11 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
         };
         buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
       } else {
-        void* staging_buffer = client->host_memory_allocator()->AllocateRaw(
+        void* staging_buffer = host_memory_allocator()->AllocateRaw(
             cpu_function_runtime::kMinAlign, size);
-        on_delete_callback = [staging_buffer, client]() {
-          client->host_memory_allocator()->DeallocateRaw(staging_buffer);
+        on_delete_callback = [staging_buffer, host_memory_allocator =
+                                                  host_memory_allocator()]() {
+          host_memory_allocator->DeallocateRaw(staging_buffer);
         };
         buffer = se::DeviceMemoryBase(staging_buffer, size);
         std::memcpy(staging_buffer, data, size);
@@ -643,7 +641,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
           std::initializer_list<se::DeviceMemoryBase>{buffer},
           definition_events, std::move(on_delete_callback));
       return absl::make_unique<PjRtBuffer>(
-          shape, shape, std::move(device_buffer), client, device);
+          shape, shape, std::move(device_buffer), this, device);
     }
   }
 
@@ -651,21 +649,22 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
       std::unique_ptr<PjRtBuffer> py_buffer,
       AllocateDestinationBuffer(compact_shape, device, local_device,
                                 local_device->host_to_device_stream(),
-                                /*is_uninitialized_create=*/false, client));
+                                /*is_uninitialized_create=*/false, this));
 
-  ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
+  PjRtBuffer::ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
   CHECK(device_buffer.ok());
 
   // If necessary, allocate a host-side buffer for staging host-to-device
   // transfers. On GPU this is a buffer in pinned memory.
   std::shared_ptr<void> staging_buffer;
   if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
-      client->should_stage_host_to_device_transfers()) {
-    void* ptr = client->host_memory_allocator()->AllocateRaw(
+      should_stage_host_to_device_transfers()) {
+    void* ptr = host_memory_allocator()->AllocateRaw(
         tensorflow::Allocator::kAllocatorAlignment, size);
-    staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
-      client->host_memory_allocator()->DeallocateRaw(ptr);
-    });
+    staging_buffer = std::shared_ptr<void>(
+        ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) {
+          host_memory_allocator->DeallocateRaw(ptr);
+        });
   }
 
   // Copy the buffer into a staging buffer before returning control to the
@@ -684,14 +683,15 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
   // usage holds have gone away.
   // TODO(misard) assess if it would be preferable to introduce a heuristic to
   // put the transfer into the calling thread for small literals.
-  auto transfer_h2d = [client, transfer_manager, local_device, data, size,
+  auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
+                       data, size,
                        movable_device_buffer{device_buffer.ToClosure()}, shape,
                        py_buffer{py_buffer.get()}, compact_shape,
                        on_device_shape{py_buffer->on_device_shape()},
                        staging_buffer{std::move(staging_buffer)},
                        buffer_reference{std::move(buffer_reference)},
                        host_buffer_semantics]() {
-    ScopedHold device_buffer(movable_device_buffer);
+    PjRtBuffer::ScopedHold device_buffer(movable_device_buffer);
     // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
     // to report failures from a callback. However, the operations here are
     // unlikely to fail and not recoverable even if we were to fail: DMAs to
@@ -699,7 +699,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
     // allocation.
 
     ShapedBuffer buffer = device_buffer->AsShapedBuffer(
-        compact_shape, on_device_shape, client->client()->platform());
+        compact_shape, on_device_shape, local_client->platform());
     // If applicable on the backend, stage the transfer via host memory
     // allocated via the host_memory_allocator. On GPU, this is pinned
     // memory.
@@ -736,41 +736,38 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
     // already defers its work onto a stream (= thread on CPU).
     transfer_h2d();
   } else {
-    client->h2d_transfer_pool()->Schedule(transfer_h2d);
+    h2d_transfer_pool()->Schedule(transfer_h2d);
   }
   return py_buffer;
 }
 
-/* static */
-StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CreateUninitialized(
-    const Shape& shape, PjRtClient* client, PjRtDevice* device) {
-  tensorflow::profiler::TraceMe traceme("PjRtBuffer::CreateUninitialized");
-  VLOG(2) << "PjRtBuffer::CreateUninitialized: shape: " << shape.ToString()
-          << " device: " << device->DebugString();
+StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
+    const Shape& shape, PjRtDevice* device) {
+  tensorflow::profiler::TraceMe traceme(
+      "PjRtClient::CreateUninitializedBuffer");
+  VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
+          << shape.ToString() << " device: " << device->DebugString();
   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
                       device->GetLocalDeviceState());
 
-  TransferManager* transfer_manager =
-      client->client()->backend().transfer_manager();
+  TransferManager* transfer_manager = client()->backend().transfer_manager();
   TF_ASSIGN_OR_RETURN(Shape compact_shape,
                       transfer_manager->ChooseCompactLayoutForShape(shape));
 
   return AllocateDestinationBuffer(compact_shape, device, local_device,
                                    /*copy_stream=*/nullptr,
-                                   /*is_uninitialized_create=*/true, client);
+                                   /*is_uninitialized_create=*/true, this);
 }
 
-/* static */
-StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
-    const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device) {
-  tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostLiteral");
-  VLOG(2) << "PjRtBuffer::FromHostLiteral: shape: "
+StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
+    const LiteralSlice& literal, PjRtDevice* device) {
+  tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral");
+  VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: "
           << literal.shape().ToString() << " device: " << device->DebugString();
   TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
                       device->GetLocalDeviceState());
 
-  TransferManager* transfer_manager =
-      client->client()->backend().transfer_manager();
+  TransferManager* transfer_manager = client()->backend().transfer_manager();
   TF_ASSIGN_OR_RETURN(
       Shape compact_shape,
       transfer_manager->ChooseCompactLayoutForShape(literal.shape()));
@@ -778,9 +775,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
       std::unique_ptr<PjRtBuffer> py_buffer,
       AllocateDestinationBuffer(compact_shape, device, local_device,
                                 local_device->host_to_device_stream(),
-                                /*is_uninitialized_create=*/false, client));
+                                /*is_uninitialized_create=*/false, this));
 
-  ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
+  PjRtBuffer::ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
   CHECK(device_buffer.ok());
 
   // The host to device transfer is performed on a thread pool, mostly because
@@ -789,11 +786,11 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
   // usage holds have gone away.
   // TODO(misard) assess if it would be preferable to introduce a heuristic to
   // put the transfer into the calling thread for small literals.
-  auto transfer_h2d = [client, transfer_manager, local_device,
+  auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
                        movable_device_buffer{device_buffer.ToClosure()},
                        literal, py_buffer{py_buffer.get()}, compact_shape,
                        on_device_shape{py_buffer->on_device_shape()}]() {
-    ScopedHold device_buffer(movable_device_buffer);
+    PjRtBuffer::ScopedHold device_buffer(movable_device_buffer);
     // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
     // to report failures from a callback. However, the operations here are
     // unlikely to fail and not recoverable even if we were to fail: DMAs to
@@ -802,7 +799,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
 
     se::Stream* h2d_stream = local_device->host_to_device_stream();
     ShapedBuffer buffer = device_buffer->AsShapedBuffer(
-        compact_shape, on_device_shape, client->client()->platform());
+        compact_shape, on_device_shape, local_client->platform());
     TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
         h2d_stream, literal, buffer));
 
@@ -817,12 +814,12 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
         .IgnoreError();  // Can return error::Unimplemented
     QCHECK(h2d_stream->ok());
   };
-  client->h2d_transfer_pool()->Schedule(transfer_h2d);
+  h2d_transfer_pool()->Schedule(transfer_h2d);
   return py_buffer;
 }
 
-/*static*/ void PjRtBuffer::MakeCrossHostReceiveBuffers(
-    absl::Span<const Shape> shapes, PjRtClient* client, PjRtDevice* device,
+void PjRtClient::MakeCrossHostReceiveBuffers(
+    absl::Span<const Shape> shapes, PjRtDevice* device,
     PjRtCrossHostRecvNotifier&& notifier) {
   if (shapes.empty()) {
     notifier(InvalidArgument(
@@ -843,7 +840,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
     StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or =
         AllocateDestinationBuffer(shape, device, local_device,
                                   /*copy_stream=*/nullptr,
-                                  /*is_uninitialized_create=*/false, client);
+                                  /*is_uninitialized_create=*/false, this);
     if (!buffer_or.ok()) {
       notifier(buffer_or.status());
       return;
@@ -851,7 +848,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
     buffers.push_back(buffer_or.ConsumeValueOrDie());
   }
 
-  client->EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
+  EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
 }
 
 PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
@@ -1159,7 +1156,7 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
 
 StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
     const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
-  tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral");
+  tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral");
   TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
                       CopyToHostAsyncInternal(discard_cached_copy, layout));
   if (host_value == nullptr) {
@@ -1267,9 +1264,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
   // Copying across PjRtClients involves a copy through the host.
   if (dst_device->client() != client_) {
     TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
-    return FromHostBuffer(literal->untyped_data(), literal->shape(),
-                          HostBufferSemantics::kZeroCopy, nullptr,
-                          dst_device->client(), dst_device);
+    return dst_device->client()->BufferFromHostBuffer(
+        literal->untyped_data(), literal->shape(),
+        PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device);
   }
 
   TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device,
@@ -2061,14 +2058,13 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
 
 }  // namespace
 
-/*static*/ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtExecutable::Compile(
-    const XlaComputation& computation, PjRtClient* client,
-    CompileOptions options) {
-  tensorflow::profiler::TraceMe traceme("LocalExecutable::Compile");
+StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
+    const XlaComputation& computation, CompileOptions options) {
+  tensorflow::profiler::TraceMe traceme("PjRtClient::Compile");
 
   ExecutableBuildOptions& build_options = options.executable_build_options;
   if (!build_options.device_allocator()) {
-    build_options.set_device_allocator(client->allocator());
+    build_options.set_device_allocator(allocator());
   }
 
   int num_replicas;
@@ -2084,14 +2080,14 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
     num_partitions = 1;
   } else {
     if (!build_options.has_device_assignment()) {
-      VLOG(2) << "PjRtExecutable::Compile using default device_assignment.";
+      VLOG(2) << "PjRtClient::Compile using default device_assignment.";
       TF_ASSIGN_OR_RETURN(
           DeviceAssignment device_assignment,
-          client->GetDefaultDeviceAssignment(build_options.num_replicas(),
-                                             build_options.num_partitions()));
+          GetDefaultDeviceAssignment(build_options.num_replicas(),
+                                     build_options.num_partitions()));
       build_options.set_device_assignment(device_assignment);
     }
-    VLOG(2) << "PjRtExecutable::Compile device_assignment:\n"
+    VLOG(2) << "PjRtClient::Compile device_assignment:\n"
             << build_options.device_assignment().ToString();
     num_replicas = build_options.device_assignment().replica_count();
     num_partitions = build_options.device_assignment().computation_count();
@@ -2118,7 +2114,8 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
 
   // Assign a default layout based on `sharded_shape` to any array subshapes in
   // `dst_shape` that are missing layouts.
-  auto assign_layouts = [client](const Shape& sharded_shape, Shape* dst_shape) {
+  auto assign_layouts = [local_client = client()](const Shape& sharded_shape,
+                                                  Shape* dst_shape) {
     return ShapeUtil::ForEachMutableSubshapeWithStatus(
         dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
           if (subshape->IsArray() && !subshape->has_layout()) {
@@ -2126,8 +2123,7 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
             const Shape& sharded_subshape =
                 ShapeUtil::GetSubshape(sharded_shape, idx);
             LayoutUtil::SetToDefaultLayout(subshape);
-            TF_ASSIGN_OR_RETURN(Shape layout, client->client()
-                                                  ->backend()
+            TF_ASSIGN_OR_RETURN(Shape layout, local_client->backend()
                                                   .transfer_manager()
                                                   ->ChooseCompactLayoutForShape(
                                                       sharded_subshape));
@@ -2162,8 +2158,8 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
     for (int replica = 0; replica < num_replicas; ++replica) {
       for (int partition = 0; partition < num_partitions; ++partition) {
         int device_id = (*device_assignment)(replica, partition);
-        PjRtDevice* device = LookupDevice(*client, device_id);
-        if (device->host_id() != client->host_id()) {
+        PjRtDevice* device = LookupDevice(*this, device_id);
+        if (device->host_id() != host_id()) {
           VLOG(3) << "Non-local device: " << device_id;
           continue;
         }
@@ -2185,15 +2181,14 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
 
   TF_ASSIGN_OR_RETURN(
       std::vector<std::unique_ptr<LocalExecutable>> local_executables,
-      client->client()->Compile(computation, argument_layout_pointers,
-                                build_options));
+      client()->Compile(computation, argument_layout_pointers, build_options));
 
   auto executable = absl::make_unique<PjRtExecutable>(
       std::move(local_executables), options.parameter_is_tupled_arguments,
       std::move(device_assignment), std::move(local_logical_device_ids),
-      std::move(local_devices), client);
+      std::move(local_devices), this);
   TF_RETURN_IF_ERROR(
-      executable->SetUpDonation(client, options.parameter_is_tupled_arguments));
+      executable->SetUpDonation(this, options.parameter_is_tupled_arguments));
   return executable;
 }
 
diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h
index cb4ef9da85b..c10470f7d60 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_client.h
+++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h
@@ -120,6 +120,24 @@ struct PjRtCrossHostRecvBuffer {
 using PjRtCrossHostRecvNotifier =
     std::function<void(StatusOr<std::vector<PjRtCrossHostRecvBuffer>>&&)>;
 
+struct CompileOptions {
+  // The layouts of the arguments that the computation should expect.
+  absl::optional<std::vector<Shape>> argument_layouts;
+
+  // If true, the supplied computation expects its arguments to be wrapped in a
+  // tuple and passed as a single parameter.
+  bool parameter_is_tupled_arguments = false;
+
+  // XLA's compilation time options.
+  ExecutableBuildOptions executable_build_options;
+
+  // If true, the executable can be run on any device. May only be true if
+  // !executable_build_options.has_device_assignment(), so only applies to
+  // single-device executables. Beware: on GPUs, sometimes an executable
+  // compiled for one device doesn't run on another.
+  bool compile_portable_executable = false;
+};
+
 class PjRtExecutable;
 
 // Encapsulates the state of Python session with XLA.
@@ -198,6 +216,63 @@ class PjRtClient {
   // Returns a backend-specific HLO cost analysis visitor.
   virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis();
 
+  virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
+      const XlaComputation& computation, CompileOptions options);
+
+  virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
+      const Shape& shape, PjRtDevice* device);
+
+  // Describes the semantics the caller to BufferFromHostBuffer expects from the
+  // runtime, in a total order from most restrictive to least restrictive.
+  enum class HostBufferSemantics {
+    // The runtime may not hold references to `data` after the call to
+    // `BufferFromHostBuffer` completes. The caller promises that `data` is
+    // immutable and will not be freed only for the duration of the
+    // BufferFromHostBuffer call. `buffer_reference` will be freed by the time
+    // `BufferFromHostBuffer` returns.
+    kImmutableOnlyDuringCall,
+
+    // The runtime may hold onto `data` after the call to `BufferFromHostBuffer`
+    // returns while the runtime completes a transfer to the device. The caller
+    // promises not to mutate or free `data` until the transfer completes, at
+    // which point the runtime will release `buffer_reference`. It is also
+    // correct to wait on the host (directly or indirectly) for the buffer's
+    // definition event to complete.
+    kImmutableUntilTransferCompletes,
+
+    // The PjRtBuffer may alias `data` internally and the runtime may use the
+    // `data` contents as long as the buffer is alive. The caller promises to
+    // keep `data` alive and not to mutate its contents as long as the buffer is
+    // alive; to notify the caller that the buffer may be freed, the runtime
+    // will release its `buffer_reference` when the PjRtBuffer is freed. On
+    // non-CPU platforms this acts identically to
+    // kImmutableUntilTransferCompletes.
+    kZeroCopy,
+  };
+  virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
+      const void* data, const Shape& shape,
+      HostBufferSemantics host_buffer_semantics,
+      std::shared_ptr<void> buffer_reference, PjRtDevice* device);
+
+  // Note that literal must remain in scope until the transfer has completed, so
+  // the caller should, for example, wait for BlockHostUntilReady() completes on
+  // the return value before letting literal go out of scope.
+  virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
+      const LiteralSlice& literal, PjRtDevice* device);
+
+  // Asynchronously makes a vector of PjRtBuffers that can be used to receive
+  // cross host transfers using `client` on `device'. `shapes` must be the exact
+  // shapes, with identical layouts, corresponding to the buffers that will be
+  // sent. When resources for the transfer are available, notifier will be
+  // called with a vector of PjRtCrossHostRecvBuffer structs, one for each
+  // shape in `shapes`. Each struct contains a buffer that will contain the
+  // received value, and an opaque string that should be transmitted to the
+  // sending host and used in a call to CopyToRemoteDevice. None of the recv
+  // buffers will become ready until *all* of the sends have completed.
+  virtual void MakeCrossHostReceiveBuffers(
+      absl::Span<const Shape> shapes, PjRtDevice* device,
+      PjRtCrossHostRecvNotifier&& notifier);
+
  protected:
   friend class PjRtBuffer;
   virtual void EnqueueCrossHostReceive(
@@ -385,6 +460,7 @@ class PjRtBuffer {
 
    private:
     friend class PjRtBuffer;
+    friend class PjRtClient;
 
     // Helper struct that makes it possible to move a ScopedHold through a
     // closure.
@@ -423,62 +499,6 @@ class PjRtBuffer {
     StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
   };
 
-  // Returns a buffer with uninitialized contents.
-  static StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitialized(
-      const Shape& shape, PjRtClient* client, PjRtDevice* device);
-
-  // Describes the semantics the caller to FromHostBuffer expects from the
-  // runtime, in a total order from most restrictive to least restrictive.
-  enum class HostBufferSemantics {
-    // The runtime may not hold references to `data` after the call to
-    // `FromHostBuffer` completes. The caller promises that `data` is immutable
-    // and will not be freed only for the duration of the FromHostBuffer call.
-    // `buffer_reference` will be freed by the time `FromHostBuffer` returns.
-    kImmutableOnlyDuringCall,
-
-    // The runtime may hold onto `data` after the call to `FromHostBuffer`
-    // returns while the runtime completes a transfer to the device. The caller
-    // promises not to mutate or free `data` until the transfer completes, at
-    // which point the runtime will release `buffer_reference`. It is also
-    // correct to wait on the host (directly or indirectly) for the buffer's
-    // definition event to complete.
-    kImmutableUntilTransferCompletes,
-
-    // The PjRtBuffer may alias `data` internally and the runtime may use the
-    // `data` contents as long as the buffer is alive.
-    // The caller promises to keep `data` alive and not to mutate its contents
-    // as long as the buffer is alive; to notify the caller that the buffer may
-    // be freed, the runtime will release its `buffer_reference` when the
-    // PjRtBuffer is freed. On non-CPU platforms this acts identically to
-    // kImmutableUntilTransferCompletes.
-    kZeroCopy,
-  };
-  static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostBuffer(
-      const void* data, const Shape& shape,
-      HostBufferSemantics host_buffer_semantics,
-      std::shared_ptr<void> buffer_reference, PjRtClient* client,
-      PjRtDevice* device);
-
-  // Note that literal must remain in scope until the transfer has completed, so
-  // the caller should, for example, wait for BlockHostUntilReady() completes on
-  // the return value before letting literal go out of scope.
-  static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostLiteral(
-      const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device);
-
-  // Asynchronously makes a vector of PjRtBuffers that can be used to receive
-  // cross host transfers using `client` on `device'. `shapes` must be the exact
-  // shapes, with identical layouts, corresponding to the buffers that will be
-  // sent. When resources for the transfer are available, notifier will be
-  // called with a vector of PjRtCrossHostRecvBuffer structs, one for each
-  // shape in `shapes`. Each struct contains a buffer that will contain the
-  // received value, and an opaque string that should be transmitted to the
-  // sending host and used in a call to CopyToRemoteDevice. None of the recv
-  // buffers will become ready until *all* of the sends have completed.
-  static void MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
-                                          PjRtClient* client,
-                                          PjRtDevice* device,
-                                          PjRtCrossHostRecvNotifier&& notifier);
-
   PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
              std::shared_ptr<TrackedDeviceBuffer> device_buffer,
              PjRtClient* client, PjRtDevice* device);
@@ -661,24 +681,6 @@ class PjRtBuffer {
   Semaphore donation_semaphore_;
 };
 
-struct CompileOptions {
-  // The layouts of the arguments that the computation should expect.
-  absl::optional<std::vector<Shape>> argument_layouts;
-
-  // If true, the supplied computation expects its arguments to be wrapped in a
-  // tuple and passed as a single parameter.
-  bool parameter_is_tupled_arguments = false;
-
-  // XLA's compilation time options.
-  ExecutableBuildOptions executable_build_options;
-
-  // If true, the executable can be run on any device. May only be true if
-  // !executable_build_options.has_device_assignment(), so only applies to
-  // single-device executables. Beware: on GPUs, sometimes an executable
-  // compiled for one device doesn't run on another.
-  bool compile_portable_executable = false;
-};
-
 class ExecuteContext {
  public:
   virtual ~ExecuteContext() = default;
@@ -710,10 +712,6 @@ struct ExecuteOptions {
 // buffer will be donated when passed to the execution.
 class PjRtExecutable {
  public:
-  static StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
-      const XlaComputation& computation, PjRtClient* client,
-      CompileOptions options);
-
   PjRtExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,
                  bool parameter_is_tupled_arguments,
                  std::shared_ptr<DeviceAssignment> device_assignment,
@@ -783,6 +781,7 @@ class PjRtExecutable {
   }
 
  private:
+  friend class PjRtClient;
   // Initializes information about which arguments to which executables must be
   // donated due to aliases that were specified by the computation.
   Status SetUpDonation(PjRtClient* client, bool tuple_inputs);
diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc
index 944b4c20a8a..f4202045a66 100644
--- a/tensorflow/compiler/xla/python/jax_jit.cc
+++ b/tensorflow/compiler/xla/python/jax_jit.cc
@@ -465,10 +465,10 @@ std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(
     xla::PjRtDevice* device) {
   CppType data = py::cast<Pybind11Type>(scalar);
   xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<CppType>({});
-  return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer(
+  return ValueOrThrow(client->BufferFromHostBuffer(
       &data, shape,
-      xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
-      client, device));
+      xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
+      device));
 }
 
 // Convert a scalar to the associated PjRtBuffer or raises an error if it is
@@ -502,17 +502,17 @@ StatusOr<std::unique_ptr<xla::PjRtBuffer>> ScalarToBuffer(
     if (jax_enable_x64) {
       xla::complex128 data(result.real, result.imag);
       xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex128>({});
-      return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer(
+      return ValueOrThrow(client->BufferFromHostBuffer(
           &data, shape,
-          xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall,
-          nullptr, client, device));
+          xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
+          nullptr, device));
     } else {
       xla::complex64 data(result.real, result.imag);
       xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex64>({});
-      return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer(
+      return ValueOrThrow(client->BufferFromHostBuffer(
           &data, shape,
-          xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall,
-          nullptr, client, device));
+          xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
+          nullptr, device));
     }
   }
   return InvalidArgument(
@@ -678,7 +678,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
           ValueOrThrow(pyclient.BufferFromPyval(
               numpy_array, data_device,
               /*force_copy=*/false, /*host_buffer_semantics=*/
-              xla::PjRtBuffer::HostBufferSemantics::kZeroCopy));
+              xla::PjRtClient::HostBufferSemantics::kZeroCopy));
       arg_buffers.push_back(buffer->buffer());
 
       ArgSignature sig;
diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.cc b/tensorflow/compiler/xla/python/outfeed_receiver.cc
index f6067e650c0..2535d62ee7e 100644
--- a/tensorflow/compiler/xla/python/outfeed_receiver.cc
+++ b/tensorflow/compiler/xla/python/outfeed_receiver.cc
@@ -409,10 +409,9 @@ Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
   compile_options.executable_build_options.set_device_assignment(
       device_assignment);
 
-  TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<PjRtExecutable> executable,
-      PjRtExecutable::Compile(computation, devices_[device_idx]->client(),
-                              std::move(compile_options)));
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
+                      devices_[device_idx]->client()->Compile(
+                          computation, std::move(compile_options)));
   ExecuteOptions execute_options;
   TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
                       executable->Execute({}, execute_options));
diff --git a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc
index 919dafe2e0b..5422a4b3056 100644
--- a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc
+++ b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc
@@ -40,9 +40,8 @@ Status CompileAndExecute(XlaBuilder* builder, XlaOp root, int device_id,
   compile_options.executable_build_options.set_device_assignment(
       device_assignment);
 
-  TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<PjRtExecutable> executable,
-      PjRtExecutable::Compile(computation, client, std::move(compile_options)));
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
+                      client->Compile(computation, std::move(compile_options)));
   ExecuteOptions execute_options;
   TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
                       executable->Execute({}, execute_options));
diff --git a/tensorflow/compiler/xla/python/py_client.cc b/tensorflow/compiler/xla/python/py_client.cc
index 07b915c640c..d42bbdca154 100644
--- a/tensorflow/compiler/xla/python/py_client.cc
+++ b/tensorflow/compiler/xla/python/py_client.cc
@@ -89,7 +89,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
 
 StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
     const pybind11::object& argument, PjRtDevice* device, bool force_copy,
-    PjRtBuffer::HostBufferSemantics host_buffer_semantics) {
+    PjRtClient::HostBufferSemantics host_buffer_semantics) {
   if (device == nullptr) {
     TF_RET_CHECK(!pjrt_client_->local_devices().empty());
     device = pjrt_client_->local_devices().front();
@@ -114,10 +114,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
   std::unique_ptr<PjRtBuffer> buffer;
   {
     py::gil_scoped_release gil_release;
-    TF_ASSIGN_OR_RETURN(
-        buffer, PjRtBuffer::FromHostBuffer(
-                    c->buf_ptr, c->shape, host_buffer_semantics,
-                    std::move(py_buffer_ref), pjrt_client_.get(), device));
+    TF_ASSIGN_OR_RETURN(buffer, pjrt_client_->BufferFromHostBuffer(
+                                    c->buf_ptr, c->shape, host_buffer_semantics,
+                                    std::move(py_buffer_ref), device));
   }
   auto traceback = Traceback::Get();
   return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
@@ -131,8 +130,7 @@ StatusOr<std::shared_ptr<PyExecutable>> PyClient::Compile(
   {
     py::gil_scoped_release gil_release;
     TF_ASSIGN_OR_RETURN(executable,
-                        PjRtExecutable::Compile(computation, pjrt_client_.get(),
-                                                std::move(options)));
+                        pjrt_client_->Compile(computation, std::move(options)));
     TF_ASSIGN_OR_RETURN(fingerprint,
                         pjrt_client_->ExecutableFingerprint(*executable));
   }
diff --git a/tensorflow/compiler/xla/python/py_client.h b/tensorflow/compiler/xla/python/py_client.h
index 08249722d6c..224f8278bb1 100644
--- a/tensorflow/compiler/xla/python/py_client.h
+++ b/tensorflow/compiler/xla/python/py_client.h
@@ -123,7 +123,7 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
 
   StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
       const pybind11::object& argument, PjRtDevice* device, bool force_copy,
-      PjRtBuffer::HostBufferSemantics host_buffer_semantics);
+      PjRtClient::HostBufferSemantics host_buffer_semantics);
 
   StatusOr<std::shared_ptr<PyExecutable>> Compile(
       const XlaComputation& computation, CompileOptions options);
diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc
index b84dfa92e47..b0948fab2b7 100644
--- a/tensorflow/compiler/xla/python/xla.cc
+++ b/tensorflow/compiler/xla/python/xla.cc
@@ -535,12 +535,12 @@ PYBIND11_MODULE(xla_extension, m) {
       .value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
       .value("BFC", GpuAllocatorConfig::Kind::kBFC);
 
-  py::enum_<PjRtBuffer::HostBufferSemantics>(m, "HostBufferSemantics")
+  py::enum_<PjRtClient::HostBufferSemantics>(m, "HostBufferSemantics")
       .value("IMMUTABLE_ONLY_DURING_CALL",
-             PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall)
+             PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall)
       .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
-             PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes)
-      .value("ZERO_COPY", PjRtBuffer::HostBufferSemantics::kZeroCopy);
+             PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes)
+      .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kZeroCopy);
 
   py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
   py_local_client.def_property_readonly("platform", &PyClient::platform_name)
@@ -562,7 +562,7 @@ PYBIND11_MODULE(xla_extension, m) {
       .def("buffer_from_pyval", &PyClient::BufferFromPyval, py::arg("argument"),
            py::arg("device") = nullptr, py::arg("force_copy") = false,
            py::arg("host_buffer_semantics") =
-               PjRtBuffer::HostBufferSemantics::kZeroCopy)
+               PjRtClient::HostBufferSemantics::kZeroCopy)
       .def("compile", &PyClient::Compile, py::arg("computation"),
            py::arg("compile_options") = CompileOptions())
       .def("heap_profile", &PyClient::HeapProfile);