From 55691acd3680182169a61adfd6006a04e639591d Mon Sep 17 00:00:00 2001 From: Qiao Zhang Date: Mon, 7 Dec 2020 21:59:57 -0800 Subject: [PATCH] Extract a PjRtDevice interface. - Extract a pure interface - Current implementation is renamed to PjRtStreamExecutorDevice. - Rename local_device to addressable_device. TODO: split into a pjrt_stream_executor.{h,cc} PiperOrigin-RevId: 346249293 Change-Id: Icf3cf5fe876a71e172b8ba0fced5fc4c8ca1cc93 --- tensorflow/compiler/xla/pjrt/cpu_device.cc | 4 +- tensorflow/compiler/xla/pjrt/cpu_device.h | 2 +- tensorflow/compiler/xla/pjrt/gpu_device.cc | 4 +- tensorflow/compiler/xla/pjrt/gpu_device.h | 2 +- .../compiler/xla/pjrt/interpreter_device.cc | 4 +- .../compiler/xla/pjrt/interpreter_device.h | 2 +- tensorflow/compiler/xla/pjrt/pjrt_client.cc | 102 ++++++++++++------ tensorflow/compiler/xla/pjrt/pjrt_client.h | 102 ++++++++++++------ tensorflow/compiler/xla/pjrt/tpu_client.h | 6 +- tensorflow/compiler/xla/python/dlpack.cc | 14 ++- .../compiler/xla/python/outfeed_receiver.cc | 6 +- tensorflow/compiler/xla/python/py_buffer.cc | 4 +- .../xla/python/tpu_driver/client/BUILD | 1 + .../python/tpu_driver/client/tpu_client.cc | 10 +- .../xla/python/tpu_driver/client/tpu_client.h | 24 +++-- .../tpu_driver/client/tpu_client_extension.cc | 2 +- tensorflow/compiler/xla/python/xla.cc | 11 +- 17 files changed, 187 insertions(+), 113 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc index 9b0f060f392..4bc89986c8c 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc @@ -26,8 +26,8 @@ static const char kCpuPlatformName[] = "cpu"; CpuDevice::CpuDevice(int id, std::unique_ptr local_device_state) - : PjRtDevice(id, std::move(local_device_state), - /*device_kind=*/kCpuPlatformName) {} + : PjRtStreamExecutorDevice(id, std::move(local_device_state), + /*device_kind=*/kCpuPlatformName) {} StatusOr> GetCpuClient(bool asynchronous) { TF_ASSIGN_OR_RETURN(se::Platform * platform, diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.h b/tensorflow/compiler/xla/pjrt/cpu_device.h index 1036d8fedbb..0aab55e6493 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.h +++ b/tensorflow/compiler/xla/pjrt/cpu_device.h @@ -23,7 +23,7 @@ limitations under the License. namespace xla { -class CpuDevice : public PjRtDevice { +class CpuDevice : public PjRtStreamExecutorDevice { public: CpuDevice(int id, std::unique_ptr local_device_state); }; diff --git a/tensorflow/compiler/xla/pjrt/gpu_device.cc b/tensorflow/compiler/xla/pjrt/gpu_device.cc index 26f38c2cbc4..6d1c0849ea1 100644 --- a/tensorflow/compiler/xla/pjrt/gpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/gpu_device.cc @@ -306,8 +306,8 @@ Status BuildDistributedDevices( GpuDevice::GpuDevice(int id, std::unique_ptr local_device_state, std::string device_kind, int node_id) - : PjRtDevice(id, std::move(local_device_state), std::move(device_kind), - node_id) {} + : PjRtStreamExecutorDevice(id, std::move(local_device_state), + std::move(device_kind), node_id) {} StatusOr> GetGpuClient( bool asynchronous, const GpuAllocatorConfig& allocator_config, diff --git a/tensorflow/compiler/xla/pjrt/gpu_device.h b/tensorflow/compiler/xla/pjrt/gpu_device.h index 7ea85db0401..142a263d959 100644 --- a/tensorflow/compiler/xla/pjrt/gpu_device.h +++ b/tensorflow/compiler/xla/pjrt/gpu_device.h @@ -25,7 +25,7 @@ limitations under the License. namespace xla { -class GpuDevice : public PjRtDevice { +class GpuDevice : public PjRtStreamExecutorDevice { public: GpuDevice(int id, std::unique_ptr local_device_state, std::string device_kind, int node_id); diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.cc b/tensorflow/compiler/xla/pjrt/interpreter_device.cc index 2819cabf258..3858aceba5e 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.cc +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.cc @@ -26,8 +26,8 @@ static const char kInterpreterPlatformName[] = "interpreter"; InterpreterDevice::InterpreterDevice( int id, std::unique_ptr local_device_state) - : PjRtDevice(id, std::move(local_device_state), - /*device_kind=*/kInterpreterPlatformName) {} + : PjRtStreamExecutorDevice(id, std::move(local_device_state), + /*device_kind=*/kInterpreterPlatformName) {} StatusOr> GetInterpreterClient() { TF_ASSIGN_OR_RETURN(se::Platform * platform, diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.h b/tensorflow/compiler/xla/pjrt/interpreter_device.h index 4038d8dbf11..a23ddcb5bb9 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.h +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.h @@ -23,7 +23,7 @@ limitations under the License. namespace xla { -class InterpreterDevice : public PjRtDevice { +class InterpreterDevice : public PjRtStreamExecutorDevice { public: InterpreterDevice(int id, std::unique_ptr local_device_state); diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index 191b3467de6..d230e5cadc3 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -114,21 +114,22 @@ limitations under the License. namespace xla { -PjRtPlatformId PjRtDevice::platform_id() const { +PjRtPlatformId PjRtStreamExecutorDevice::platform_id() const { return client_->platform_id(); } -const std::string& PjRtDevice::platform_name() const { +const std::string& PjRtStreamExecutorDevice::platform_name() const { return client_->platform_name(); } -StatusOr PjRtDevice::GetLocalDeviceState() const { +StatusOr PjRtStreamExecutorDevice::GetLocalDeviceState() + const { if (local_device_state_) { return local_device_state_.get(); } return InvalidArgument("Device %s is not a local device.", DebugString()); } -std::string PjRtDevice::DebugString() const { +std::string PjRtStreamExecutorDevice::DebugString() const { return absl::StrCat(platform_name(), ":", id()); } @@ -153,14 +154,15 @@ StatusOr DevicesToDeviceAssignment( devices[replica].size(), replica, devices[0].size()); } for (int partition = 0; partition < devices[replica].size(); ++partition) { - if (devices[0][0]->platform_id() != - devices[replica][partition]->platform_id()) { + if (devices[0][0]->client()->platform_id() != + devices[replica][partition]->client()->platform_id()) { return InvalidArgument( "Device assignment passed to Compile() must have devices of a " "single kind, got %s for replica 0 partition 0 and %s for replica " "%d partition %d.", - devices[0][0]->platform_name(), - devices[replica][partition]->platform_name(), replica, partition); + devices[0][0]->client()->platform_name(), + devices[replica][partition]->client()->platform_name(), replica, + partition); } xla_assignment(replica, partition) = devices[replica][partition]->id(); } @@ -215,15 +217,16 @@ PjRtClient::PjRtClient( CHECK(id_to_device_.insert({device->id(), device.get()}).second) << "Duplicate device id: " << device->id(); - if (device->IsLocalDevice()) { - int idx = device->local_device_id(); + if (device->IsAddressable()) { + int idx = device->local_hardware_id(); if (idx >= local_devices_.size()) { local_devices_.resize(idx + 1); } CHECK(local_devices_[idx] == nullptr) << idx; local_devices_[idx] = device.get(); } - device->SetClient(this); + tensorflow::down_cast(device.get()) + ->SetClient(this); } for (int idx = 0; idx < local_devices_.size(); ++idx) { CHECK(local_devices_[idx] != nullptr) << idx; @@ -554,7 +557,8 @@ StatusOr> PjRtClient::BufferFromHostBuffer( return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple"); } TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device->GetLocalDeviceState()); + tensorflow::down_cast(device) + ->GetLocalDeviceState()); int64 size = ShapeUtil::ByteSizeOf(shape); TransferManager* transfer_manager = client()->backend().transfer_manager(); @@ -721,7 +725,8 @@ StatusOr> PjRtClient::CreateUninitializedBuffer( VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: " << shape.ToString() << " device: " << device->DebugString(); TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device->GetLocalDeviceState()); + tensorflow::down_cast(device) + ->GetLocalDeviceState()); TransferManager* transfer_manager = client()->backend().transfer_manager(); TF_ASSIGN_OR_RETURN(Shape compact_shape, @@ -739,7 +744,8 @@ StatusOr> PjRtClient::BufferFromHostLiteral( VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: " << literal.shape().ToString() << " device: " << device->DebugString(); TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device->GetLocalDeviceState()); + tensorflow::down_cast(device) + ->GetLocalDeviceState()); TransferManager* transfer_manager = client()->backend().transfer_manager(); TF_ASSIGN_OR_RETURN( @@ -801,7 +807,9 @@ void PjRtClient::MakeCrossHostReceiveBuffers( return; } - auto local_device_or = device->GetLocalDeviceState(); + auto local_device_or = + tensorflow::down_cast(device) + ->GetLocalDeviceState(); if (!local_device_or.ok()) { notifier(local_device_or.status()); return; @@ -828,27 +836,29 @@ void PjRtClient::MakeCrossHostReceiveBuffers( } // Transfer the given literal to the infeed queue of the given local device. -Status PjRtDevice::TransferToInfeed(const LiteralSlice& literal) const { +Status PjRtStreamExecutorDevice::TransferToInfeed( + const LiteralSlice& literal) const { // Only support infeed to local device. TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState()); return local_device->client()->TransferToInfeedLocal( literal, local_device->device_ordinal()); } -StatusOr PjRtDevice::TransferFromOutfeed(const Shape& shape) const { +StatusOr PjRtStreamExecutorDevice::TransferFromOutfeed( + const Shape& shape) const { TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState()); return local_device->client()->TransferFromOutfeedLocal( shape, local_device->device_ordinal()); } -StatusOr PjRtClient::LookupLocalDevice(int local_device_id) const { +StatusOr PjRtClient::LookupAddressableDevice(int device_id) const { for (auto* device : local_devices_) { - if (local_device_id == device->local_device_id()) { + if (device_id == device->local_hardware_id()) { return device; } } - return InvalidArgument("No matching device found for local_device_id %d", - local_device_id); + return InvalidArgument("No matching device found for device_id %d", + device_id); } PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape, @@ -919,7 +929,9 @@ StatusOr> PjRtBuffer::Release( // the final set of usage events. events = device_buffer->LockUseAndTransferUsageEvents(); } - LocalDeviceState* local_device_state = device_->local_device_state(); + LocalDeviceState* local_device_state = + tensorflow::down_cast(device_) + ->local_device_state(); if (wait_for_operations_to_complete) { // Block the host until all usage events have completed. Usage events // dominate definition events, so this also waits for the buffer to be @@ -1080,7 +1092,9 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy, } ScopedHold device_buffer(this, ScopedHold::kUsage); std::shared_ptr host_value; - LocalDeviceState* local_device = device_->local_device_state(); + LocalDeviceState* local_device = + tensorflow::down_cast(device_) + ->local_device_state(); se::Stream* stream = local_device->GetDeviceToHostStream(); const xla::Layout& host_layout = layout.has_value() ? layout.value() : on_host_shape_.layout(); @@ -1241,8 +1255,9 @@ PjRtBuffer::CopyToDeviceHelper( // StallStreamOnError only makes sure the destination device is ok, so // make sure that the src buffer remains valid until after any transfers // have completed. - device_->local_device_state()->ThenRelease(transfer_stream, - src_device_buffer); + tensorflow::down_cast(device_) + ->local_device_state() + ->ThenRelease(transfer_stream, src_device_buffer); } return copy_event_or.status(); } @@ -1268,11 +1283,15 @@ StatusOr> PjRtBuffer::CopyToDevice( PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device); } - TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device, - dst_device->GetLocalDeviceState()); + TF_ASSIGN_OR_RETURN( + LocalDeviceState * dst_local_device, + tensorflow::down_cast(dst_device) + ->GetLocalDeviceState()); LocalDeviceState* transfer_local_device = - client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state() - : dst_local_device; + client_->EnqueueD2DTransfersOnSrcStream() + ? tensorflow::down_cast(device_) + ->local_device_state() + : dst_local_device; CHECK_EQ(dst_local_device->allocation_model(), transfer_local_device->allocation_model()); @@ -1310,7 +1329,9 @@ StatusOr> PjRtBuffer::CopyToDevice( // alternative is to ensure, before freeing the buffer, that the compute // stream is synchronized past the transfer, but it seems better to hold onto // the buffer too long than to stall the compute stream. - RecordUsage(std::move(src_device_buffer), device_->local_device_state(), + RecordUsage(std::move(src_device_buffer), + tensorflow::down_cast(device_) + ->local_device_state(), transfer_local_device, event, transfer_stream, /*prefer_to_retain_reference=*/true); @@ -1332,7 +1353,9 @@ Status PjRtBuffer::BlockHostUntilReady() { } device_buffer = device_buffer_; } - LocalDeviceState* local_device_state = device_->local_device_state(); + LocalDeviceState* local_device_state = + tensorflow::down_cast(device_) + ->local_device_state(); std::unique_ptr stream; for (auto& event : device_buffer->definition_events()) { if (!event->IsComplete()) { @@ -1628,7 +1651,9 @@ StatusOr PjRtStreamExecutorExecutable::EnqueueExecution( int executable_idx, const RunId& run_id, const ExecuteOptions& options, PjRtDevice* device, std::vector* device_buffers, std::shared_ptr device_assignment) const { - int device_ordinal = device->local_device_state()->device_ordinal(); + int device_ordinal = tensorflow::down_cast(device) + ->local_device_state() + ->device_ordinal(); LocalDeviceState* device_state = &client_->device_state(device_ordinal); tensorflow::profiler::TraceMeConsumer activity( "LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt, @@ -1814,7 +1839,9 @@ PjRtStreamExecutorExecutable::ExecuteHelper( } CHECK_EQ(device->host_id(), client_->host_id()); - int device_ordinal = device->local_device_state()->device_ordinal(); + int device_ordinal = tensorflow::down_cast(device) + ->local_device_state() + ->device_ordinal(); tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute"); VLOG(3) << "Replica " << replica << ", partition " << partition << " mapped to device ordinal for execution: " << device_ordinal; @@ -1922,7 +1949,9 @@ PjRtStreamExecutorExecutable::Execute( const int replica = addressable_device_logical_ids_[i].replica; const int partition = addressable_device_logical_ids_[i].partition; PjRtDevice* device = addressable_devices_[i]; - const LocalDeviceState& device_state = *device->local_device_state(); + const LocalDeviceState& device_state = + *tensorflow::down_cast(device) + ->local_device_state(); device_state.execute_thread()->Schedule([&, replica, partition, i] { results[i] = ExecuteHelper(argument_handles[i], replica, partition, run_id, options); @@ -2254,7 +2283,10 @@ StatusOr> PjRtClient::Compile( if (build_options.device_ordinal() < 0) { build_options.set_device_ordinal( - addressable_devices.front()->local_device_state()->device_ordinal()); + tensorflow::down_cast( + addressable_devices.front()) + ->local_device_state() + ->device_ordinal()); } } diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index b32d2889fc8..703999956a6 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" @@ -67,16 +68,56 @@ class PjRtClient; class PjRtDevice { public: - explicit PjRtDevice(int id, - std::unique_ptr local_device_state, - std::string device_kind, int host_id = 0) + virtual ~PjRtDevice() {} + + // Return the client that owns this device. + virtual PjRtClient* client() const = 0; + + // Whether client can issue command to this device. + virtual bool IsAddressable() const = 0; + + // The ID of this device. IDs are unique among devices of this type + // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all + // hosts' devices. This is the ID that should be used in a DeviceAssignment. + virtual int id() const = 0; + + // The task ID of this device according to TpuTopology. This is not the same + // as PjRtClient::host_id() in a multi-task setting, where each client can see + // devices from all tasks, but only a subset of them are addressable and have + // the same task_id as the client. + virtual int host_id() const = 0; + + // Opaque hardware ID, e.g., the CUDA device number, useful for identifying + // which GPU when interacting with non-JAX code. In general, not guaranteed to + // be dense, and -1 if undefined. + virtual int local_hardware_id() const = 0; + + // A vendor-dependent string that uniquely identifies the kind of device, + // e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are + // compatible compilation. + virtual const std::string& device_kind() const = 0; + + virtual std::string DebugString() const = 0; + + // Transfer the given literal to the infeed queue. + virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0; + + // Transfer and return a value of the given shape from the outfeed queue. + virtual StatusOr TransferFromOutfeed(const Shape& shape) const = 0; +}; + +class PjRtStreamExecutorDevice : public PjRtDevice { + public: + explicit PjRtStreamExecutorDevice( + int id, std::unique_ptr local_device_state, + std::string device_kind, int host_id = 0) : id_(id), - local_device_id_( + device_ordinal_( local_device_state ? local_device_state->device_ordinal() : -1), local_device_state_(std::move(local_device_state)), host_id_(host_id), device_kind_(std::move(device_kind)) {} - virtual ~PjRtDevice() {} + ~PjRtStreamExecutorDevice() override {} // Must set client exactly once. void SetClient(PjRtClient* client) { @@ -84,14 +125,25 @@ class PjRtDevice { client_ = client; } + // Task ID. This is always 0 on single-task setup. + int host_id() const override { return host_id_; } + + // Return `platform_id` from client. + PjRtPlatformId platform_id() const; + + // Return `platform_name` from client. + const std::string& platform_name() const; + + PjRtClient* client() const override { return client_; } + // The ID of this device. IDs are unique among devices of this type // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all // hosts' devices. This is the ID that should be used in a DeviceAssignment. - int id() const { return id_; } + int id() const override { return id_; } - bool IsLocalDevice() const { return local_device_id_ != -1; } + bool IsAddressable() const override { return device_ordinal_ != -1; } - int local_device_id() const { return local_device_id_; } + int local_hardware_id() const override { return device_ordinal_; } // If this is a device local to this host, returns a LocalDeviceState object // that can be used to manipulate the device. Returns nullptr if the device is @@ -105,32 +157,21 @@ class PjRtDevice { // is not local to this host. StatusOr GetLocalDeviceState() const; - // The ID of this device's host. This is always 0 on single-host platforms. - int host_id() const { return host_id_; } - - // Return `platform_id` from client. - PjRtPlatformId platform_id() const; - - // Return `platform_name` from client. - const std::string& platform_name() const; - // A vendor-dependent string that uniquely identifies the kind of device. - const std::string& device_kind() const { return device_kind_; } + const std::string& device_kind() const override { return device_kind_; } - virtual std::string DebugString() const; - - PjRtClient* client() const { return client_; } + std::string DebugString() const override; // Transfer the given literal to the infeed queue of the given localdevice. - virtual Status TransferToInfeed(const LiteralSlice& literal) const; + Status TransferToInfeed(const LiteralSlice& literal) const override; // Transfer and return a value of the given shape from the outfeed of the // given device. - virtual StatusOr TransferFromOutfeed(const Shape& shape) const; + StatusOr TransferFromOutfeed(const Shape& shape) const override; private: const int id_; - const int local_device_id_; // -1 means not local. + const int device_ordinal_; // -1 means not local. const std::unique_ptr local_device_state_; const int host_id_; const std::string device_kind_; @@ -196,9 +237,7 @@ class PjRtClient { const std::vector>& devices() const { return devices_; } - const std::vector& local_devices() const { - return local_devices_; - } + absl::Span local_devices() const { return local_devices_; } const std::map& id_to_device() const { return id_to_device_; } @@ -207,11 +246,13 @@ class PjRtClient { const std::string& platform_name() const { return platform_name_; } LocalDeviceState& device_state(int device_ordinal) const { - return *local_devices_.at(device_ordinal)->local_device_state(); + return *tensorflow::down_cast( + local_devices_.at(device_ordinal)) + ->local_device_state(); } - // Return a local PjRtDevice for a given `local_device_id`. - virtual StatusOr LookupLocalDevice(int local_device_id) const; + // Return an addressable PjRtDevice for a given `device_id`. + virtual StatusOr LookupAddressableDevice(int device_id) const; LocalClient* client() const { return client_; } se::DeviceMemoryAllocator* allocator() const { return allocator_; } @@ -791,6 +832,7 @@ class PjRtExecutable { virtual absl::Span addressable_device_logical_ids() const = 0; + // An addressable_device is one which the client can issue commands to. // addressable_devices()[i] is the Device to which // addressable_device_logical_ids()[i] is assigned. virtual absl::Span addressable_devices() const = 0; diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.h b/tensorflow/compiler/xla/pjrt/tpu_client.h index cdc68bc9606..f17d82a270e 100644 --- a/tensorflow/compiler/xla/pjrt/tpu_client.h +++ b/tensorflow/compiler/xla/pjrt/tpu_client.h @@ -26,14 +26,14 @@ limitations under the License. namespace xla { -class PjRtTpuDevice : public PjRtDevice { +class PjRtTpuDevice : public PjRtStreamExecutorDevice { public: PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core, std::unique_ptr local_device_state, int host_id, const std::array& coords, std::string device_kind) - : PjRtDevice(core.Id(), std::move(local_device_state), - std::move(device_kind), host_id), + : PjRtStreamExecutorDevice(core.Id(), std::move(local_device_state), + std::move(device_kind), host_id), core_(core), coords_(coords) {} diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index 85252256657..fd603930d6c 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -214,11 +214,9 @@ StatusOr> StridesToLayout(absl::Span dims, } StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { - const se::Platform* platform = - device.local_device_state()->executor()->platform(); - if (platform->id() == se::host::kHostPlatformId) { + if (device.client()->platform_id() == kCpuId) { return kDLCPU; - } else if (platform->id() == se::cuda::kCudaPlatformId) { + } else if (device.client()->platform_id() == kGpuId) { return kDLGPU; } return InvalidArgument("Device %s cannot be used as a DLPack device.", @@ -228,7 +226,7 @@ StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { StatusOr DLContextForDevice(const PjRtDevice& device) { DLContext context; TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device)); - context.device_id = device.local_device_id(); + context.device_id = device.local_hardware_id(); return context; } @@ -241,14 +239,14 @@ StatusOr DeviceForDLContext(const PjRtClient& client, "DLPack CPU device type mismatch with PjRtClient platform %s", client.platform_name()); } - return client.LookupLocalDevice(context.device_id); + return client.LookupAddressableDevice(context.device_id); case kDLGPU: if (client.platform_id() != kGpuId) { return InvalidArgument( "DLPack GPU device type mismatch with PjRtClient platform %s", client.platform_name()); } - return client.LookupLocalDevice(context.device_id); + return client.LookupAddressableDevice(context.device_id); default: return InvalidArgument("Unknown/unsupported DLPack device type %d", context.device_type); @@ -297,7 +295,7 @@ StatusOr BufferToDLPackManagedTensor(py::handle py_buffer, pack->tensor.manager_ctx = pack.get(); pack->tensor.deleter = DLPackTensorDeleter; TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device())); - dt.ctx.device_id = buffer->buffer()->device()->local_device_id(); + dt.ctx.device_id = buffer->buffer()->device()->local_hardware_id(); dt.ndim = buffer->buffer()->on_host_shape().dimensions_size(); TF_ASSIGN_OR_RETURN(dt.dtype, PrimitiveTypeToDLDataType( diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.cc b/tensorflow/compiler/xla/python/outfeed_receiver.cc index 3c0f9750f7f..69df3c6f230 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver.cc @@ -342,11 +342,7 @@ StatusOr> OutfeedReceiverImpl::ReceiveRawFromOutfeed( const PjRtDevice* device, const Shape& shape) { std::shared_ptr literal_shared; - TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device->GetLocalDeviceState()); - TF_ASSIGN_OR_RETURN(Literal literal, - local_device->client()->TransferFromOutfeedLocal( - shape, local_device->device_ordinal())); + TF_ASSIGN_OR_RETURN(Literal literal, device->TransferFromOutfeed(shape)); return absl::make_unique(std::move(literal)); } diff --git a/tensorflow/compiler/xla/python/py_buffer.cc b/tensorflow/compiler/xla/python/py_buffer.cc index 1f39266a989..dfa312c4592 100644 --- a/tensorflow/compiler/xla/python/py_buffer.cc +++ b/tensorflow/compiler/xla/python/py_buffer.cc @@ -86,8 +86,8 @@ StatusOr PyBuffer::UnsafeBufferPointer() const { } StatusOr PyBuffer::CudaArrayInterface() const { - if (buffer_->device()->local_device_state()->executor()->platform_kind() != - se::PlatformKind::kCuda) { + // TODO(zhangqiaorjc): Differentiate between NVidia and other GPUs. + if (buffer_->client()->platform_id() != kGpuId) { return InvalidArgument( "__cuda_array_interface__ is only defined for NVidia GPU buffers."); } diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD index 9d98d0cf654..28a491c0326 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD @@ -37,6 +37,7 @@ cc_library( "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core/framework:allocator", + "//tensorflow/core/platform:casts", "//tensorflow/core/platform:env", "//tensorflow/core/profiler/lib:traceme", "@com_google_absl//absl/memory", diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index c6a748068d4..a9aa218ca6f 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -37,8 +37,8 @@ namespace xla { TpuDevice::TpuDevice(int id, int host_id, const std::array& coords, int core_on_chip) - : xla::PjRtDevice(id, /*local_device_state=*/nullptr, - /*device_kind=*/"Cloud TPU", host_id), + : xla::PjRtStreamExecutorDevice(id, /*local_device_state=*/nullptr, + /*device_kind=*/"Cloud TPU", host_id), coords_(coords), core_on_chip_(core_on_chip) {} @@ -531,7 +531,7 @@ PyTpuExecutable::PyTpuExecutable( << "Inserting duplicate replica:" << replica; executables_[replica] = client_->driver()->LoadProgram(device_id, compiled_program.get(), {}); - addressable_device_logical_ids_.emplace_back(replica, partition); + local_logical_device_ids_.emplace_back(replica, partition); local_devices_.push_back(device); } } @@ -711,8 +711,8 @@ PyTpuExecutable::ExecuteOnLocalDevices( // long time and we want all cores to be scheduled in parallel. thread_pool->Schedule([this, i, argument_handles, &results, &results_lock, &execute_semaphore]() { - const int replica = addressable_device_logical_ids_[i].first; - const int partition = addressable_device_logical_ids_[i].second; + const int replica = local_logical_device_ids_[i].first; + const int partition = local_logical_device_ids_[i].second; RunId run_id; auto result = ExecuteHelper(argument_handles, argument_handles[i], replica, partition, run_id); diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index 20c2f749a75..89dca53bbb6 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -32,13 +32,14 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/threadpool.h" namespace xla { constexpr char kTpuPlatform[] = "tpu"; -class TpuDevice : public PjRtDevice { +class TpuDevice : public PjRtStreamExecutorDevice { public: TpuDevice(int id, int host_id, const std::array& coords, int core_on_chip); @@ -298,9 +299,8 @@ class PyTpuExecutable { return device_assignment_; } - const std::vector>& addressable_device_logical_ids() - const { - return addressable_device_logical_ids_; + const std::vector>& local_logical_device_ids() const { + return local_logical_device_ids_; } const std::vector>& local_devices() const { @@ -341,14 +341,16 @@ class PyTpuExecutable { // The replica and partition indices of device_assignment_ to be run by this // client. On single-host platforms without partitioning, this is all replicas - // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the - // case on multi-host platforms. If there are 4 replicas and 2 partitions on a - // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8. - std::vector> addressable_device_logical_ids_; + // (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case + // on multi-host platforms. + // If there are 4 replicas and 2 partitions on a single host platform, size of + // local_logical_device_ids_ is 4*2 = 8. + std::vector> local_logical_device_ids_; - // local_devices_[i] is the Device to which addressable_device_logical_ids_[i] - // is assigned. shared_ptrs instead of unique_ptrs to play well with the - // Python bindings (see xla.cc). + // local_devices_[i] is the Device to which local_logical_device_ids_[i] is + // assigned. + // shared_ptrs instead of unique_ptrs to play well with the Python bindings + // (see xla.cc). std::vector> local_devices_; xla::Shape result_shape_; diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index 0562ff230e1..a9fd70b6475 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -186,7 +186,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { py::class_(m, "TpuExecutable") .def("local_logical_device_ids", - &PyTpuExecutable::addressable_device_logical_ids) + &PyTpuExecutable::local_logical_device_ids) .def("local_devices", &PyTpuExecutable::local_devices) .def_property_readonly("client", &PyTpuExecutable::client) .def("size_of_generated_code_in_bytes", diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index dee1b14b90f..2b7068036d9 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -149,7 +149,10 @@ PYBIND11_MODULE(xla_extension, m) { .def_property_readonly("host_id", &PjRtDevice::host_id, "Integer ID of this device's host.\n\n" "This is always 0 except on multi-host platforms.") - .def_property_readonly("platform", &PjRtDevice::platform_name) + .def_property_readonly("platform", + [](const PjRtDevice& device) { + return device.client()->platform_name(); + }) .def_property_readonly("device_kind", &PjRtDevice::device_kind) .def_property_readonly( "client", @@ -381,10 +384,10 @@ PYBIND11_MODULE(xla_extension, m) { [](PyExecutable* exec) { auto span = exec->addressable_device_logical_ids(); // Not on dispatch critical path, so ok to have heap allocation. - std::vector> addressable_device_logical_ids; - addressable_device_logical_ids.reserve(span.size()); + std::vector> addressable_device_logic_ids; + addressable_device_logic_ids.reserve(span.size()); for (const auto& logical_device_id : span) { - addressable_device_logical_ids.push_back(std::make_pair( + addressable_device_logic_ids.push_back(std::make_pair( logical_device_id.replica, logical_device_id.partition)); } })