Refactor PJRT client.
Add a few methods to ensure third_party/tensorflow/compiler/xla/python only depend on PJRT interfaces rather than implementation details of PJRT. PiperOrigin-RevId: 338546086 Change-Id: I859af094d9c56d00b180dbbf367132630d59510c
This commit is contained in:
parent
7f9ce6eae5
commit
3ee3eea626
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
|
||||
namespace xla {
|
||||
@ -25,7 +26,7 @@ static const char kCpuPlatformName[] = "cpu";
|
||||
|
||||
CpuDevice::CpuDevice(int id,
|
||||
std::unique_ptr<LocalDeviceState> local_device_state)
|
||||
: PjRtDevice(id, std::move(local_device_state), kCpuPlatformName,
|
||||
: PjRtDevice(id, std::move(local_device_state),
|
||||
/*device_kind=*/kCpuPlatformName) {}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
||||
@ -57,7 +58,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
||||
}
|
||||
|
||||
return std::make_unique<PjRtClient>(
|
||||
kCpuPlatformName, client, std::move(devices), /*host_id=*/0,
|
||||
PjRtPlatformId::kCpu, client, std::move(devices), /*host_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr);
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
|
||||
namespace xla {
|
||||
@ -25,7 +26,7 @@ static const char kInterpreterPlatformName[] = "interpreter";
|
||||
|
||||
InterpreterDevice::InterpreterDevice(
|
||||
int id, std::unique_ptr<LocalDeviceState> local_device_state)
|
||||
: PjRtDevice(id, std::move(local_device_state), kInterpreterPlatformName,
|
||||
: PjRtDevice(id, std::move(local_device_state),
|
||||
/*device_kind=*/kInterpreterPlatformName) {}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
||||
@ -51,7 +52,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
||||
devices.push_back(std::move(device));
|
||||
|
||||
return std::make_unique<PjRtClient>(
|
||||
kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0,
|
||||
PjRtPlatformId::kInterpreter, client, std::move(devices), /*host_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr);
|
||||
|
@ -30,8 +30,6 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
static const char kGpuPlatformName[] = "gpu";
|
||||
|
||||
// A custom PjRtClient that overrides the device assignment method.
|
||||
class GpuClient : public xla::PjRtClient {
|
||||
public:
|
||||
@ -298,8 +296,8 @@ Status BuildDistributedDevices(
|
||||
GpuDevice::GpuDevice(int id,
|
||||
std::unique_ptr<LocalDeviceState> local_device_state,
|
||||
std::string device_kind, int node_id)
|
||||
: PjRtDevice(id, std::move(local_device_state), kGpuPlatformName,
|
||||
std::move(device_kind), node_id) {}
|
||||
: PjRtDevice(id, std::move(local_device_state), std::move(device_kind),
|
||||
node_id) {}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtClient>> GetNvidiaGpuClient(
|
||||
bool asynchronous, const GpuAllocatorConfig& allocator_config,
|
||||
@ -325,7 +323,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetNvidiaGpuClient(
|
||||
}
|
||||
|
||||
return std::unique_ptr<PjRtClient>(std::make_unique<GpuClient>(
|
||||
"gpu", xla_client, std::move(devices),
|
||||
PjRtPlatformId::kNvidiaGpu, xla_client, std::move(devices),
|
||||
/*node_id=*/node_id, std::move(allocator),
|
||||
std::move(host_memory_allocator),
|
||||
/*should_stage_host_to_device_transfers=*/true,
|
||||
|
@ -113,6 +113,13 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
PjRtPlatformId PjRtDevice::platform_id() const {
|
||||
return client_->platform_id();
|
||||
}
|
||||
const std::string& PjRtDevice::platform_name() const {
|
||||
return client_->platform_name();
|
||||
}
|
||||
|
||||
StatusOr<LocalDeviceState*> PjRtDevice::GetLocalDeviceState() const {
|
||||
if (local_device_state_) {
|
||||
return local_device_state_.get();
|
||||
@ -145,8 +152,8 @@ StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
|
||||
devices[replica].size(), replica, devices[0].size());
|
||||
}
|
||||
for (int partition = 0; partition < devices[replica].size(); ++partition) {
|
||||
if (devices[0][0]->platform_name() !=
|
||||
devices[replica][partition]->platform_name()) {
|
||||
if (devices[0][0]->platform_id() !=
|
||||
devices[replica][partition]->platform_id()) {
|
||||
return InvalidArgument(
|
||||
"Device assignment passed to Compile() must have devices of a "
|
||||
"single kind, got %s for replica 0 partition 0 and %s for replica "
|
||||
@ -175,13 +182,14 @@ class CpuAllocator : public tensorflow::Allocator {
|
||||
};
|
||||
|
||||
PjRtClient::PjRtClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
PjRtPlatformId platform_id, LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options)
|
||||
: platform_name_(std::move(platform_name)),
|
||||
: platform_id_(platform_id),
|
||||
platform_name_(Name(platform_id)),
|
||||
client_(client),
|
||||
host_memory_allocator_(std::move(host_memory_allocator)),
|
||||
devices_(std::move(devices)),
|
||||
@ -206,15 +214,15 @@ PjRtClient::PjRtClient(
|
||||
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
|
||||
<< "Duplicate device id: " << device->id();
|
||||
|
||||
if (device->local_device_state()) {
|
||||
int idx = device->local_device_state()->device_ordinal();
|
||||
if (device->IsLocalDevice()) {
|
||||
int idx = device->local_device_id();
|
||||
if (idx >= local_devices_.size()) {
|
||||
local_devices_.resize(idx + 1);
|
||||
}
|
||||
CHECK(local_devices_[idx] == nullptr) << idx;
|
||||
local_devices_[idx] = device.get();
|
||||
}
|
||||
device->client_ = this;
|
||||
device->SetClient(this);
|
||||
}
|
||||
for (int idx = 0; idx < local_devices_.size(); ++idx) {
|
||||
CHECK(local_devices_[idx] != nullptr) << idx;
|
||||
@ -576,6 +584,10 @@ void PjRtBuffer::ScopedHold::AddToInput(
|
||||
}
|
||||
}
|
||||
|
||||
bool PjRtBuffer::IsOnCpu() const {
|
||||
return client()->platform_id() == PjRtPlatformId::kCpu;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
@ -865,6 +877,16 @@ StatusOr<Literal> PjRtDevice::TransferFromOutfeed(const Shape& shape) const {
|
||||
shape, local_device->device_ordinal());
|
||||
}
|
||||
|
||||
StatusOr<PjRtDevice*> PjRtClient::LookupLocalDevice(int local_device_id) const {
|
||||
for (auto* device : local_devices_) {
|
||||
if (local_device_id == device->local_device_id()) {
|
||||
return device;
|
||||
}
|
||||
}
|
||||
return InvalidArgument("No matching device found for local_device_id %d",
|
||||
local_device_id);
|
||||
}
|
||||
|
||||
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
|
||||
PjRtClient* client, PjRtDevice* device)
|
||||
@ -1985,6 +2007,19 @@ PjRtExecutable::ExecuteOnLocalDevices(
|
||||
return wrapped_results;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::shared_ptr<HloModule>>>
|
||||
PjRtExecutable::GetHloModules() {
|
||||
std::vector<std::shared_ptr<HloModule>> modules;
|
||||
modules.reserve(executables().size());
|
||||
for (const auto& local_exec : executables()) {
|
||||
if (!local_exec->executable()->has_module()) {
|
||||
return InvalidArgument("Executable does not have HLO modules.");
|
||||
}
|
||||
modules.push_back(local_exec->executable()->shared_module());
|
||||
}
|
||||
return std::move(modules);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
StatusOr<Shape> GetShardedShape(const Shape& shape,
|
||||
|
@ -36,11 +36,13 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
@ -50,26 +52,63 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// TODO(zhangqiaorjc): Add a registration mechanism to add new platforms.
|
||||
enum class PjRtPlatformId : int {
|
||||
kCpu = 0,
|
||||
kNvidiaGpu = 1,
|
||||
kAmdGpu = 2,
|
||||
kTpu = 3,
|
||||
kEdgeTpu = 4,
|
||||
kInterpreter = 5
|
||||
};
|
||||
constexpr const char* Name(PjRtPlatformId platform_id) {
|
||||
switch (platform_id) {
|
||||
case PjRtPlatformId::kCpu:
|
||||
return "cpu";
|
||||
case PjRtPlatformId::kNvidiaGpu:
|
||||
// TODO(zhangqiaorjc): Rename to nvidia_gpu when we add AMD support.
|
||||
return "gpu";
|
||||
case PjRtPlatformId::kAmdGpu:
|
||||
return "amd_gpu";
|
||||
case PjRtPlatformId::kTpu:
|
||||
return "tpu";
|
||||
case PjRtPlatformId::kEdgeTpu:
|
||||
return "edge_tpu";
|
||||
case PjRtPlatformId::kInterpreter:
|
||||
return "interpreter";
|
||||
}
|
||||
}
|
||||
|
||||
class PjRtClient;
|
||||
|
||||
class PjRtDevice {
|
||||
public:
|
||||
explicit PjRtDevice(int id,
|
||||
std::unique_ptr<LocalDeviceState> local_device_state,
|
||||
std::string platform_name, std::string device_kind,
|
||||
int host_id = 0)
|
||||
std::string device_kind, int host_id = 0)
|
||||
: id_(id),
|
||||
local_device_id_(
|
||||
local_device_state ? local_device_state->device_ordinal() : -1),
|
||||
local_device_state_(std::move(local_device_state)),
|
||||
host_id_(host_id),
|
||||
platform_name_(std::move(platform_name)),
|
||||
device_kind_(std::move(device_kind)) {}
|
||||
virtual ~PjRtDevice() {}
|
||||
|
||||
// Must set client exactly once.
|
||||
void SetClient(PjRtClient* client) {
|
||||
CHECK(client_ == nullptr);
|
||||
client_ = client;
|
||||
}
|
||||
|
||||
// The ID of this device. IDs are unique among devices of this type
|
||||
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
|
||||
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
|
||||
int id() const { return id_; }
|
||||
|
||||
bool IsLocalDevice() const { return local_device_id_ != -1; }
|
||||
|
||||
int local_device_id() const { return local_device_id_; }
|
||||
|
||||
// If this is a device local to this host, returns a LocalDeviceState object
|
||||
// that can be used to manipulate the device. Returns nullptr if the device is
|
||||
// not local to this host.
|
||||
@ -85,7 +124,11 @@ class PjRtDevice {
|
||||
// The ID of this device's host. This is always 0 on single-host platforms.
|
||||
int host_id() const { return host_id_; }
|
||||
|
||||
const std::string& platform_name() const { return platform_name_; }
|
||||
// Return `platform_id` from client.
|
||||
PjRtPlatformId platform_id() const;
|
||||
|
||||
// Return `platform_name` from client.
|
||||
const std::string& platform_name() const;
|
||||
|
||||
// A vendor-dependent string that uniquely identifies the kind of device.
|
||||
const std::string& device_kind() const { return device_kind_; }
|
||||
@ -102,12 +145,10 @@ class PjRtDevice {
|
||||
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const;
|
||||
|
||||
private:
|
||||
friend class PjRtClient;
|
||||
|
||||
const int id_;
|
||||
const int local_device_id_; // -1 means not local.
|
||||
const std::unique_ptr<LocalDeviceState> local_device_state_;
|
||||
const int host_id_;
|
||||
const std::string platform_name_;
|
||||
const std::string device_kind_;
|
||||
PjRtClient* client_ = nullptr;
|
||||
};
|
||||
@ -155,7 +196,7 @@ class PjRtClient {
|
||||
public:
|
||||
// `allocator` may null, in which case the platform default allocator is used.
|
||||
explicit PjRtClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
PjRtPlatformId platform_id, LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
@ -178,12 +219,16 @@ class PjRtClient {
|
||||
return id_to_device_;
|
||||
}
|
||||
int host_id() const { return host_id_; }
|
||||
PjRtPlatformId platform_id() const { return platform_id_; }
|
||||
const std::string& platform_name() const { return platform_name_; }
|
||||
|
||||
LocalDeviceState& device_state(int device_ordinal) const {
|
||||
return *local_devices_.at(device_ordinal)->local_device_state();
|
||||
}
|
||||
|
||||
// Return a local PjRtDevice for a given `local_device_id`.
|
||||
virtual StatusOr<PjRtDevice*> LookupLocalDevice(int local_device_id) const;
|
||||
|
||||
LocalClient* client() const { return client_; }
|
||||
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
||||
tensorflow::Allocator* host_memory_allocator() const {
|
||||
@ -280,6 +325,16 @@ class PjRtClient {
|
||||
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||
PjRtCrossHostRecvNotifier&& notifier);
|
||||
|
||||
virtual StatusOr<ChannelHandle> CreateChannelHandle() {
|
||||
return client()->CreateChannelHandle();
|
||||
}
|
||||
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
|
||||
return client()->CreateDeviceToHostChannelHandle();
|
||||
}
|
||||
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
|
||||
return client()->CreateHostToDeviceChannelHandle();
|
||||
}
|
||||
|
||||
protected:
|
||||
friend class PjRtBuffer;
|
||||
virtual void EnqueueCrossHostReceive(
|
||||
@ -293,7 +348,8 @@ class PjRtClient {
|
||||
return Unimplemented("Cross host sends not implemented.");
|
||||
}
|
||||
|
||||
std::string platform_name_;
|
||||
const PjRtPlatformId platform_id_;
|
||||
const std::string platform_name_;
|
||||
LocalClient* client_;
|
||||
|
||||
// Allocator to be used for staging memory transfers to devices.
|
||||
@ -509,7 +565,7 @@ class PjRtBuffer {
|
||||
PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
|
||||
PjRtClient* client, PjRtDevice* device);
|
||||
~PjRtBuffer();
|
||||
virtual ~PjRtBuffer();
|
||||
|
||||
PjRtBuffer(const PjRtBuffer&) = delete;
|
||||
PjRtBuffer(PjRtBuffer&&) = delete;
|
||||
@ -519,6 +575,7 @@ class PjRtBuffer {
|
||||
const Shape& on_host_shape() const { return on_host_shape_; }
|
||||
const Shape& on_device_shape() const { return on_device_shape_; }
|
||||
PjRtDevice* device() const { return device_; }
|
||||
PjRtPlatformId platform_id() const { return client_->platform_id(); }
|
||||
const std::string& platform_name() const { return client_->platform_name(); }
|
||||
PjRtClient* client() const { return client_; }
|
||||
bool IsEmptyTuple() const {
|
||||
@ -611,6 +668,9 @@ class PjRtBuffer {
|
||||
// immediate use on the device. Useful in particular for timing benchmarks.
|
||||
Status BlockHostUntilReady();
|
||||
|
||||
// Whether this buffer is on CPU and thus allows for certain optimizations.
|
||||
bool IsOnCpu() const;
|
||||
|
||||
private:
|
||||
friend class PjRtClient;
|
||||
// The cached value of the buffer on the host, produced either from a call to
|
||||
@ -782,6 +842,9 @@ class PjRtExecutable {
|
||||
|
||||
const string& name() const;
|
||||
|
||||
// Return an HloModule per partition.
|
||||
StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules();
|
||||
|
||||
protected:
|
||||
bool parameter_is_tupled_arguments() const {
|
||||
return parameter_is_tupled_arguments_;
|
||||
|
@ -118,7 +118,7 @@ PjRtTpuClient::PjRtTpuClient(LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices,
|
||||
int host_id,
|
||||
tf_tpu::TpuPlatformInterface* tpu_platform)
|
||||
: PjRtClient("tpu", client, std::move(devices), host_id,
|
||||
: PjRtClient(PjRtPlatformId::kTpu, client, std::move(devices), host_id,
|
||||
/*allocator=*/nullptr,
|
||||
/*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
@ -145,7 +145,7 @@ StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
|
||||
return InvalidArgument(
|
||||
"Passed executable from different client (platform '%s') to "
|
||||
"PjRtTpuClient::ExecutableFingerprint",
|
||||
executable.client()->platform_name());
|
||||
Name(executable.client()->platform_id()));
|
||||
}
|
||||
if (executable.executables().size() > 1) {
|
||||
LOG(INFO) << "ExecutableFingerprint not fully implemented for MPMD "
|
||||
|
@ -33,7 +33,7 @@ class PjRtTpuDevice : public PjRtDevice {
|
||||
int host_id, const std::array<int, 3>& coords,
|
||||
std::string device_kind)
|
||||
: PjRtDevice(core.Id(), std::move(local_device_state),
|
||||
/*platform_name=*/"tpu", std::move(device_kind), host_id),
|
||||
std::move(device_kind), host_id),
|
||||
core_(core),
|
||||
coords_(coords) {}
|
||||
|
||||
|
@ -228,35 +228,31 @@ StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
|
||||
StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) {
|
||||
DLContext context;
|
||||
TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
|
||||
context.device_id = device.local_device_state()->device_ordinal();
|
||||
context.device_id = device.local_device_id();
|
||||
return context;
|
||||
}
|
||||
|
||||
StatusOr<PjRtDevice*> DeviceForDLContext(const PjRtClient& client,
|
||||
const DLContext& context) {
|
||||
se::Platform::Id platform_id;
|
||||
switch (context.device_type) {
|
||||
case kDLCPU:
|
||||
platform_id = se::host::kHostPlatformId;
|
||||
break;
|
||||
if (client.platform_id() != PjRtPlatformId::kCpu) {
|
||||
return InvalidArgument(
|
||||
"DLPack CPU device type mismatch with PjRtClient platform %s",
|
||||
client.platform_name());
|
||||
}
|
||||
return client.LookupLocalDevice(context.device_id);
|
||||
case kDLGPU:
|
||||
platform_id = se::cuda::kCudaPlatformId;
|
||||
break;
|
||||
if (client.platform_id() != PjRtPlatformId::kNvidiaGpu) {
|
||||
return InvalidArgument(
|
||||
"DLPack GPU device type mismatch with PjRtClient platform %s",
|
||||
client.platform_name());
|
||||
}
|
||||
return client.LookupLocalDevice(context.device_id);
|
||||
default:
|
||||
return InvalidArgument("Unknown/unsupported DLPack device type %d",
|
||||
context.device_type);
|
||||
}
|
||||
auto it = absl::c_find_if(client.local_devices(), [&](PjRtDevice* device) {
|
||||
return device->local_device_state()->executor()->platform()->id() ==
|
||||
platform_id &&
|
||||
device->local_device_state()->device_ordinal() == context.device_id;
|
||||
});
|
||||
if (it == client.local_devices().end()) {
|
||||
return InvalidArgument(
|
||||
"No matching device found for DLPack device_type %d device_id %d",
|
||||
context.device_type, context.device_id);
|
||||
}
|
||||
return *it;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -301,8 +297,7 @@ StatusOr<py::capsule> BufferToDLPackManagedTensor(py::handle py_buffer,
|
||||
pack->tensor.manager_ctx = pack.get();
|
||||
pack->tensor.deleter = DLPackTensorDeleter;
|
||||
TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
|
||||
dt.ctx.device_id =
|
||||
buffer->buffer()->device()->local_device_state()->device_ordinal();
|
||||
dt.ctx.device_id = buffer->buffer()->device()->local_device_id();
|
||||
dt.ndim = buffer->buffer()->on_host_shape().dimensions_size();
|
||||
TF_ASSIGN_OR_RETURN(dt.dtype,
|
||||
PrimitiveTypeToDLDataType(
|
||||
|
@ -144,7 +144,7 @@ int PjRtBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) {
|
||||
// Additionally we call BlockHostUntilReady() below, which may block.
|
||||
py::gil_scoped_release gil_release;
|
||||
|
||||
if (buffer.device()->platform_name() != "cpu") {
|
||||
if (!buffer.IsOnCpu()) {
|
||||
return InvalidArgument(
|
||||
"Python buffer protocol is only defined for CPU buffers.");
|
||||
}
|
||||
|
@ -112,13 +112,13 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
||||
int num_replicas);
|
||||
|
||||
StatusOr<ChannelHandle> CreateChannelHandle() {
|
||||
return pjrt_client_->client()->CreateChannelHandle();
|
||||
return pjrt_client_->CreateChannelHandle();
|
||||
}
|
||||
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
|
||||
return pjrt_client_->client()->CreateDeviceToHostChannelHandle();
|
||||
return pjrt_client_->CreateDeviceToHostChannelHandle();
|
||||
}
|
||||
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
|
||||
return pjrt_client_->client()->CreateHostToDeviceChannelHandle();
|
||||
return pjrt_client_->CreateHostToDeviceChannelHandle();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
|
||||
|
@ -135,15 +135,7 @@ PyExecutable::ExecuteOnLocalDevices(
|
||||
|
||||
StatusOr<std::vector<std::shared_ptr<HloModule>>> PyExecutable::HloModules()
|
||||
const {
|
||||
std::vector<std::shared_ptr<HloModule>> modules;
|
||||
modules.reserve(executable_->executables().size());
|
||||
for (const auto& local_exec : executable_->executables()) {
|
||||
if (!local_exec->executable()->has_module()) {
|
||||
return InvalidArgument("Executable does not have HLO modules.");
|
||||
}
|
||||
modules.push_back(local_exec->executable()->shared_module());
|
||||
}
|
||||
return std::move(modules);
|
||||
return executable_->GetHloModules();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -37,7 +37,7 @@ namespace xla {
|
||||
|
||||
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||
int core_on_chip)
|
||||
: xla::PjRtDevice(id, /*local_device_state=*/nullptr, kTpuPlatform,
|
||||
: xla::PjRtDevice(id, /*local_device_state=*/nullptr,
|
||||
/*device_kind=*/"Cloud TPU", host_id),
|
||||
coords_(coords),
|
||||
core_on_chip_(core_on_chip) {}
|
||||
|
@ -641,9 +641,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
[](py::object buffer_obj) -> StatusOr<py::object> {
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
|
||||
LocalDeviceState* state =
|
||||
buffer->buffer()->device()->local_device_state();
|
||||
if (state->executor()->platform_kind() == se::PlatformKind::kHost &&
|
||||
if (buffer->buffer()->IsOnCpu() &&
|
||||
buffer->buffer()->on_device_shape().IsArray() &&
|
||||
buffer->buffer()->on_device_shape().element_type() != BF16) {
|
||||
py::object out = py::reinterpret_steal<py::object>(
|
||||
|
Loading…
x
Reference in New Issue
Block a user