diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index cdbe69d617e..c01f906fe85 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -140,9 +140,9 @@ tf_cc_test(
 )
 
 cc_library(
-    name = "device_state",
-    srcs = ["device_state.cc"],
-    hdrs = ["device_state.h"],
+    name = "local_device_state",
+    srcs = ["local_device_state.cc"],
+    hdrs = ["local_device_state.h"],
     deps = [
         ":event_pool",
         ":semaphore",
@@ -161,7 +161,7 @@ cc_library(
     srcs = ["local_client.cc"],
     hdrs = ["local_client.h"],
     deps = [
-        ":device_state",
+        ":local_device_state",
         ":shared_device_buffer",
         "//tensorflow/compiler/xla:executable_run_options",
         "//tensorflow/compiler/xla:literal",
diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc
index ef8ff4275a6..237f10c39ae 100644
--- a/tensorflow/compiler/xla/python/local_client.cc
+++ b/tensorflow/compiler/xla/python/local_client.cc
@@ -105,6 +105,13 @@ limitations under the License.
 
 namespace xla {
 
+StatusOr<LocalDeviceState*> Device::GetLocalDeviceState() const {
+  if (local_device_state_) {
+    return local_device_state_.get();
+  }
+  return InvalidArgument("Device %s is not a local device.", DebugString());
+}
+
 std::string CpuDevice::DebugString() const {
   return absl::StrCat("CPU_", id());
 }
@@ -115,7 +122,7 @@ std::string GpuDevice::DebugString() const {
 
 static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
     se::Platform* platform,
-    absl::Span<const std::unique_ptr<DeviceState>> device_states,
+    absl::Span<const std::shared_ptr<Device>> local_devices,
     LocalClient* client, double memory_fraction, bool preallocate) {
   CHECK_GT(client->backend().device_count(), 0);
   std::vector<se::MultiDeviceAdapter::AllocatorWithStream> allocators;
@@ -148,19 +155,24 @@ static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
         /*allow_growth=*/!preallocate,
         absl::StrCat("GPU_", device_ordinal, "_bfc"));
     allocators.emplace_back(std::move(gpu_bfc_allocator),
-                            device_states.at(device_ordinal)->compute_stream());
+                            local_devices.at(device_ordinal)
+                                ->local_device_state()
+                                ->compute_stream());
   }
   return absl::make_unique<se::MultiDeviceAdapter>(platform,
                                                    std::move(allocators));
 }
 
-static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
-                                          int id, int local_device_ordinal) {
+static std::shared_ptr<Device> MakeDevice(
+    const std::string& platform_name, int id,
+    std::unique_ptr<LocalDeviceState> local_device_state) {
   if (platform_name == "cpu") {
-    return std::make_shared<CpuDevice>(id, local_device_ordinal, platform_name);
+    return std::make_shared<CpuDevice>(id, std::move(local_device_state),
+                                       platform_name);
   } else {
     CHECK_EQ(platform_name, "gpu");
-    return std::make_shared<GpuDevice>(id, local_device_ordinal, platform_name);
+    return std::make_shared<GpuDevice>(id, std::move(local_device_state),
+                                       platform_name);
   }
 }
 
@@ -179,16 +191,15 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
                       ClientLibrary::GetOrCreateLocalClient(options));
 
   bool gpu_platform = platform_name == "gpu";
-  std::vector<std::unique_ptr<DeviceState>> device_states;
   std::vector<std::shared_ptr<Device>> devices;
   bool synchronous_deallocation = platform_name == "cpu";
   for (int i = 0; i < client->device_count(); ++i) {
     se::StreamExecutor* executor =
         client->backend().stream_executor(i).ValueOrDie();
-    device_states.push_back(absl::make_unique<DeviceState>(
+    auto device_state = absl::make_unique<LocalDeviceState>(
         executor, synchronous_deallocation, asynchronous,
-        /*allow_event_reuse=*/gpu_platform));
-    devices.push_back(MakeDevice(platform_name, i, i));
+        /*allow_event_reuse=*/gpu_platform);
+    devices.push_back(MakeDevice(platform_name, i, std::move(device_state)));
   }
 
   std::unique_ptr<se::DeviceMemoryAllocator> allocator;
@@ -196,7 +207,7 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
   if (gpu_platform) {
     if (allocator_config.kind != AllocatorConfig::Kind::kPlatform) {
       TF_ASSIGN_OR_RETURN(allocator,
-                          CreateBFCAllocator(platform, device_states, client,
+                          CreateBFCAllocator(platform, devices, client,
                                              allocator_config.memory_fraction,
                                              allocator_config.preallocate));
     }
@@ -217,21 +228,18 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
 
   return std::make_shared<PyLocalClient>(
       platform_name, client, std::move(devices), /*host_id=*/0,
-      std::move(device_states), std::move(allocator),
-      std::move(host_memory_allocator));
+      std::move(allocator), std::move(host_memory_allocator));
 }
 
 PyLocalClient::PyLocalClient(
     std::string platform_name, LocalClient* client,
     std::vector<std::shared_ptr<Device>> devices, int host_id,
-    std::vector<std::unique_ptr<DeviceState>> device_states,
     std::unique_ptr<se::DeviceMemoryAllocator> allocator,
     std::unique_ptr<tensorflow::Allocator> host_memory_allocator)
     : platform_name_(std::move(platform_name)),
       client_(client),
       devices_(std::move(devices)),
       host_id_(host_id),
-      device_states_(std::move(device_states)),
       owned_allocator_(std::move(allocator)),
       host_memory_allocator_(std::move(host_memory_allocator)),
       h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
@@ -242,15 +250,16 @@ PyLocalClient::PyLocalClient(
     allocator_ = client_->backend().memory_allocator();
   }
 
-  local_devices_.resize(device_states_.size());
   for (const std::shared_ptr<Device>& device : devices_) {
     CHECK(id_to_device_.insert({device->id(), device}).second)
         << "Duplicate device id: " << device->id();
 
-    if (device->local_device_ordinal() != -1) {
-      int idx = device->local_device_ordinal();
+    if (device->local_device_state()) {
+      int idx = device->local_device_state()->device_ordinal();
+      if (idx >= local_devices_.size()) {
+        local_devices_.resize(idx + 1);
+      }
       CHECK(local_devices_[idx] == nullptr) << idx;
-      CHECK_LT(idx, local_devices_.size());
       local_devices_[idx] = device;
     }
   }
@@ -274,17 +283,19 @@ PyLocalClient::DeserializeExecutable(
 }
 
 Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal,
-                                       int device_ordinal) {
-  TF_RETURN_IF_ERROR(
-      CheckDeviceOrdinal(device_ordinal, "PyLocalClient::TransferToInfeed"));
-  return client_->TransferToInfeedLocal(literal, device_ordinal);
+                                       std::shared_ptr<Device> device) {
+  TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
+                      device->GetLocalDeviceState());
+  return client_->TransferToInfeedLocal(literal,
+                                        local_device->device_ordinal());
 }
 
-StatusOr<Literal> PyLocalClient::TransferFromOutfeed(const Shape& shape,
-                                                     int device_ordinal) {
-  TF_RETURN_IF_ERROR(
-      CheckDeviceOrdinal(device_ordinal, "PyLocalClient::TransferFromOutfeed"));
-  return client_->TransferFromOutfeedLocal(shape, device_ordinal);
+StatusOr<Literal> PyLocalClient::TransferFromOutfeed(
+    const Shape& shape, std::shared_ptr<Device> device) {
+  TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
+                      device->GetLocalDeviceState());
+  return client_->TransferFromOutfeedLocal(shape,
+                                           local_device->device_ordinal());
 }
 
 StatusOr<DeviceAssignment> PyLocalClient::GetDefaultDeviceAssignment(
@@ -293,36 +304,26 @@ StatusOr<DeviceAssignment> PyLocalClient::GetDefaultDeviceAssignment(
       num_replicas, /*computation_count=*/1);
 }
 
-Status PyLocalClient::CheckDeviceOrdinal(int device_ordinal,
-                                         absl::string_view caller_name) {
-  if (device_ordinal < 0 || device_ordinal >= local_device_count()) {
-    return InvalidArgument(
-        "%s got bad device_ordinal: %d (num_local_devices=%d)", caller_name,
-        device_ordinal, local_device_count());
-  }
-  return Status::OK();
-}
-
 /* static */
 StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
     std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
     std::shared_ptr<void> leaves_reference,
-    std::shared_ptr<PyLocalClient> client, int device_ordinal) {
+    std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device) {
   tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals");
   VLOG(1) << "PyLocalBuffer::FromLiterals: shape: " << tuple_shape.ToString()
-          << " device ordinal: " << device_ordinal;
-  TF_RETURN_IF_ERROR(client->CheckDeviceOrdinal(device_ordinal,
-                                                "PyLocalBuffer::FromLiterals"));
-  DeviceState* device = &client->device_state(device_ordinal);
+          << " device: " << device->DebugString();
+  TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
+                      device->GetLocalDeviceState());
   TransferManager* transfer_manager =
       client->client()->backend().transfer_manager();
   se::DeviceMemoryAllocator* allocator = client->allocator();
   TF_ASSIGN_OR_RETURN(
       Shape compact_shape,
       transfer_manager->ChooseCompactLayoutForShape(tuple_shape));
-  TF_ASSIGN_OR_RETURN(ScopedShapedBuffer scoped_buffer,
-                      transfer_manager->AllocateScopedShapedBuffer(
-                          compact_shape, allocator, device_ordinal));
+  TF_ASSIGN_OR_RETURN(
+      ScopedShapedBuffer scoped_buffer,
+      transfer_manager->AllocateScopedShapedBuffer(
+          compact_shape, allocator, local_device->device_ordinal()));
 
   // Make the host to device stream wait for the newly allocated buffer to be
   // available on the compute stream. We schedule this wait synchronously; while
@@ -331,8 +332,9 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
   // computations that depend on this transfer being enqueued on the compute
   // stream.
   if (!transfer_manager->CanShapedBufferBeAccessedNow(
-          device->host_to_device_stream()->parent(), scoped_buffer)) {
-    device->host_to_device_stream()->ThenWaitFor(device->compute_stream());
+          local_device->host_to_device_stream()->parent(), scoped_buffer)) {
+    local_device->host_to_device_stream()->ThenWaitFor(
+        local_device->compute_stream());
   }
 
   std::shared_ptr<BufferDefinitionEvent> definition_event =
@@ -344,16 +346,15 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
   // TODO(makro): Use move capture once C++ 14 features are available.
   auto leaves = std::make_shared<std::vector<BorrowingLiteral>>(
       std::move(leaves_literals));
-  auto transfer_h2d = [client, transfer_manager, device, device_ordinal,
-                       device_buffer, compact_shape, leaves,
-                       leaves_reference]() {
+  auto transfer_h2d = [client, transfer_manager, local_device, device_buffer,
+                       compact_shape, leaves, leaves_reference]() {
     // This function uses TF_CHECK_OK and ValueOrDie() since we have no way to
     // report failures from a callback. However, the operations here are
     // unlikely to fail and not recoverable even if we were to fail: DMAs to
     // memory that has already been allocated, and a possible Event allocation.
     ShapedBuffer buffer = device_buffer->AsShapedBuffer(compact_shape);
     TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync(
-        device->host_to_device_stream(), buffer));
+        local_device->host_to_device_stream(), buffer));
     std::vector<std::shared_ptr<void>> staging_buffers;
     staging_buffers.reserve(leaves->size());
     auto it = leaves->begin();
@@ -363,7 +364,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
       ShapedBuffer leaf(
           indexed_shape.shape,
           transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
-          client->client()->platform(), device_ordinal);
+          client->client()->platform(), local_device->device_ordinal());
       leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {});
 
       // If applicable on the backend, stage the transfer via host memory
@@ -379,51 +380,53 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
         BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
                                  it->shape());
         TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
-            device->host_to_device_stream(), literal, leaf));
+            local_device->host_to_device_stream(), literal, leaf));
         staging_buffers.push_back(std::move(staging_buffer));
       } else {
         // Otherwise, just transfer the literal.
         TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
-            device->host_to_device_stream(), *it, leaf));
+            local_device->host_to_device_stream(), *it, leaf));
       }
       ++it;
     }
 
     EventPool::Handle event =
-        device->event_pool()
-            .ThenAllocateAndRecordEvent(device->host_to_device_stream())
+        local_device->event_pool()
+            .ThenAllocateAndRecordEvent(local_device->host_to_device_stream())
             .ValueOrDie();
 
     // Sets the buffer definition event. Note: this has the side effect of
     // unblocking any host threads that may have been waiting to consume the
     // buffer.
     device_buffer->definition_event()->SetDefinitionEvent(
-        std::move(event), device->host_to_device_stream());
+        std::move(event), local_device->host_to_device_stream());
 
-    if (device->synchronous_deallocation()) {
-      device->ThenRelease(device->host_to_device_stream(), device_buffer);
+    if (local_device->synchronous_deallocation()) {
+      local_device->ThenRelease(local_device->host_to_device_stream(),
+                                device_buffer);
     }
 
-    device->ThenRelease(
-        device->host_to_device_stream(),
+    local_device->ThenRelease(
+        local_device->host_to_device_stream(),
         std::make_pair(leaves_reference, std::move(staging_buffers)));
   };
   client->h2d_transfer_pool()->Schedule(transfer_h2d);
-  return absl::make_unique<PyLocalBuffer>(
-      compact_shape, std::move(device_buffer), std::move(client));
+  return absl::make_unique<PyLocalBuffer>(compact_shape,
+                                          std::move(device_buffer),
+                                          std::move(client), std::move(device));
 }
 
 /* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple(
     const std::vector<PyLocalBuffer*> buffers,
-    std::shared_ptr<PyLocalClient> client, int device_ordinal) {
-  TF_RETURN_IF_ERROR(
-      client->CheckDeviceOrdinal(device_ordinal, "PyLocalBuffer::MakeTuple"));
+    std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device) {
+  TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
+                      device->GetLocalDeviceState());
   std::vector<Shape> host_shapes;
   std::vector<std::shared_ptr<SharedDeviceBuffer>> device_buffers;
   host_shapes.reserve(buffers.size());
   device_buffers.reserve(buffers.size());
   for (const PyLocalBuffer* buffer : buffers) {
-    TF_RET_CHECK(buffer->device_ordinal() == device_ordinal);
+    TF_RET_CHECK(buffer->device().get() == device.get());
     std::shared_ptr<SharedDeviceBuffer> device_buffer = buffer->DeviceBuffer();
     if (!device_buffer) {
       return InvalidArgument(
@@ -436,45 +439,48 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
   se::DeviceMemoryAllocator* allocator = client->allocator();
   TransferManager* transfer_manager =
       client->client()->backend().transfer_manager();
-  DeviceState& device = client->device_state(device_ordinal);
 
   auto definition_event = std::make_shared<BufferDefinitionEvent>();
-  TF_ASSIGN_OR_RETURN(
-      std::shared_ptr<SharedDeviceBuffer> tuple_buffer,
-      SharedDeviceBuffer::MakeTuple(device_buffers, transfer_manager, allocator,
-                                    device_ordinal, definition_event));
+  TF_ASSIGN_OR_RETURN(std::shared_ptr<SharedDeviceBuffer> tuple_buffer,
+                      SharedDeviceBuffer::MakeTuple(
+                          device_buffers, transfer_manager, allocator,
+                          local_device->device_ordinal(), definition_event));
   auto buffer = absl::make_unique<PyLocalBuffer>(
-      ShapeUtil::MakeTupleShape(host_shapes), tuple_buffer, std::move(client));
+      ShapeUtil::MakeTupleShape(host_shapes), tuple_buffer, std::move(client),
+      std::move(device));
 
   // TODO(phawkins): extend TransferManager so we do not need to form a full
   // ShapedBuffer just to write the root tuple index table.
   TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer->AsShapedBuffer());
   if (!transfer_manager->CanShapedBufferBeAccessedNow(
-          device.host_to_device_stream()->parent(), shaped_buffer)) {
+          local_device->host_to_device_stream()->parent(), shaped_buffer)) {
     // Wait for the compute stream so that memory allocations are synchronized.
-    device.host_to_device_stream()->ThenWaitFor(device.compute_stream());
+    local_device->host_to_device_stream()->ThenWaitFor(
+        local_device->compute_stream());
   }
   TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
-      device.host_to_device_stream(), shaped_buffer));
+      local_device->host_to_device_stream(), shaped_buffer));
 
   TF_ASSIGN_OR_RETURN(EventPool::Handle event,
-                      device.event_pool().ThenAllocateAndRecordEvent(
-                          device.host_to_device_stream()));
+                      local_device->event_pool().ThenAllocateAndRecordEvent(
+                          local_device->host_to_device_stream()));
   definition_event->SetDefinitionEvent(std::move(event),
-                                       device.host_to_device_stream());
+                                       local_device->host_to_device_stream());
 
-  if (device.synchronous_deallocation()) {
-    device.ThenRelease(device.host_to_device_stream(), std::move(tuple_buffer));
+  if (local_device->synchronous_deallocation()) {
+    local_device->ThenRelease(local_device->host_to_device_stream(),
+                              std::move(tuple_buffer));
   }
   return buffer;
 }
 
 PyLocalBuffer::PyLocalBuffer(Shape on_host_shape,
                              std::shared_ptr<SharedDeviceBuffer> device_buffer,
-                             std::shared_ptr<PyLocalClient> client)
+                             std::shared_ptr<PyLocalClient> client,
+                             std::shared_ptr<Device> device)
     : client_(std::move(client)),
       on_host_shape_(std::move(on_host_shape)),
-      device_ordinal_(device_buffer->device_ordinal()),
+      device_(std::move(device)),
       device_buffer_(std::move(device_buffer)) {}
 
 void PyLocalBuffer::Delete() {
@@ -499,8 +505,7 @@ Status PyLocalBuffer::CopyToHostAsync() {
     }
     host_value = host_value_ = std::make_shared<HostValue>();
   }
-  se::Stream* stream =
-      client_->device_state(device_ordinal_).device_to_host_stream();
+  se::Stream* stream = device_->local_device_state()->device_to_host_stream();
   WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
   host_value->value = std::make_shared<Literal>(on_host_shape_);
   TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, AsShapedBuffer());
@@ -564,36 +569,38 @@ PyLocalBuffer::DestructureTuple() {
   for (int64 i = 0; i < num_children; ++i) {
     results.push_back(absl::make_unique<PyLocalBuffer>(
         on_host_shape_.tuple_shapes(i), device_buffer_->children().at(i),
-        client_));
+        client_, device_));
   }
   return results;
 }
 
 StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
-    int dst_device_ordinal) {
+    std::shared_ptr<Device> dst_device) {
   tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice");
   std::shared_ptr<SharedDeviceBuffer> src_device_buffer = DeviceBuffer();
-  if (dst_device_ordinal == device_ordinal_) {
-    return absl::make_unique<PyLocalBuffer>(on_host_shape_, src_device_buffer,
-                                            client_);
-  }
-  int transfer_device_ordinal = client_->EnqueueD2DTransfersOnSrcStream()
-                                    ? device_ordinal_
-                                    : dst_device_ordinal;
-  DeviceState& transfer_device = client_->device_state(transfer_device_ordinal);
-  const DeviceState& dst_device = client_->device_state(dst_device_ordinal);
+  TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device,
+                      dst_device->GetLocalDeviceState());
 
-  se::Stream* transfer_stream = transfer_device.GetDeviceToDeviceStream();
+  if (dst_device.get() == device_.get()) {
+    return absl::make_unique<PyLocalBuffer>(on_host_shape_, src_device_buffer,
+                                            client_, device_);
+  }
+  LocalDeviceState* transfer_local_device =
+      client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state()
+                                                : dst_local_device;
+
+  se::Stream* transfer_stream =
+      transfer_local_device->GetDeviceToDeviceStream();
 
   TransferManager* transfer_manager =
       client_->client()->backend().transfer_manager();
-  TF_ASSIGN_OR_RETURN(
-      ScopedShapedBuffer dst_buffer,
-      transfer_manager->AllocateScopedShapedBuffer(
-          on_host_shape_, client_->allocator(), dst_device_ordinal));
+  TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
+                      transfer_manager->AllocateScopedShapedBuffer(
+                          on_host_shape_, client_->allocator(),
+                          dst_local_device->device_ordinal()));
   if (!transfer_manager->CanShapedBufferBeAccessedNow(
-          dst_device.compute_stream()->parent(), dst_buffer)) {
-    transfer_stream->ThenWaitFor(dst_device.compute_stream());
+          dst_local_device->compute_stream()->parent(), dst_buffer)) {
+    transfer_stream->ThenWaitFor(dst_local_device->compute_stream());
   }
   TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer());
 
@@ -607,37 +614,39 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
     TF_RET_CHECK(input_buffer.size() == output_buffer.size())
         << "input: " << input_buffer.size()
         << " output: " << output_buffer.size();
-    TF_RETURN_IF_ERROR(transfer_device.ThenMemcpyDeviceToDevice(
-        transfer_stream, dst_device.compute_stream(), input_buffer,
+    TF_RETURN_IF_ERROR(transfer_local_device->ThenMemcpyDeviceToDevice(
+        transfer_stream, dst_local_device->compute_stream(), input_buffer,
         output_buffer));
   }
 
   // We hold on to the `src_device_buffer` until the transfer is finished.
-  transfer_device.ThenRelease(transfer_stream, std::move(src_device_buffer));
+  transfer_local_device->ThenRelease(transfer_stream,
+                                     std::move(src_device_buffer));
 
   // Write new tuple buffers. The destination buffers have different addresses,
   // so we must construct tuple buffers from scratch instead of copying them.
   if (dst_buffer.on_device_shape().IsTuple()) {
     TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
-        dst_device.host_to_device_stream(), dst_buffer));
+        dst_local_device->host_to_device_stream(), dst_buffer));
 
     // We need a single definition event, so make the device to device stream
     // wait for the stream that wrote the tuple index tables on the destination
     // device.
-    transfer_stream->ThenWaitFor(dst_device.host_to_device_stream());
+    transfer_stream->ThenWaitFor(dst_local_device->host_to_device_stream());
   }
 
   auto definition_event = std::make_shared<BufferDefinitionEvent>();
   TF_ASSIGN_OR_RETURN(
       EventPool::Handle event,
-      transfer_device.event_pool().ThenAllocateAndRecordEvent(transfer_stream));
+      transfer_local_device->event_pool().ThenAllocateAndRecordEvent(
+          transfer_stream));
   definition_event->SetDefinitionEvent(std::move(event), transfer_stream);
 
   std::shared_ptr<SharedDeviceBuffer> dst_device_buffer =
       SharedDeviceBuffer::FromScopedShapedBuffer(std::move(dst_buffer),
                                                  definition_event);
   return absl::make_unique<PyLocalBuffer>(
-      on_host_shape_, std::move(dst_device_buffer), client_);
+      on_host_shape_, std::move(dst_device_buffer), client_, dst_device);
 }
 
 Status PyLocalBuffer::BlockHostUntilReady() {
@@ -694,7 +703,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
   const int device_id = (*device_assignment_)(replica, 0);
   std::shared_ptr<Device> device = LookupDevice(*client_, device_id);
   CHECK_EQ(device->host_id(), client_->host_id());
-  int device_ordinal = device->local_device_ordinal();
+  int device_ordinal = device->local_device_state()->device_ordinal();
   tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
   VLOG(3) << "Replica " << replica
           << " mapped to device ordinal for execution: " << device_ordinal;
@@ -729,7 +738,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
             << " buffer: " << argument_buffers.back().ToString();
   }
 
-  DeviceState* device_state = &client_->device_state(device_ordinal);
+  LocalDeviceState* device_state = &client_->device_state(device_ordinal);
   // The choice of where we wait is arbitrary; the reason for the wait is pacing
   // to avoid problems such as memory fragmentation and running ahead too far,
   // not for correctness. Placing it before the executable launch allows the
@@ -782,7 +791,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
       device_state->compute_stream(),
       std::make_tuple(executable_, compute_reservation, device_assignment_));
   return absl::make_unique<PyLocalBuffer>(on_host_shape, std::move(out_buffer),
-                                          client_);
+                                          client_, device);
 }
 
 StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::Execute(
@@ -833,8 +842,7 @@ PyLocalExecutable::ExecutePerReplica(
     for (int i = 0; i < num_local_replicas; ++i) {
       const int replica = local_replicas_[i];
       std::shared_ptr<Device> device = local_devices_[i];
-      const DeviceState& device_state =
-          client_->device_state(device->local_device_ordinal());
+      const LocalDeviceState& device_state = *device->local_device_state();
       device_state.execute_thread()->Schedule([&, replica, i] {
         results[i] = ExecuteHelper(argument_handles[i], replica, run_id);
 
diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h
index 3f13f62241f..e0a21ad6f1e 100644
--- a/tensorflow/compiler/xla/python/local_client.h
+++ b/tensorflow/compiler/xla/python/local_client.h
@@ -27,7 +27,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/client/executable_build_options.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
-#include "tensorflow/compiler/xla/python/device_state.h"
+#include "tensorflow/compiler/xla/python/local_device_state.h"
 #include "tensorflow/compiler/xla/python/shared_device_buffer.h"
 #include "tensorflow/compiler/xla/service/computation_placer.h"
 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
@@ -43,10 +43,10 @@ class PyLocalExecutable;
 
 class Device {
  public:
-  explicit Device(int id, int local_device_ordinal,
+  explicit Device(int id, std::unique_ptr<LocalDeviceState> local_device_state,
                   absl::string_view platform_name, int host_id = 0)
       : id_(id),
-        local_device_ordinal_(local_device_ordinal),
+        local_device_state_(std::move(local_device_state)),
         host_id_(host_id),
         platform_name_(platform_name) {}
   virtual ~Device() {}
@@ -56,13 +56,17 @@ class Device {
   // hosts' devices.  This is the ID that should be used in a DeviceAssignment.
   int id() const { return id_; }
 
-  // If this is a device local to this host, the local index of this device as
-  // according to the underlying backend. Unlike id(), this will always be in
-  // the range [0, num_local_devices), and can be used with the xla::LocalClient
-  // and xla::Backend APIs.
-  //
-  // -1 if this device is not local to this host.
-  int local_device_ordinal() const { return local_device_ordinal_; }
+  // If this is a device local to this host, returns a LocalDeviceState object
+  // that can be used to manipulate the device. Returns nullptr if the device is
+  // not local to this host.
+  LocalDeviceState* local_device_state() const {
+    return local_device_state_.get();
+  }
+
+  // If this is a device local to this host, returns a LocalDeviceState object
+  // that can be used to manipulate the device. Returns an error if the device
+  // is not local to this host.
+  StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
 
   // The ID of this device's host. This is always 0 on single-host platforms.
   int host_id() const { return host_id_; }
@@ -73,7 +77,7 @@ class Device {
 
  private:
   const int id_;
-  const int local_device_ordinal_;
+  const std::unique_ptr<LocalDeviceState> local_device_state_;
   const int host_id_;
   const std::string platform_name_;
 };
@@ -123,13 +127,14 @@ class PyLocalClient {
   explicit PyLocalClient(
       std::string platform_name, LocalClient* client,
       std::vector<std::shared_ptr<Device>> devices, int host_id,
-      std::vector<std::unique_ptr<DeviceState>> device_states,
       std::unique_ptr<se::DeviceMemoryAllocator> allocator,
       std::unique_ptr<tensorflow::Allocator> host_memory_allocator);
   virtual ~PyLocalClient() = default;
 
-  Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal);
-  StatusOr<Literal> TransferFromOutfeed(const Shape& shape, int device_ordinal);
+  Status TransferToInfeed(const LiteralSlice& literal,
+                          std::shared_ptr<Device> device);
+  StatusOr<Literal> TransferFromOutfeed(const Shape& shape,
+                                        std::shared_ptr<Device> device);
 
   virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
       int num_replicas) const;
@@ -146,8 +151,8 @@ class PyLocalClient {
   int host_id() const { return host_id_; }
   const std::string& platform_name() const { return platform_name_; }
 
-  DeviceState& device_state(int device_ordinal) const {
-    return *device_states_.at(device_ordinal);
+  LocalDeviceState& device_state(int device_ordinal) const {
+    return *local_devices_.at(device_ordinal)->local_device_state();
   }
 
   LocalClient* client() const { return client_; }
@@ -178,10 +183,6 @@ class PyLocalClient {
       const std::string& serialized,
       std::shared_ptr<PyLocalClient> this_shared) const;
 
-  // Returns a bad status containing `caller_name` if `device_ordinal` doesn't
-  // correspond to a local device.
-  Status CheckDeviceOrdinal(int device_ordinal, absl::string_view caller_name);
-
  protected:
   std::string platform_name_;
   LocalClient* client_;
@@ -194,8 +195,6 @@ class PyLocalClient {
   std::vector<std::shared_ptr<Device>> local_devices_;
   int host_id_;
 
-  // Device states local to this host. Indexed by local device ordinal.
-  std::vector<std::unique_ptr<DeviceState>> device_states_;
   se::DeviceMemoryAllocator* allocator_;
   std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
 
@@ -219,16 +218,16 @@ class PyLocalBuffer {
   static StatusOr<std::unique_ptr<PyLocalBuffer>> FromLiterals(
       std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
       std::shared_ptr<void> leaves_reference,
-      std::shared_ptr<PyLocalClient> client, int device_ordinal);
+      std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device);
 
   static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
       const std::vector<PyLocalBuffer*> buffers,
-      std::shared_ptr<PyLocalClient> client, int device_ordinal);
+      std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device);
 
-  PyLocalBuffer() = default;
   PyLocalBuffer(Shape on_host_shape,
                 std::shared_ptr<SharedDeviceBuffer> device_buffer,
-                std::shared_ptr<PyLocalClient> client);
+                std::shared_ptr<PyLocalClient> client,
+                std::shared_ptr<Device> device);
 
   PyLocalBuffer(const PyLocalBuffer&) = delete;
   PyLocalBuffer(PyLocalBuffer&&) = delete;
@@ -236,7 +235,7 @@ class PyLocalBuffer {
   PyLocalBuffer& operator=(PyLocalBuffer&&) = delete;
 
   const Shape& on_host_shape() const { return on_host_shape_; }
-  int device_ordinal() const { return device_ordinal_; }
+  std::shared_ptr<Device> device() const { return device_; }
   const std::string& platform_name() const { return client_->platform_name(); }
   std::shared_ptr<PyLocalClient> client() const { return client_; }
 
@@ -266,8 +265,9 @@ class PyLocalBuffer {
   // Destructures a tuple-valued PyLocalBuffer into its constituent elements.
   StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> DestructureTuple();
 
-  // Copies the buffer to device `dst_device_ordinal`.
-  StatusOr<std::unique_ptr<PyLocalBuffer>> CopyToDevice(int dst_device_ordinal);
+  // Copies the buffer to device `dst_device`.
+  StatusOr<std::unique_ptr<PyLocalBuffer>> CopyToDevice(
+      std::shared_ptr<Device> dst_device);
 
   // Blocks the host until the buffer's value has been computed and is ready for
   // immediate use on the device. Useful in particular for timing benchmarks.
@@ -276,7 +276,7 @@ class PyLocalBuffer {
  private:
   const std::shared_ptr<PyLocalClient> client_;
   const Shape on_host_shape_;
-  const int device_ordinal_;
+  const std::shared_ptr<Device> device_;
   mutable absl::Mutex mu_;
   std::shared_ptr<SharedDeviceBuffer> device_buffer_ GUARDED_BY(mu_);
 
diff --git a/tensorflow/compiler/xla/python/device_state.cc b/tensorflow/compiler/xla/python/local_device_state.cc
similarity index 81%
rename from tensorflow/compiler/xla/python/device_state.cc
rename to tensorflow/compiler/xla/python/local_device_state.cc
index 3403d882e92..6b8d09d4ffa 100644
--- a/tensorflow/compiler/xla/python/device_state.cc
+++ b/tensorflow/compiler/xla/python/local_device_state.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/compiler/xla/python/device_state.h"
+#include "tensorflow/compiler/xla/python/local_device_state.h"
 
 #include <memory>
 #include <vector>
@@ -24,12 +24,13 @@ limitations under the License.
 
 namespace xla {
 
-DeviceState::DeviceState(se::StreamExecutor* executor,
-                         bool synchronous_deallocation, bool asynchronous,
-                         bool allow_event_reuse)
+LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor,
+                                   bool synchronous_deallocation,
+                                   bool asynchronous, bool allow_event_reuse)
     : synchronous_deallocation_(synchronous_deallocation),
       event_pool_(allow_event_reuse),
-      compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1) {
+      compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1),
+      executor_(executor) {
   compute_stream_ = absl::make_unique<se::Stream>(executor);
   host_to_device_stream_ = absl::make_unique<se::Stream>(executor);
   device_to_host_stream_ = absl::make_unique<se::Stream>(executor);
@@ -50,14 +51,14 @@ DeviceState::DeviceState(se::StreamExecutor* executor,
                                                      "py_xla_callback");
 }
 
-DeviceState::~DeviceState() {
+LocalDeviceState::~LocalDeviceState() {
   Status status = SynchronizeAllActivity();
   if (!status.ok()) {
     LOG(ERROR) << "Error when closing device: " << status;
   }
 }
 
-Status DeviceState::SynchronizeAllActivity() {
+Status LocalDeviceState::SynchronizeAllActivity() {
   Status status;
   // TODO(phawkins): in theory the call to SynchronizeAllActivity below should
   // suffice. However on the Host platform SynchronizeAllActivity is a dummy
@@ -73,10 +74,9 @@ Status DeviceState::SynchronizeAllActivity() {
   return status;
 }
 
-Status DeviceState::ThenMemcpyDeviceToDevice(se::Stream* transfer_stream,
-                                             se::Stream* dst_stream,
-                                             se::DeviceMemoryBase src_buffer,
-                                             se::DeviceMemoryBase dst_buffer) {
+Status LocalDeviceState::ThenMemcpyDeviceToDevice(
+    se::Stream* transfer_stream, se::Stream* dst_stream,
+    se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
   // The default implementation simply calls ThenMemcpyD2D, and assumes that
   // the buffer addresses identify the devices. This does not work
   // on all platforms; this method is virtual so it can be overridden.
@@ -84,14 +84,14 @@ Status DeviceState::ThenMemcpyDeviceToDevice(se::Stream* transfer_stream,
   return Status::OK();
 }
 
-void DeviceState::ThenExecuteOnCallbackThread(
+void LocalDeviceState::ThenExecuteOnCallbackThread(
     se::Stream* stream, std::function<void()> callback) const {
   stream->ThenDoHostCallback([this, callback]() mutable {
     callback_thread_->Schedule(std::move(callback));
   });
 }
 
-se::Stream* DeviceState::GetDeviceToDeviceStream() {
+se::Stream* LocalDeviceState::GetDeviceToDeviceStream() {
   absl::MutexLock lock(&mu_);
   int i = next_device_to_device_stream_;
   next_device_to_device_stream_ =
diff --git a/tensorflow/compiler/xla/python/device_state.h b/tensorflow/compiler/xla/python/local_device_state.h
similarity index 88%
rename from tensorflow/compiler/xla/python/device_state.h
rename to tensorflow/compiler/xla/python/local_device_state.h
index 3772c03fc59..fe9b9bd61b3 100644
--- a/tensorflow/compiler/xla/python/device_state.h
+++ b/tensorflow/compiler/xla/python/local_device_state.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_
-#define TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_
+#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_
+#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_
 
 #include <memory>
 #include <vector>
@@ -29,9 +29,9 @@ limitations under the License.
 namespace xla {
 
 // Class that encapsulates state relating to a device (e.g., a GPU) on which we
-// can perform computation and transfers. DeviceState objects only exist for
-// devices local to this host.
-class DeviceState {
+// can perform computation and transfers. LocalDeviceState objects only exist
+// for devices local to this host.
+class LocalDeviceState {
  public:
   // If synchronous_deallocation is true, the host must not free buffers until
   // compute/transfers that use those buffers have completed. For example, this
@@ -40,9 +40,12 @@ class DeviceState {
   //
   // If asynchronous is false, the host will synchronize to the device after
   // each execution or transfer. This is intended for debugging only.
-  DeviceState(se::StreamExecutor* executor, bool synchronous_deallocation,
-              bool asynchronous, bool allow_event_reuse);
-  virtual ~DeviceState();
+  LocalDeviceState(se::StreamExecutor* executor, bool synchronous_deallocation,
+                   bool asynchronous, bool allow_event_reuse);
+  virtual ~LocalDeviceState();
+
+  // StreamExecutor (local) device ordinal.
+  int device_ordinal() const { return executor_->device_ordinal(); }
 
   bool synchronous_deallocation() const { return synchronous_deallocation_; }
 
@@ -104,6 +107,7 @@ class DeviceState {
   // stream by the host ahead of the device.
   Semaphore compute_semaphore_;
 
+  se::StreamExecutor* executor_;
   std::unique_ptr<se::Stream> compute_stream_;
   std::unique_ptr<se::Stream> host_to_device_stream_;
   std::unique_ptr<se::Stream> device_to_host_stream_;
@@ -132,4 +136,4 @@ class DeviceState {
 
 }  // namespace xla
 
-#endif  // TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_
+#endif  // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_
diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD
index d5d492de054..13e0d147e86 100644
--- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD
+++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD
@@ -19,7 +19,6 @@ cc_library(
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto_cc",
         "//tensorflow/compiler/xla/client:executable_build_options",
-        "//tensorflow/compiler/xla/python:device_state",
         "//tensorflow/compiler/xla/python:local_client",
         "//tensorflow/compiler/xla/python:semaphore",
         "//tensorflow/compiler/xla/python/tpu_driver",
diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc
index b9ca2a7e1a7..f0c93772ffe 100644
--- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc
+++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc
@@ -39,10 +39,9 @@ std::string TpuDevice::DebugString() const {
 }
 
 static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
-                                          int id, int local_device_ordinal) {
+                                          int id) {
   CHECK_EQ(platform_name, "tpu");
-  CHECK_EQ(id, local_device_ordinal);  // Every device must be local for now.
-  return std::make_shared<TpuDevice>(id, local_device_ordinal, "tpu");
+  return std::make_shared<TpuDevice>(id, /*local_device_state=*/nullptr, "tpu");
 }
 
 StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
@@ -67,7 +66,7 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
   LOG(INFO) << "Creating " << num_cores << " TPU device(s).";
   devices.reserve(num_cores);
   for (int i = 0; i < num_cores; ++i) {
-    devices.push_back(MakeDevice("tpu", i, i));
+    devices.push_back(MakeDevice("tpu", i));
   }
 
   return std::make_shared<PyTpuClient>("tpu", std::move(client),
@@ -87,8 +86,8 @@ PyTpuClient::PyTpuClient(std::string platform_name,
     CHECK(id_to_device_.insert({device->id(), device}).second)
         << "Duplicate device id: " << device->id();
 
-    if (device->local_device_ordinal() != -1) {
-      int idx = device->local_device_ordinal();
+    if (device->id() != -1) {
+      int idx = device->id();
       CHECK(local_devices_[idx] == nullptr) << idx;
       CHECK_LT(idx, local_devices_.size());
       local_devices_[idx] = device;
@@ -509,7 +508,7 @@ PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper(
   const int device_id = device_assignment_(replica, 0);
   std::shared_ptr<Device> device = LookupDevice(*client_, device_id);
   CHECK_EQ(device->host_id(), client_->host_id());
-  int device_ordinal = device->local_device_ordinal();
+  int device_ordinal = device->id();
   tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute");
   VLOG(3) << "Replica " << replica
           << " mapped to device ordinal for execution: " << device_ordinal;
@@ -742,7 +741,7 @@ PyTpuExecutable::ExecutePerReplica(
     const int device_id = (*device_assignment)(replica, 0);
     std::shared_ptr<Device> device = LookupDevice(*client, device_id);
     CHECK_EQ(device->host_id(), client->host_id());
-    int device_ordinal = device->local_device_ordinal();
+    int device_ordinal = device->id();
     loaded_programs[replica] = client->driver()->LoadProgram(
         device_ordinal, compiled_program.get(), {});
   }
diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h
index 7624a14943f..49d4182b719 100644
--- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h
+++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h
@@ -24,7 +24,6 @@ limitations under the License.
 #include "absl/synchronization/notification.h"
 #include "absl/types/span.h"
 #include "tensorflow/compiler/xla/client/executable_build_options.h"
-#include "tensorflow/compiler/xla/python/device_state.h"
 #include "tensorflow/compiler/xla/python/local_client.h"
 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc
index 60886416a62..2b7082d40c9 100644
--- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc
+++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc
@@ -96,9 +96,9 @@ PYBIND11_MODULE(tpu_client_extension, m) {
                           std::make_move_iterator(tree.leaves.end()));
 
             py::gil_scoped_release gil_release;
-            return PyTpuBuffer::FromLiterals(
-                std::move(leaves), tree.shape, std::move(py_buffer_ref),
-                std::move(client), device->local_device_ordinal());
+            return PyTpuBuffer::FromLiterals(std::move(leaves), tree.shape,
+                                             std::move(py_buffer_ref),
+                                             std::move(client), device->id());
           })
       .def_static(
           "from_python",
@@ -135,8 +135,8 @@ PYBIND11_MODULE(tpu_client_extension, m) {
                           "Cannot make tuple on device '%s' with '%s' backend",
                           device->DebugString(), client->platform_name());
                     }
-                    return PyTpuBuffer::MakeTuple(
-                        buffers, client, device->local_device_ordinal());
+                    return PyTpuBuffer::MakeTuple(buffers, client,
+                                                  device->id());
                   })
       .def_static("make_tuple", &PyTpuBuffer::MakeTuple)
       .def("copy_to_device",
@@ -144,7 +144,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
              CHECK(dst_device != nullptr);
              GlobalPyRefManager()->CollectGarbage();
              py::gil_scoped_release gil_release;
-             return buffer->CopyToDevice(dst_device->local_device_ordinal());
+             return buffer->CopyToDevice(dst_device->id());
            })
       .def("copy_to_device",
            [](PyTpuBuffer* buffer, int dst_device_ordinal) {
@@ -193,7 +193,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
            [](const PyTpuExecutable& executable) {
              std::vector<int> device_ordinals;
              for (std::shared_ptr<Device> device : executable.local_devices()) {
-               device_ordinals.push_back(device->local_device_ordinal());
+               device_ordinals.push_back(device->id());
              }
              return device_ordinals;
            })
diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc
index f1776763796..b5eb6fa47da 100644
--- a/tensorflow/compiler/xla/python/xla.cc
+++ b/tensorflow/compiler/xla/python/xla.cc
@@ -142,6 +142,16 @@ Status PyRegisterCustomCallTarget(const std::string& fn_name,
   return Status::OK();
 }
 
+StatusOr<std::shared_ptr<Device>> LookupDeviceOrdinal(
+    PyLocalClient* client, int device_ordinal, absl::string_view caller_name) {
+  if (device_ordinal < 0 || device_ordinal >= client->local_device_count()) {
+    return InvalidArgument(
+        "%s got bad device_ordinal: %d (num_local_devices=%d)", caller_name,
+        device_ordinal, client->local_device_count());
+  }
+  return client->local_devices()[device_ordinal];
+}
+
 }  // namespace
 
 PYBIND11_MODULE(xla_extension, m) {
@@ -381,13 +391,27 @@ PYBIND11_MODULE(xla_extension, m) {
              }
              return result;
            })
+      // TODO(phawkins): delete overload that accepts a device_ordinal after
+      // all callers have been updated to pass a Device.
       .def("TransferToInfeed",
            [](PyLocalClient* client, const LiteralSlice& literal,
               int device_ordinal) {
              GlobalPyRefManager()->CollectGarbage();
              py::gil_scoped_release gil_release;
-             return client->TransferToInfeed(literal, device_ordinal);
+             TF_ASSIGN_OR_RETURN(std::shared_ptr<Device> device,
+                                 LookupDeviceOrdinal(client, device_ordinal,
+                                                     "TransferToInfeed"));
+             return client->TransferToInfeed(literal, device);
            })
+      .def("TransferToInfeed",
+           [](PyLocalClient* client, const LiteralSlice& literal,
+              std::shared_ptr<Device> device) {
+             GlobalPyRefManager()->CollectGarbage();
+             py::gil_scoped_release gil_release;
+             return client->TransferToInfeed(literal, device);
+           })
+      // TODO(phawkins): delete overload that accepts a device_ordinal after
+      // all callers have been updated to pass a Device.
       .def("TransferFromOutfeed",
            [](PyLocalClient* client, const Shape& shape,
               int device_ordinal) -> StatusOr<py::object> {
@@ -395,8 +419,24 @@ PYBIND11_MODULE(xla_extension, m) {
              std::shared_ptr<Literal> literal_shared;
              {
                py::gil_scoped_release gil_release;
-               TF_ASSIGN_OR_RETURN(Literal literal, client->TransferFromOutfeed(
-                                                        shape, device_ordinal));
+               TF_ASSIGN_OR_RETURN(std::shared_ptr<Device> device,
+                                   LookupDeviceOrdinal(client, device_ordinal,
+                                                       "TransferFromOutfeed"));
+               TF_ASSIGN_OR_RETURN(Literal literal,
+                                   client->TransferFromOutfeed(shape, device));
+               literal_shared = std::make_shared<Literal>(std::move(literal));
+             }
+             return LiteralToPython(std::move(literal_shared));
+           })
+      .def("TransferFromOutfeed",
+           [](PyLocalClient* client, const Shape& shape,
+              std::shared_ptr<Device> device) -> StatusOr<py::object> {
+             GlobalPyRefManager()->CollectGarbage();
+             std::shared_ptr<Literal> literal_shared;
+             {
+               py::gil_scoped_release gil_release;
+               TF_ASSIGN_OR_RETURN(Literal literal,
+                                   client->TransferFromOutfeed(shape, device));
                literal_shared = std::make_shared<Literal>(std::move(literal));
              }
              return LiteralToPython(std::move(literal_shared));
@@ -440,7 +480,7 @@ PYBIND11_MODULE(xla_extension, m) {
             py::gil_scoped_release gil_release;
             return PyLocalBuffer::FromLiterals(
                 std::move(leaves), tree.shape, std::move(py_buffer_ref),
-                std::move(client), device->local_device_ordinal());
+                std::move(client), std::move(device));
           })
       .def_static("make_tuple",
                   [](const std::vector<PyLocalBuffer*> buffers,
@@ -454,15 +494,15 @@ PYBIND11_MODULE(xla_extension, m) {
                           "Cannot make tuple on device '%s' with '%s' backend",
                           device->DebugString(), client->platform_name());
                     }
-                    return PyLocalBuffer::MakeTuple(
-                        buffers, client, device->local_device_ordinal());
+                    return PyLocalBuffer::MakeTuple(buffers, std::move(client),
+                                                    std::move(device));
                   })
       .def("copy_to_device",
            [](PyLocalBuffer* buffer, std::shared_ptr<Device> dst_device) {
              CHECK(dst_device != nullptr);
              GlobalPyRefManager()->CollectGarbage();
              py::gil_scoped_release gil_release;
-             return buffer->CopyToDevice(dst_device->local_device_ordinal());
+             return buffer->CopyToDevice(std::move(dst_device));
            })
       .def("delete", &PyLocalBuffer::Delete)
       .def("destructure", &PyLocalBuffer::DestructureTuple)
@@ -485,10 +525,7 @@ PYBIND11_MODULE(xla_extension, m) {
              return LiteralToPython(std::move(literal));
            })
       .def("shape", &PyLocalBuffer::on_host_shape)
-      .def("device",
-           [](PyLocalBuffer* buffer) -> std::shared_ptr<Device> {
-             return buffer->client()->local_devices()[buffer->device_ordinal()];
-           })
+      .def("device", &PyLocalBuffer::device)
       .def("platform", &PyLocalBuffer::platform_name)
       .def("is_deleted",
            [](const PyLocalBuffer& buffer) {
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index c7f36a56912..82cab92443c 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -444,7 +444,7 @@ def shape_from_pyval(pyval):
   return convert(pyval)
 
 
-def transfer_to_infeed(value, device_ordinal=0):
+def transfer_to_infeed(value, device=None):
   """Transfers the given value into the XLA infeed queue.
 
   XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with
@@ -454,29 +454,31 @@ def transfer_to_infeed(value, device_ordinal=0):
   Args:
     value: the value that the caller would like to enqueue into the XLA infeed
       queue
-    device_ordinal: the device to infeed the value to. Each device has a
+    device: the device to infeed the value to. Each device has a
       distinct infeed queue.
   """
   # TODO(phawkins): support non-default backends.
   backend = get_local_backend()
-  backend.client.TransferToInfeed(value, device_ordinal)
+  device = device or backend.local_devices()[0]
+  backend.client.TransferToInfeed(value, device)
 
 
-def transfer_from_outfeed(shape, device_ordinal=0):
-  """Transfers a literal of the given shape from `device_ordinal`'s outfeed.
+def transfer_from_outfeed(shape, device=None):
+  """Transfers a literal of the given shape from `device`'s outfeed.
 
   Args:
     shape: The shape of the value to transfer from outfeed.
-    device_ordinal: The device ordinal to transfer the outfeed value from. Each
-      device has a distinct outfeed queue..
+    device: The device from which to transfer the outfeed value. Each device has
+      a distinct outfeed queue..
 
   Returns:
     The literal value that is produced from the outfeed queue.
   """
   # TODO(phawkins): support non-default backends.
   backend = get_local_backend()
+  device = device or backend.local_devices()[0]
   return backend.client.TransferFromOutfeed(
-      shape.with_major_to_minor_layout_if_absent(), device_ordinal)
+      shape.with_major_to_minor_layout_if_absent(), device)
 
 
 DeviceAssignment = _xla.DeviceAssignment