[PJRT] Rename Device to PjRtDevice for consistency with the other PjRt classes.

Rename only; no functional changes.

PiperOrigin-RevId: 327476904
Change-Id: I85a410bd68ad9192d5a43107b94cad4e1aeb83f0
This commit is contained in:
Peter Hawkins 2020-08-19 11:51:47 -07:00 committed by TensorFlower Gardener
parent 800b502f00
commit 39e459f513
25 changed files with 217 additions and 205 deletions

View File

@ -25,8 +25,8 @@ static const char kCpuPlatformName[] = "cpu";
CpuDevice::CpuDevice(int id,
std::unique_ptr<LocalDeviceState> local_device_state)
: Device(id, std::move(local_device_state), kCpuPlatformName,
/*device_kind=*/kCpuPlatformName) {}
: PjRtDevice(id, std::move(local_device_state), kCpuPlatformName,
/*device_kind=*/kCpuPlatformName) {}
StatusOr<std::shared_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
TF_ASSIGN_OR_RETURN(se::Platform * platform,
@ -39,7 +39,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
TF_ASSIGN_OR_RETURN(LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options));
std::vector<std::unique_ptr<Device>> devices;
std::vector<std::unique_ptr<PjRtDevice>> devices;
for (int i = 0; i < client->device_count(); ++i) {
se::StreamExecutorConfig config;
config.ordinal = i;

View File

@ -23,7 +23,7 @@ limitations under the License.
namespace xla {
class CpuDevice : public Device {
class CpuDevice : public PjRtDevice {
public:
CpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state);
};

View File

@ -32,7 +32,7 @@ TEST(GpuMultiStream, Basics) {
GetNvidiaGpuClient(/*asynchronous=*/true, GpuAllocatorConfig(),
/*distributed_client=*/nullptr, /*node_id=*/0));
Device* device = client->local_devices().at(0);
PjRtDevice* device = client->local_devices().at(0);
int n = 1024;
Shape shape = ShapeUtil::MakeShape(S32, {n});

View File

@ -25,8 +25,8 @@ static const char kInterpreterPlatformName[] = "interpreter";
InterpreterDevice::InterpreterDevice(
int id, std::unique_ptr<LocalDeviceState> local_device_state)
: Device(id, std::move(local_device_state), kInterpreterPlatformName,
/*device_kind=*/kInterpreterPlatformName) {}
: PjRtDevice(id, std::move(local_device_state), kInterpreterPlatformName,
/*device_kind=*/kInterpreterPlatformName) {}
StatusOr<std::shared_ptr<PjRtClient>> GetInterpreterClient() {
TF_ASSIGN_OR_RETURN(se::Platform * platform,
@ -40,7 +40,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetInterpreterClient() {
TF_ASSIGN_OR_RETURN(LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options));
std::vector<std::unique_ptr<Device>> devices;
std::vector<std::unique_ptr<PjRtDevice>> devices;
se::StreamExecutor* executor =
client->backend().stream_executor(0).ValueOrDie();
auto device_state = absl::make_unique<LocalDeviceState>(

View File

@ -23,7 +23,7 @@ limitations under the License.
namespace xla {
class InterpreterDevice : public Device {
class InterpreterDevice : public PjRtDevice {
public:
InterpreterDevice(int id,
std::unique_ptr<LocalDeviceState> local_device_state);

View File

@ -207,9 +207,9 @@ StatusOr<std::string> NcclIdStore::GetNcclUniqueId(const NcclCliqueKey& key) {
return cache_.emplace(key_string, result.ValueOrDie()).first->second;
}
std::vector<std::unique_ptr<Device>> BuildLocalDevices(
std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
std::vector<std::unique_ptr<Device>> devices;
std::vector<std::unique_ptr<PjRtDevice>> devices;
for (auto& local_device : local_device_states) {
int device_ordinal = local_device->device_ordinal();
const se::DeviceDescription& description =
@ -225,7 +225,7 @@ std::vector<std::unique_ptr<Device>> BuildLocalDevices(
Status BuildDistributedDevices(
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
std::vector<std::unique_ptr<Device>>* devices,
std::vector<std::unique_ptr<PjRtDevice>>* devices,
GpuExecutableRunOptions* gpu_executable_run_options) {
LocalTopologyProto local_topology;
local_topology.set_node_id(node_id);
@ -286,8 +286,8 @@ Status BuildDistributedDevices(
GpuDevice::GpuDevice(int id,
std::unique_ptr<LocalDeviceState> local_device_state,
std::string device_kind, int node_id)
: Device(id, std::move(local_device_state), kGpuPlatformName,
std::move(device_kind), node_id) {}
: PjRtDevice(id, std::move(local_device_state), kGpuPlatformName,
std::move(device_kind), node_id) {}
StatusOr<std::shared_ptr<PjRtClient>> GetNvidiaGpuClient(
bool asynchronous, const GpuAllocatorConfig& allocator_config,
@ -302,7 +302,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetNvidiaGpuClient(
auto host_memory_allocator =
GetGpuHostAllocator(local_device_states.front()->executor());
std::vector<std::unique_ptr<Device>> devices;
std::vector<std::unique_ptr<PjRtDevice>> devices;
auto gpu_run_options = absl::make_unique<GpuExecutableRunOptions>();
if (distributed_client) {
TF_RETURN_IF_ERROR(BuildDistributedDevices(

View File

@ -25,7 +25,7 @@ limitations under the License.
namespace xla {
class GpuDevice : public Device {
class GpuDevice : public PjRtDevice {
public:
GpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state,
std::string device_kind, int node_id);

View File

@ -112,19 +112,19 @@ limitations under the License.
namespace xla {
StatusOr<LocalDeviceState*> Device::GetLocalDeviceState() const {
StatusOr<LocalDeviceState*> PjRtDevice::GetLocalDeviceState() const {
if (local_device_state_) {
return local_device_state_.get();
}
return InvalidArgument("Device %s is not a local device.", DebugString());
}
std::string Device::DebugString() const {
std::string PjRtDevice::DebugString() const {
return absl::StrCat(platform_name(), ":", id());
}
StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
absl::Span<const std::vector<Device*>> devices) {
absl::Span<const std::vector<PjRtDevice*>> devices) {
if (devices.empty()) {
return InvalidArgument(
"Device assignment passed to Compile() must be non-empty.");
@ -175,7 +175,7 @@ class CpuAllocator : public tensorflow::Allocator {
PjRtClient::PjRtClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<Device>> devices, int host_id,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
@ -201,7 +201,7 @@ PjRtClient::PjRtClient(
host_memory_allocator_ = std::make_unique<CpuAllocator>();
}
for (const std::unique_ptr<Device>& device : devices_) {
for (const std::unique_ptr<PjRtDevice>& device : devices_) {
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
<< "Duplicate device id: " << device->id();
@ -376,8 +376,9 @@ void RecordUsage(PjRtBuffer::ScopedHold device_buffer,
// It is safe to delete the returned PjRtBuffer without further
// synchronization if an error occurs before the buffer is used.
StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
const Shape& on_host_shape, Device* device, LocalDeviceState* local_device,
se::Stream* copy_stream, bool is_uninitialized_create, PjRtClient* client) {
const Shape& on_host_shape, PjRtDevice* device,
LocalDeviceState* local_device, se::Stream* copy_stream,
bool is_uninitialized_create, PjRtClient* client) {
if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) {
return InvalidArgument("Can't make a buffer from an empty tuple");
}
@ -574,7 +575,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtClient* client,
Device* device) {
PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer");
VLOG(2) << "PjRtBuffer::FromHostBuffer: shape: " << shape.ToString()
<< " device: " << device->DebugString();
@ -736,7 +737,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
/* static */
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CreateUninitialized(
const Shape& shape, PjRtClient* client, Device* device) {
const Shape& shape, PjRtClient* client, PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme("PjRtBuffer::CreateUninitialized");
VLOG(2) << "PjRtBuffer::CreateUninitialized: shape: " << shape.ToString()
<< " device: " << device->DebugString();
@ -755,7 +756,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CreateUninitialized(
/* static */
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
const LiteralSlice& literal, PjRtClient* client, Device* device) {
const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostLiteral");
VLOG(2) << "PjRtBuffer::FromHostLiteral: shape: "
<< literal.shape().ToString() << " device: " << device->DebugString();
@ -815,7 +816,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
}
/*static*/ void PjRtBuffer::MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtClient* client, Device* device,
absl::Span<const Shape> shapes, PjRtClient* client, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier) {
if (shapes.empty()) {
notifier(InvalidArgument(
@ -849,7 +850,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
PjRtClient* client, Device* device)
PjRtClient* client, PjRtDevice* device)
: client_(client),
on_host_shape_(std::move(on_host_shape)),
on_device_shape_(std::move(on_device_shape)),
@ -1189,7 +1190,7 @@ PjRtBuffer::ScopedHold PjRtBuffer::GetBufferWithHold(ScopedHold::Type type) {
StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
std::shared_ptr<BufferSequencingEvent>>>
PjRtBuffer::CopyToDeviceHelper(
Device* dst_device, LocalDeviceState* dst_local_device,
PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
LocalDeviceState* transfer_local_device, se::Stream* transfer_stream,
std::shared_ptr<TrackedDeviceBuffer> src_device_buffer) {
TF_ASSIGN_OR_RETURN(
@ -1249,7 +1250,7 @@ PjRtBuffer::CopyToDeviceHelper(
}
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
Device* dst_device) {
PjRtDevice* dst_device) {
tensorflow::profiler::TraceMe traceme("PjRtBuffer::CopyToDevice");
if (dst_device == device_) {
return InvalidArgument(
@ -1420,7 +1421,7 @@ StatusOr<TupleHandle> MakeTupleHelper(
std::unique_ptr<PjRtBuffer> OutputBufferHelper(
ScopedShapedBuffer* result_buffer,
std::shared_ptr<BufferSequencingEvent> definition_event, PjRtClient* client,
Device* device, LocalDeviceState* local_device) {
PjRtDevice* device, LocalDeviceState* local_device) {
std::shared_ptr<TrackedDeviceBuffer> out_buffer =
TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer,
{definition_event});
@ -1433,7 +1434,7 @@ std::unique_ptr<PjRtBuffer> OutputBufferHelper(
return pjrt_buffer;
}
static Device* LookupDevice(const PjRtClient& client, int device_id) {
static PjRtDevice* LookupDevice(const PjRtClient& client, int device_id) {
auto it = client.id_to_device().find(device_id);
CHECK(it != client.id_to_device().end())
<< "Unknown device id: " << device_id;
@ -1447,7 +1448,7 @@ PjRtExecutable::PjRtExecutable(
bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<std::pair<int, int>> local_logical_device_ids,
std::vector<Device*> local_devices, PjRtClient* client)
std::vector<PjRtDevice*> local_devices, PjRtClient* client)
: client_(client),
device_assignment_(std::move(device_assignment)),
parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
@ -1559,7 +1560,7 @@ PjRtExecutable::MakeExecutionInputsAndWaitForEvents(
StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
Device* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers,
PjRtDevice* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers,
std::shared_ptr<DeviceAssignment> device_assignment) const {
int device_ordinal = device->local_device_state()->device_ordinal();
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
@ -1695,7 +1696,7 @@ std::vector<std::unique_ptr<PjRtBuffer>> PjRtExecutable::MakeOutputBuffers(
int device_ordinal, const ExecuteOptions& options,
ScopedShapedBuffer result_buffer,
std::shared_ptr<BufferSequencingEvent> definition_event,
Device* device) const {
PjRtDevice* device) const {
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
@ -1729,7 +1730,7 @@ StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtExecutable::ExecuteHelper(absl::Span<PjRtBuffer* const> argument_handles,
int replica, int partition, const RunId& run_id,
const ExecuteOptions& options,
Device* device) const {
PjRtDevice* device) const {
std::shared_ptr<DeviceAssignment> device_assignment;
if (device == nullptr) {
CHECK(device_assignment_ != nullptr);
@ -1828,7 +1829,7 @@ StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> PjRtExecutable::Execute(
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtExecutable::ExecuteOnLocalDevice(
absl::Span<PjRtBuffer* const> argument_handles, Device* device,
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const {
if (device_assignment_ == nullptr) {
VLOG(1) << "Executing portable single-core program on "
@ -1894,7 +1895,7 @@ PjRtExecutable::ExecuteOnLocalDevices(
for (int i = 0; i < num_local_devices; ++i) {
const int replica = local_logical_device_ids_[i].first;
const int partition = local_logical_device_ids_[i].second;
Device* device = local_devices_[i];
PjRtDevice* device = local_devices_[i];
const LocalDeviceState& device_state = *device->local_device_state();
device_state.execute_thread()->Schedule([&, replica, partition, i] {
results[i] = ExecuteHelper(argument_handles[i], replica, partition,
@ -2141,12 +2142,12 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
build_options.set_result_layout(result_layout);
std::vector<std::pair<int, int>> local_logical_device_ids;
std::vector<Device*> local_devices;
std::vector<PjRtDevice*> local_devices;
if (device_assignment != nullptr) {
for (int replica = 0; replica < num_replicas; ++replica) {
for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = (*device_assignment)(replica, partition);
Device* device = LookupDevice(*client, device_id);
PjRtDevice* device = LookupDevice(*client, device_id);
if (device->host_id() != client->host_id()) {
VLOG(3) << "Non-local device: " << device_id;
continue;

View File

@ -52,17 +52,18 @@ namespace xla {
class PjRtClient;
class Device {
class PjRtDevice {
public:
explicit Device(int id, std::unique_ptr<LocalDeviceState> local_device_state,
std::string platform_name, std::string device_kind,
int host_id = 0)
explicit PjRtDevice(int id,
std::unique_ptr<LocalDeviceState> local_device_state,
std::string platform_name, std::string device_kind,
int host_id = 0)
: id_(id),
local_device_state_(std::move(local_device_state)),
host_id_(host_id),
platform_name_(std::move(platform_name)),
device_kind_(std::move(device_kind)) {}
virtual ~Device() {}
virtual ~PjRtDevice() {}
// The ID of this device. IDs are unique among devices of this type
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
@ -130,7 +131,7 @@ class PjRtClient {
// `allocator` may null, in which case the platform default allocator is used.
explicit PjRtClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<Device>> devices, int host_id,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
@ -142,11 +143,15 @@ class PjRtClient {
int device_count() const { return devices_.size(); }
int local_device_count() const { return local_devices_.size(); }
const std::vector<std::unique_ptr<Device>>& devices() const {
const std::vector<std::unique_ptr<PjRtDevice>>& devices() const {
return devices_;
}
const std::vector<Device*>& local_devices() const { return local_devices_; }
const std::map<int, Device*>& id_to_device() const { return id_to_device_; }
const std::vector<PjRtDevice*>& local_devices() const {
return local_devices_;
}
const std::map<int, PjRtDevice*>& id_to_device() const {
return id_to_device_;
}
int host_id() const { return host_id_; }
const std::string& platform_name() const { return platform_name_; }
@ -210,11 +215,11 @@ class PjRtClient {
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
// Includes all devices, including non-local devices on multi-host platforms.
std::vector<std::unique_ptr<Device>> devices_;
std::vector<std::unique_ptr<PjRtDevice>> devices_;
// Maps Device::id() to the corresponding Device. Includes all devices.
std::map<int, Device*> id_to_device_;
std::map<int, PjRtDevice*> id_to_device_;
// Local devices indexed by local device ordinal.
std::vector<Device*> local_devices_;
std::vector<PjRtDevice*> local_devices_;
int host_id_;
se::DeviceMemoryAllocator* allocator_;
@ -233,7 +238,7 @@ class PjRtClient {
// Converts a 2D set of Device objects indexed by [replica][partition] into an
// xla::DeviceAssignment.
StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
absl::Span<const std::vector<Device*>> devices);
absl::Span<const std::vector<PjRtDevice*>> devices);
// Holds a reference from Python to a tuple of device buffers. A PjRtBuffer
// can be either valid or invalid. An invalid buffer is one that has never been
@ -417,7 +422,7 @@ class PjRtBuffer {
// Returns a buffer with uninitialized contents.
static StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitialized(
const Shape& shape, PjRtClient* client, Device* device);
const Shape& shape, PjRtClient* client, PjRtDevice* device);
// Describes the semantics the caller to FromHostBuffer expects from the
// runtime, in a total order from most restrictive to least restrictive.
@ -449,13 +454,13 @@ class PjRtBuffer {
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtClient* client,
Device* device);
PjRtDevice* device);
// Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on
// the return value before letting literal go out of scope.
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostLiteral(
const LiteralSlice& literal, PjRtClient* client, Device* device);
const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device);
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
// cross host transfers using `client` on `device'. `shapes` must be the exact
@ -467,12 +472,13 @@ class PjRtBuffer {
// sending host and used in a call to CopyToRemoteDevice. None of the recv
// buffers will become ready until *all* of the sends have completed.
static void MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
PjRtClient* client, Device* device,
PjRtClient* client,
PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier);
PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
PjRtClient* client, Device* device);
PjRtClient* client, PjRtDevice* device);
~PjRtBuffer();
PjRtBuffer(const PjRtBuffer&) = delete;
@ -482,7 +488,7 @@ class PjRtBuffer {
const Shape& on_host_shape() const { return on_host_shape_; }
const Shape& on_device_shape() const { return on_device_shape_; }
Device* device() const { return device_; }
PjRtDevice* device() const { return device_; }
const std::string& platform_name() const { return client_->platform_name(); }
PjRtClient* client() const { return client_; }
bool IsEmptyTuple() const {
@ -556,7 +562,7 @@ class PjRtBuffer {
// Copies the buffer to device `dst_device`. Returns an error if the buffer is
// already on dst_device.
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(Device* dst_device);
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(PjRtDevice* dst_device);
// Copies the buffer to the remote device encoded in serialized_descriptor.
// This call must be preceded by a call to MakeCrossHostReceiveBuffers on the
@ -629,7 +635,7 @@ class PjRtBuffer {
StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
std::shared_ptr<BufferSequencingEvent>>>
CopyToDeviceHelper(Device* dst_device, LocalDeviceState* dst_local_device,
CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
LocalDeviceState* transfer_local_device,
se::Stream* transfer_stream,
std::shared_ptr<TrackedDeviceBuffer> src_device_buffer);
@ -637,7 +643,7 @@ class PjRtBuffer {
PjRtClient* const client_;
const Shape on_host_shape_;
const Shape on_device_shape_;
Device* const device_;
PjRtDevice* const device_;
mutable absl::Mutex mu_;
std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_);
@ -707,7 +713,7 @@ class PjRtExecutable {
bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<std::pair<int, int>> local_logical_device_ids,
std::vector<Device*> local_devices, PjRtClient* client);
std::vector<PjRtDevice*> local_devices, PjRtClient* client);
virtual ~PjRtExecutable() = default;
@ -741,14 +747,16 @@ class PjRtExecutable {
return local_logical_device_ids_;
}
const std::vector<Device*>& local_devices() const { return local_devices_; }
const std::vector<PjRtDevice*>& local_devices() const {
return local_devices_;
}
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> Execute(
absl::Span<PjRtBuffer* const> argument_handles,
const ExecuteOptions& options) const;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteOnLocalDevice(
absl::Span<PjRtBuffer* const> argument_handles, Device* device,
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const;
// Execute on local devices. Takes a sequence of argument lists (one argument
@ -786,7 +794,7 @@ class PjRtExecutable {
StatusOr<ScopedShapedBuffer> EnqueueExecution(
absl::Span<PjRtBuffer* const> argument_handles, int replica,
int partition, int executable_idx, const RunId& run_id,
const ExecuteOptions& options, Device* device,
const ExecuteOptions& options, PjRtDevice* device,
std::vector<PjRtBuffer::ScopedHold>* device_buffers,
std::shared_ptr<DeviceAssignment> device_assignment) const;
@ -794,12 +802,12 @@ class PjRtExecutable {
int device_ordinal, const ExecuteOptions& options,
ScopedShapedBuffer result_buffer,
std::shared_ptr<BufferSequencingEvent> definition_event,
Device* device) const;
PjRtDevice* device) const;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper(
absl::Span<PjRtBuffer* const> argument_handles, int replica,
int partition, const RunId& run_id, const ExecuteOptions& options,
Device* device = nullptr) const;
PjRtDevice* device = nullptr) const;
// Create shared pointers so we can free them after the execution: with
// asynchronous execution, the process being executed can outlive the
@ -828,7 +836,7 @@ class PjRtExecutable {
// assigned.
// shared_ptrs instead of unique_ptrs to play well with the Python bindings
// (see xla.cc).
std::vector<Device*> local_devices_;
std::vector<PjRtDevice*> local_devices_;
};
} // namespace xla

View File

@ -193,7 +193,7 @@ StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
return minor_to_major;
}
StatusOr<DLDeviceType> DLDeviceTypeForDevice(const Device& device) {
StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
const se::Platform* platform =
device.local_device_state()->executor()->platform();
if (platform->id() == se::host::kHostPlatformId) {
@ -205,15 +205,15 @@ StatusOr<DLDeviceType> DLDeviceTypeForDevice(const Device& device) {
device.DebugString());
}
StatusOr<DLContext> DLContextForDevice(const Device& device) {
StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) {
DLContext context;
TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
context.device_id = device.local_device_state()->device_ordinal();
return context;
}
StatusOr<Device*> DeviceForDLContext(const PjRtClient& client,
const DLContext& context) {
StatusOr<PjRtDevice*> DeviceForDLContext(const PjRtClient& client,
const DLContext& context) {
se::Platform::Id platform_id;
switch (context.device_type) {
case kDLCPU:
@ -226,7 +226,7 @@ StatusOr<Device*> DeviceForDLContext(const PjRtClient& client,
return InvalidArgument("Unknown/unsupported DLPack device type %d",
context.device_type);
}
auto it = absl::c_find_if(client.local_devices(), [&](Device* device) {
auto it = absl::c_find_if(client.local_devices(), [&](PjRtDevice* device) {
return device->local_device_state()->executor()->platform()->id() ==
platform_id &&
device->local_device_state()->device_ordinal() == context.device_id;
@ -313,7 +313,7 @@ StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
dlmt->dl_tensor.ndim);
}
TF_ASSIGN_OR_RETURN(
Device * device,
PjRtDevice * device,
DeviceForDLContext(*client->pjrt_client(), dlmt->dl_tensor.ctx));
absl::Span<int64 const> dimensions(
reinterpret_cast<int64*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);

View File

@ -217,7 +217,7 @@ std::string CallSignature::DebugString() const {
struct CacheEntry {
std::shared_ptr<xla::PyExecutable> executable;
xla::Device* device;
xla::PjRtDevice* device;
PyTreeDef out_pytree_def;
// These are the objects required to create a `DeviceArray` object.
// We use Python types within the vector because this is what we will be
@ -235,7 +235,7 @@ class CompiledFunction {
CompiledFunction(py::function cache_miss_fun, py::function python_f_jitted,
bool jax_enable_x64, std::vector<int> static_argnums,
std::shared_ptr<xla::PyClient> pyclient,
xla::Device* device);
xla::PjRtDevice* device);
~CompiledFunction();
// This function will:
@ -268,7 +268,7 @@ class CompiledFunction {
absl::flat_hash_map<CallSignature, std::unique_ptr<CacheEntry>> executables_;
const std::shared_ptr<xla::PyClient> pyclient_;
xla::Device* const default_device_;
xla::PjRtDevice* const default_device_;
};
CompiledFunction::CompiledFunction(py::function cache_miss_fun,
@ -276,7 +276,7 @@ CompiledFunction::CompiledFunction(py::function cache_miss_fun,
bool jax_enable_x64,
std::vector<int> static_argnums,
std::shared_ptr<xla::PyClient> pyclient,
xla::Device* device)
xla::PjRtDevice* device)
: cache_miss_fun_(std::move(cache_miss_fun)),
python_f_jitted_(std::move(python_f_jitted)),
jax_enable_x64_(jax_enable_x64),
@ -374,9 +374,9 @@ void FlattenArguments(const py::args& args, const py::kwargs& py_kwargs,
}
template <typename CppType, typename Pybind11Type>
std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(const py::handle& scalar,
xla::PjRtClient* client,
xla::Device* device) {
std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(
const py::handle& scalar, xla::PjRtClient* client,
xla::PjRtDevice* device) {
CppType data = py::cast<Pybind11Type>(scalar);
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<CppType>({});
return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer(
@ -389,7 +389,7 @@ std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(const py::handle& scalar,
// not convertible (thus, this must be called after other checks).
StatusOr<std::unique_ptr<xla::PjRtBuffer>> ScalarToBuffer(
py::handle scalar, bool jax_enable_x64, xla::PjRtClient* client,
xla::Device* device) {
xla::PjRtDevice* device) {
// Important: In Python, isinstance(True, int) returns True. Thus, we have
// to check for bool before int.
if (py::isinstance<py::bool_>(scalar)) {
@ -467,7 +467,7 @@ const py::dtype* DtypeTo32BitDtype(const py::dtype& dtype) {
//
// Returns `OkStatus()` on success.
Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
xla::Device* default_device,
xla::PjRtDevice* default_device,
ParsedArgumentsAsBuffers& arguments) {
std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
auto& keep_alive = arguments.keep_alive;
@ -490,12 +490,12 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
// https://github.com/google/jax/pull/1916 for the rationale why the
// computation follows the data locality.
// It's also similar to PyTorch's behavior.
xla::Device* data_device = nullptr;
xla::PjRtDevice* data_device = nullptr;
for (py::handle arg : arguments.flat_dynamic_args) {
if (py::isinstance(arg, device_array)) {
xla::PyBuffer* buffer =
py::cast<xla::PyBuffer*>(arg.attr("device_buffer"));
xla::Device* device = buffer->buffer()->device();
xla::PjRtDevice* device = buffer->buffer()->device();
if (data_device && (device != data_device)) {
return InvalidArgument(
"%s",
@ -682,7 +682,7 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
[](py::function cache_miss_fun,
py::function fallback_on_unsupported_argument,
bool jax_enable_x64, std::vector<int> static_argnums,
xla::ClientAndPtr<xla::Device> client_and_device)
xla::ClientAndPtr<xla::PjRtDevice> client_and_device)
-> std::unique_ptr<CompiledFunction> {
return std::make_unique<CompiledFunction>(
std::move(cache_miss_fun),

View File

@ -101,14 +101,14 @@ uint32_t constexpr kOutfeedCidShutdown = 0;
// Encapsulates data received from a device outfeed.
class OutfeedData {
public:
OutfeedData(Device* device, uint32_t consumer_id, Shape shape)
OutfeedData(PjRtDevice* device, uint32_t consumer_id, Shape shape)
: device_(device),
consumer_id_(consumer_id),
shape_(shape),
literal_(nullptr),
literal_size_bytes_(0) {}
Device* device() { return device_; }
PjRtDevice* device() { return device_; }
uint32_t consumer_id() const { return consumer_id_; }
Shape shape() const { return shape_; }
std::unique_ptr<Literal> literal() {
@ -123,7 +123,7 @@ class OutfeedData {
std::string DebugString() const;
private:
Device* device_;
PjRtDevice* device_;
uint32_t consumer_id_;
Shape shape_;
std::unique_ptr<Literal> literal_;
@ -187,8 +187,8 @@ class OutfeedReceiverImpl {
Status SendShutdownOutfeedHeader(int device_idx);
// Receives a raw Literal from a device outfeed.
StatusOr<std::unique_ptr<Literal>> ReceiveRawFromOutfeed(const Device* device,
const Shape& shape);
StatusOr<std::unique_ptr<Literal>> ReceiveRawFromOutfeed(
const PjRtDevice* device, const Shape& shape);
// Enqueues received data in the callbaback queue.
void EnqueueReceivedData(std::unique_ptr<OutfeedData> received)
@ -200,7 +200,7 @@ class OutfeedReceiverImpl {
OutfeedReceiver::Callback callback_;
// The devices on which we are listening.
std::vector<Device*> devices_;
std::vector<PjRtDevice*> devices_;
// Maximum bytes capacity of the callback queue.
uint64_t max_callback_queue_size_bytes_;
@ -283,7 +283,7 @@ void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
absl::MutexLock lock(&mu_);
++num_listening_threads_;
}
Device* device = devices_[device_idx];
PjRtDevice* device = devices_[device_idx];
while (true) {
Shape header_shape = ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords});
std::unique_ptr<Literal> header =
@ -339,7 +339,7 @@ void OutfeedReceiverImpl::EnqueueReceivedData(
}
StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
const Device* device, const Shape& shape) {
const PjRtDevice* device, const Shape& shape) {
std::shared_ptr<Literal> literal_shared;
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
@ -390,7 +390,7 @@ void OutfeedReceiverImpl::CallbackThreadLoop() {
}
Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
const Device* device = devices_[device_idx];
const PjRtDevice* device = devices_[device_idx];
constexpr int consumer_id = kOutfeedCidShutdown;
VLOG(2) << "[" << device->DebugString()
<< "] SendSpecialHeader cons=" << consumer_id;

View File

@ -33,7 +33,7 @@ class OutfeedReceiver {
public:
// A callback takes: device, consumer id, received.
using Callback =
std::function<void(Device*, uint32_t, std::shared_ptr<Literal>)>;
std::function<void(PjRtDevice*, uint32_t, std::shared_ptr<Literal>)>;
// Constructs the receiver for the given clients and callback function.
//

View File

@ -40,7 +40,7 @@ class OutfeedReceiverForPython {
public:
// A callback to Python takes: consumer id, received literal.
using CallbackToPython =
std::function<void(ClientAndPtr<Device>, uint32_t, pybind11::object)>;
std::function<void(ClientAndPtr<PjRtDevice>, uint32_t, pybind11::object)>;
OutfeedReceiverForPython(CallbackToPython callback_python,
std::vector<std::shared_ptr<PyClient>> clients,
@ -48,7 +48,7 @@ class OutfeedReceiverForPython {
: callback_python_(std::move(callback_python)),
clients_(std::move(clients)) {
OutfeedReceiver::Callback callback =
[this](Device* device, uint32_t consumer_id,
[this](PjRtDevice* device, uint32_t consumer_id,
std::shared_ptr<Literal> literal) {
this->Callback(device, consumer_id, std::move(literal));
};
@ -86,7 +86,7 @@ class OutfeedReceiverForPython {
arrays);
}
void Callback(Device* device, uint32_t consumer_id,
void Callback(PjRtDevice* device, uint32_t consumer_id,
std::shared_ptr<Literal> literal) {
{
absl::MutexLock lock(&mu_);
@ -106,7 +106,7 @@ class OutfeedReceiverForPython {
LiteralToPython(std::move(literal)).ValueOrDie();
// The callback_ should handle all exceptions in user-code. If we get
// an exception here, it is a bug in the callback and we should stop.
callback_python_(WrapWithClient<Device>(*it, device), consumer_id,
callback_python_(WrapWithClient<PjRtDevice>(*it, device), consumer_id,
std::move(literal_python));
}

View File

@ -78,11 +78,11 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) {
std::vector<PjRtClient*> clients{cpu_client.get()};
auto receiver = absl::make_unique<Accumulator>();
OutfeedReceiver::Callback callback = [&receiver](
Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
OutfeedReceiver::Callback callback =
[&receiver](PjRtDevice* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
auto outfeed_receiver =
std::make_shared<OutfeedReceiver>(callback, clients, 128);
outfeed_receiver->Start();
@ -111,11 +111,11 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedTwoComputations) {
std::vector<PjRtClient*> clients{cpu_client.get()};
auto receiver = absl::make_unique<Accumulator>();
OutfeedReceiver::Callback callback = [&receiver](
Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
OutfeedReceiver::Callback callback =
[&receiver](PjRtDevice* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
auto outfeed_receiver =
std::make_shared<OutfeedReceiver>(callback, clients, 128);
outfeed_receiver->Start();
@ -156,11 +156,11 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedTwoOutfeed) {
std::vector<PjRtClient*> clients{cpu_client.get()};
auto receiver = absl::make_unique<Accumulator>();
OutfeedReceiver::Callback callback = [&receiver](
Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
OutfeedReceiver::Callback callback =
[&receiver](PjRtDevice* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
auto outfeed_receiver =
std::make_shared<OutfeedReceiver>(callback, clients, 128);
outfeed_receiver->Start();
@ -199,11 +199,11 @@ TEST(OutfeedReceiverTest, DifferentShapeForConsumerIdError) {
std::vector<PjRtClient*> clients{cpu_client.get()};
auto receiver = absl::make_unique<Accumulator>();
OutfeedReceiver::Callback callback = [&receiver](
Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
OutfeedReceiver::Callback callback =
[&receiver](PjRtDevice* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
auto outfeed_receiver =
std::make_shared<OutfeedReceiver>(callback, clients, 128);
outfeed_receiver->Start();
@ -233,11 +233,11 @@ TEST(OutfeedReceiverTest, InvalidConsumerIdError) {
std::vector<PjRtClient*> clients{cpu_client.get()};
auto receiver = absl::make_unique<Accumulator>();
OutfeedReceiver::Callback callback = [&receiver](
Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
OutfeedReceiver::Callback callback =
[&receiver](PjRtDevice* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
auto outfeed_receiver =
std::make_shared<OutfeedReceiver>(callback, clients, 128);
outfeed_receiver->Start();

View File

@ -51,12 +51,12 @@ PyBuffer::~PyBuffer() {
}
}
ClientAndPtr<Device> PyBuffer::device() const {
ClientAndPtr<PjRtDevice> PyBuffer::device() const {
return WrapWithClient(client_, buffer_->device());
}
StatusOr<std::unique_ptr<PyBuffer>> PyBuffer::CopyToDevice(
const ClientAndPtr<Device>& dst_device) const {
const ClientAndPtr<PjRtDevice>& dst_device) const {
CHECK(dst_device.get() != nullptr);
GlobalPyRefManager()->CollectGarbage();
std::unique_ptr<PjRtBuffer> out;

View File

@ -38,12 +38,12 @@ class PyBuffer {
std::shared_ptr<PyClient> client() const { return client_; }
PjRtBuffer* buffer() const { return buffer_.get(); }
ClientAndPtr<Device> device() const;
ClientAndPtr<PjRtDevice> device() const;
const std::string& platform_name() const { return buffer_->platform_name(); }
bool is_deleted() const { return buffer_->IsDeleted(); }
StatusOr<std::unique_ptr<PyBuffer>> CopyToDevice(
const ClientAndPtr<Device>& dst_device) const;
const ClientAndPtr<PjRtDevice>& dst_device) const;
void Delete() { return buffer_->Delete(); }

View File

@ -33,8 +33,8 @@ namespace pprof = tensorflow::tfprof::pprof;
PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
: pjrt_client_(std::move(pjrt_client)) {}
std::vector<ClientAndPtr<Device>> PyClient::Devices() {
std::vector<ClientAndPtr<Device>> devices;
std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
std::vector<ClientAndPtr<PjRtDevice>> devices;
devices.reserve(pjrt_client_->devices().size());
for (const auto& device : pjrt_client_->devices()) {
devices.push_back(WrapWithClient(shared_from_this(), device.get()));
@ -42,21 +42,21 @@ std::vector<ClientAndPtr<Device>> PyClient::Devices() {
return devices;
}
std::vector<ClientAndPtr<Device>> PyClient::LocalDevices() {
std::vector<ClientAndPtr<Device>> devices;
std::vector<ClientAndPtr<PjRtDevice>> PyClient::LocalDevices() {
std::vector<ClientAndPtr<PjRtDevice>> devices;
devices.reserve(pjrt_client_->local_devices().size());
for (Device* device : pjrt_client_->local_devices()) {
for (PjRtDevice* device : pjrt_client_->local_devices()) {
devices.push_back(WrapWithClient(shared_from_this(), device));
}
return devices;
}
StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>>
StatusOr<std::vector<std::vector<ClientAndPtr<PjRtDevice>>>>
PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
TF_ASSIGN_OR_RETURN(
DeviceAssignment device_assignment,
pjrt_client_->GetDefaultDeviceAssignment(num_replicas, num_partitions));
std::vector<std::vector<ClientAndPtr<Device>>> result;
std::vector<std::vector<ClientAndPtr<PjRtDevice>>> result;
result.resize(num_replicas);
for (int r = 0; r < num_replicas; ++r) {
result[r].resize(num_partitions);
@ -70,12 +70,12 @@ PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
return result;
}
StatusOr<std::vector<ClientAndPtr<Device>>>
StatusOr<std::vector<ClientAndPtr<PjRtDevice>>>
PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
pjrt_client_->GetDefaultDeviceAssignment(
num_replicas, /*num_partitions=*/1));
std::vector<ClientAndPtr<Device>> result;
std::vector<ClientAndPtr<PjRtDevice>> result;
for (int i = 0; i < num_replicas; ++i) {
int device_id = device_assignment(i, 0);
auto iter = pjrt_client_->id_to_device().find(device_id);
@ -86,7 +86,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
}
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
const pybind11::object& argument, Device* device, bool force_copy,
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
PjRtBuffer::HostBufferSemantics host_buffer_semantics) {
if (device == nullptr) {
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
@ -206,7 +206,7 @@ namespace {
struct HeapProfileKey {
Traceback* traceback;
int64 size;
Device* device;
PjRtDevice* device;
bool operator==(const HeapProfileKey& other) const;
};

View File

@ -100,14 +100,14 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
int device_count() const { return pjrt_client_->device_count(); }
int host_id() const { return pjrt_client_->host_id(); }
std::vector<ClientAndPtr<Device>> Devices();
std::vector<ClientAndPtr<Device>> LocalDevices();
std::vector<ClientAndPtr<PjRtDevice>> Devices();
std::vector<ClientAndPtr<PjRtDevice>> LocalDevices();
StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>>
StatusOr<std::vector<std::vector<ClientAndPtr<PjRtDevice>>>>
GetDefaultDeviceAssignment(int num_replicas, int num_partitions);
// TODO(skye): delete after all callers can handle 2D output
StatusOr<std::vector<ClientAndPtr<Device>>> GetDefaultDeviceAssignment1D(
StatusOr<std::vector<ClientAndPtr<PjRtDevice>>> GetDefaultDeviceAssignment1D(
int num_replicas);
StatusOr<ChannelHandle> CreateChannelHandle() {
@ -121,7 +121,7 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
}
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
const pybind11::object& argument, Device* device, bool force_copy,
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
PjRtBuffer::HostBufferSemantics host_buffer_semantics);
StatusOr<std::shared_ptr<PyExecutable>> Compile(

View File

@ -58,10 +58,10 @@ PyExecutable::~PyExecutable() {
}
}
std::vector<ClientAndPtr<Device>> PyExecutable::LocalDevices() const {
std::vector<ClientAndPtr<Device>> devices;
std::vector<ClientAndPtr<PjRtDevice>> PyExecutable::LocalDevices() const {
std::vector<ClientAndPtr<PjRtDevice>> devices;
devices.reserve(executable_->local_devices().size());
for (Device* device : executable_->local_devices()) {
for (PjRtDevice* device : executable_->local_devices()) {
devices.push_back(WrapWithClient(client_, device));
}
return devices;

View File

@ -47,7 +47,7 @@ class PyExecutable {
return executable_->local_logical_device_ids();
}
std::vector<ClientAndPtr<Device>> LocalDevices() const;
std::vector<ClientAndPtr<PjRtDevice>> LocalDevices() const;
int64 SizeOfGeneratedCodeInBytes() const {
return executable_->SizeOfGeneratedCodeInBytes();

View File

@ -37,8 +37,8 @@ namespace xla {
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
int core_on_chip)
: xla::Device(id, /*local_device_state=*/nullptr, kTpuPlatform,
/*device_kind=*/"Cloud TPU", host_id),
: xla::PjRtDevice(id, /*local_device_state=*/nullptr, kTpuPlatform,
/*device_kind=*/"Cloud TPU", host_id),
coords_(coords),
core_on_chip_(core_on_chip) {}
@ -47,9 +47,9 @@ std::string TpuDevice::DebugString() const {
coords_[0], coords_[1], coords_[2], core_on_chip_);
}
xla::StatusOr<std::vector<std::shared_ptr<xla::Device>>>
xla::StatusOr<std::vector<std::shared_ptr<xla::PjRtDevice>>>
TpuDevice::GetTpuDevices(const tpu_driver::SystemInfo& system_info) {
std::vector<std::shared_ptr<Device>> devices;
std::vector<std::shared_ptr<PjRtDevice>> devices;
for (const auto& chip : system_info.tpu_chip()) {
auto& coord = chip.chip_coord();
std::array<int, 3> coords_array = {coord.x(), coord.y(), coord.z()};
@ -78,7 +78,7 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
tpu_driver::SystemInfo system_info;
client->QuerySystemInfo(&system_info);
TF_ASSIGN_OR_RETURN(std::vector<std::shared_ptr<Device>> devices,
TF_ASSIGN_OR_RETURN(std::vector<std::shared_ptr<PjRtDevice>> devices,
TpuDevice::GetTpuDevices(system_info));
return std::make_shared<PyTpuClient>(kTpuPlatform, std::move(client),
@ -88,13 +88,13 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
PyTpuClient::PyTpuClient(std::string platform_name,
std::unique_ptr<tpu_driver::TpuDriver> driver,
std::vector<std::shared_ptr<Device>> devices,
std::vector<std::shared_ptr<PjRtDevice>> devices,
int host_id)
: platform_name_(std::move(platform_name)),
driver_(std::move(driver)),
devices_(std::move(devices)),
host_id_(host_id) {
for (const std::shared_ptr<Device>& device : devices_) {
for (const std::shared_ptr<PjRtDevice>& device : devices_) {
CHECK(id_to_device_.insert({device->id(), device}).second)
<< "Duplicate device id: " << device->id();
@ -173,7 +173,7 @@ static Status CheckDataType(xla::PrimitiveType dtype) {
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::FromLiterals(
std::vector<BorrowingLiteral> leaves, const Shape& tuple_shape,
std::shared_ptr<void> leaves_references,
std::shared_ptr<PyTpuClient> client, std::shared_ptr<Device> device) {
std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device) {
tensorflow::profiler::TraceMe traceme("PyTpuBuffer::FromLiterals");
VLOG(1) << "PyTpuBuffer::FromLiterals: shape: " << tuple_shape.DebugString()
<< " device: " << device->DebugString();
@ -229,7 +229,7 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::FromLiterals(
/* static */
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::MakeTuple(
absl::Span<PyTpuBuffer* const> buffers, std::shared_ptr<PyTpuClient> client,
std::shared_ptr<Device> device) {
std::shared_ptr<PjRtDevice> device) {
std::vector<Shape> child_shapes;
std::vector<std::shared_ptr<TpuSharedBuffer>> child_device_buffers;
std::vector<tpu_driver::BufferHandle*> child_handle_ptrs;
@ -388,7 +388,7 @@ PyTpuBuffer::DestructureTuple() {
}
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CopyToDevice(
std::shared_ptr<Device> dst_device) {
std::shared_ptr<PjRtDevice> dst_device) {
tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CopyToDevice");
if (on_host_shape_.IsTuple()) {
return Unimplemented("CopyToDevice for tuples is not supported.");
@ -433,7 +433,7 @@ Status PyTpuBuffer::BlockHostUntilReady() {
/* static */
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::AllocateBuffer(
const Shape& shape, std::shared_ptr<PyTpuClient> client,
std::shared_ptr<Device> device) {
std::shared_ptr<PjRtDevice> device) {
tensorflow::profiler::TraceMe traceme("PyTpuBuffer::AllocateBuffer");
VLOG(1) << "PyTpuBuffer::AllocateBuffer: shape: " << shape.DebugString()
<< " device: " << device->DebugString();
@ -465,7 +465,7 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::AllocateBuffer(
/*static*/
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CreateBuffer(
const Shape& non_tuple_shape, absl::optional<BufferInitializer> initializer,
std::shared_ptr<PyTpuClient> client, std::shared_ptr<Device> device) {
std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device) {
tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CreateBuffer");
VLOG(1) << "PyTpuBuffer::CreateBuffer: shape: "
<< non_tuple_shape.DebugString()
@ -493,8 +493,8 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CreateBuffer(
std::vector<std::shared_ptr<TpuSharedBuffer>>(), client);
}
static std::shared_ptr<Device> LookupDevice(const PyTpuClient& client,
int device_id) {
static std::shared_ptr<PjRtDevice> LookupDevice(const PyTpuClient& client,
int device_id) {
auto it = client.id_to_device().find(device_id);
CHECK(it != client.id_to_device().end())
<< "Unknown device id: " << device_id;
@ -516,7 +516,7 @@ PyTpuExecutable::PyTpuExecutable(
for (int replica = 0; replica < num_replicas; ++replica) {
for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = device_assignment_(replica, partition);
std::shared_ptr<Device> device = LookupDevice(*client_, device_id);
std::shared_ptr<PjRtDevice> device = LookupDevice(*client_, device_id);
if (device->host_id() != client_->host_id()) {
VLOG(3) << "Non-local device: " << device_id;
continue;
@ -541,7 +541,7 @@ PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper(
absl::Span<PyTpuBuffer* const> this_core_arguments, int replica,
int partition, const RunId& run_id) {
const int device_id = device_assignment_(replica, partition);
std::shared_ptr<Device> device = LookupDevice(*client_, device_id);
std::shared_ptr<PjRtDevice> device = LookupDevice(*client_, device_id);
CHECK_EQ(device->host_id(), client_->host_id());
tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute");
VLOG(3) << "Replica " << replica << ", partition " << partition

View File

@ -38,7 +38,7 @@ namespace xla {
constexpr char kTpuPlatform[] = "tpu";
class TpuDevice : public Device {
class TpuDevice : public PjRtDevice {
public:
TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
int core_on_chip);
@ -48,8 +48,8 @@ class TpuDevice : public Device {
std::string DebugString() const override;
static xla::StatusOr<std::vector<std::shared_ptr<xla::Device>>> GetTpuDevices(
const tpu_driver::SystemInfo& system_info);
static xla::StatusOr<std::vector<std::shared_ptr<xla::PjRtDevice>>>
GetTpuDevices(const tpu_driver::SystemInfo& system_info);
private:
const std::array<int, 3> coords_;
@ -66,7 +66,7 @@ class PyTpuClient {
explicit PyTpuClient(std::string platform_name,
std::unique_ptr<tpu_driver::TpuDriver> driver,
std::vector<std::shared_ptr<Device>> devices,
std::vector<std::shared_ptr<PjRtDevice>> devices,
int host_id);
virtual ~PyTpuClient() = default;
@ -83,11 +83,11 @@ class PyTpuClient {
int device_count() const { return devices_.size(); }
int local_device_count() const { return local_devices_.size(); }
const std::vector<std::shared_ptr<Device>>& devices() { return devices_; }
const std::vector<std::shared_ptr<Device>>& local_devices() {
const std::vector<std::shared_ptr<PjRtDevice>>& devices() { return devices_; }
const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() {
return local_devices_;
}
const std::map<int, std::shared_ptr<Device>>& id_to_device() const {
const std::map<int, std::shared_ptr<PjRtDevice>>& id_to_device() const {
return id_to_device_;
}
int host_id() const { return host_id_; }
@ -110,11 +110,11 @@ class PyTpuClient {
std::unique_ptr<tpu_driver::TpuDriver> driver_;
// Includes all devices, including non-local devices on multi-host platforms.
std::vector<std::shared_ptr<Device>> devices_;
std::vector<std::shared_ptr<PjRtDevice>> devices_;
// Maps Device::id() to the corresponding Device. Includes all devices.
std::map<int, std::shared_ptr<Device>> id_to_device_;
std::map<int, std::shared_ptr<PjRtDevice>> id_to_device_;
// Local devices indexed by local device ordinal.
std::vector<std::shared_ptr<Device>> local_devices_;
std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
int host_id_;
// A thread pool for scheduling core executions in parallel.
@ -128,7 +128,7 @@ struct TpuSharedBuffer final {
TpuSharedBuffer(tpu_driver::TpuDriver* driver,
std::unique_ptr<tpu_driver::BufferHandle> handle,
std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use,
std::shared_ptr<Device> src_device)
std::shared_ptr<PjRtDevice> src_device)
: driver(driver),
device(std::move(src_device)),
handle(std::move(handle)),
@ -143,7 +143,7 @@ struct TpuSharedBuffer final {
}
tpu_driver::TpuDriver* const driver;
const std::shared_ptr<Device> device;
const std::shared_ptr<PjRtDevice> device;
std::unique_ptr<tpu_driver::BufferHandle> handle;
std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use;
@ -162,12 +162,12 @@ class PyTpuBuffer {
static StatusOr<std::unique_ptr<PyTpuBuffer>> FromLiterals(
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
std::shared_ptr<void> leaves_reference,
std::shared_ptr<PyTpuClient> client, std::shared_ptr<Device> device);
std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device);
// Supports nested tuple creation.
static StatusOr<std::unique_ptr<PyTpuBuffer>> MakeTuple(
absl::Span<PyTpuBuffer* const> buffers,
std::shared_ptr<PyTpuClient> client, std::shared_ptr<Device> device);
std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device);
PyTpuBuffer() = delete;
PyTpuBuffer(Shape on_host_shape,
@ -181,7 +181,7 @@ class PyTpuBuffer {
PyTpuBuffer& operator=(PyTpuBuffer&&) = delete;
const Shape& on_host_shape() const { return on_host_shape_; }
std::shared_ptr<Device> device() const { return device_; }
std::shared_ptr<PjRtDevice> device() const { return device_; }
const std::string& platform_name() const { return client_->platform_name(); }
std::shared_ptr<PyTpuClient> client() const { return client_; }
@ -210,7 +210,7 @@ class PyTpuBuffer {
// Copies the buffer to target device `dst_device` and returns a PyTpuBuffer
// object holding the context to the target device buffer.
StatusOr<std::unique_ptr<PyTpuBuffer>> CopyToDevice(
std::shared_ptr<Device> dst_device);
std::shared_ptr<PjRtDevice> dst_device);
// Blocks the host until the buffer's value has been computed and is ready for
// immediate use on the device. Useful in particular for timing benchmarks.
@ -220,7 +220,7 @@ class PyTpuBuffer {
// tuple, the returned buffer corresponds to the root tuple buffer.
static StatusOr<std::unique_ptr<PyTpuBuffer>> AllocateBuffer(
const Shape& shape, std::shared_ptr<PyTpuClient> client,
std::shared_ptr<Device> device);
std::shared_ptr<PjRtDevice> device);
private:
// Initializes a just allocated device buffer. The returned event will be
@ -231,11 +231,11 @@ class PyTpuBuffer {
static StatusOr<std::unique_ptr<PyTpuBuffer>> CreateBuffer(
const Shape& non_tuple_shape,
absl::optional<BufferInitializer> initializer,
std::shared_ptr<PyTpuClient> client, std::shared_ptr<Device> device);
std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device);
const std::shared_ptr<PyTpuClient> client_;
const Shape on_host_shape_;
const std::shared_ptr<Device> device_;
const std::shared_ptr<PjRtDevice> device_;
// If this is a tuple, `device_buffer_` stores the tuple buffer and
// `child_buffers_` stores the child buffers; else, `device_buffer_` stores
@ -302,7 +302,7 @@ class PyTpuExecutable {
return local_logical_device_ids_;
}
const std::vector<std::shared_ptr<Device>>& local_devices() const {
const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const {
return local_devices_;
}
@ -350,7 +350,7 @@ class PyTpuExecutable {
// assigned.
// shared_ptrs instead of unique_ptrs to play well with the Python bindings
// (see xla.cc).
std::vector<std::shared_ptr<Device>> local_devices_;
std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
xla::Shape result_shape_;
};

View File

@ -40,11 +40,12 @@ PYBIND11_MODULE(tpu_client_extension, m) {
.def("host_id", &PyTpuClient::host_id)
.def("get_default_device_assignment",
[](PyTpuClient* client, int num_replicas, int num_partitions)
-> StatusOr<std::vector<std::vector<std::shared_ptr<Device>>>> {
-> StatusOr<
std::vector<std::vector<std::shared_ptr<PjRtDevice>>>> {
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
client->GetDefaultDeviceAssignment(
num_replicas, num_partitions));
std::vector<std::vector<std::shared_ptr<Device>>> result;
std::vector<std::vector<std::shared_ptr<PjRtDevice>>> result;
result.resize(num_replicas);
for (int r = 0; r < num_replicas; ++r) {
result[r].resize(num_partitions);
@ -60,11 +61,11 @@ PYBIND11_MODULE(tpu_client_extension, m) {
// TODO(skye): delete after all callers can handle 2D output
.def("get_default_device_assignment",
[](PyTpuClient* client, int num_replicas)
-> StatusOr<std::vector<std::shared_ptr<Device>>> {
-> StatusOr<std::vector<std::shared_ptr<PjRtDevice>>> {
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
client->GetDefaultDeviceAssignment(
num_replicas, /*num_partitions=*/1));
std::vector<std::shared_ptr<Device>> result;
std::vector<std::shared_ptr<PjRtDevice>> result;
for (int i = 0; i < num_replicas; ++i) {
int device_id = device_assignment(i, 0);
auto iter = client->id_to_device().find(device_id);
@ -96,7 +97,8 @@ PYBIND11_MODULE(tpu_client_extension, m) {
.def(
"buffer_from_pyval",
[](std::shared_ptr<PyTpuClient> client,
const pybind11::object& argument, std::shared_ptr<Device> device,
const pybind11::object& argument,
std::shared_ptr<PjRtDevice> device,
bool force_copy) -> StatusOr<std::unique_ptr<PyTpuBuffer>> {
if (device == nullptr) {
TF_RET_CHECK(!client->local_devices().empty());
@ -145,7 +147,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
py::class_<PyTpuBuffer>(m, "PyTpuBuffer")
.def_property_readonly("client", &PyTpuBuffer::client)
.def("copy_to_device",
[](PyTpuBuffer* buffer, std::shared_ptr<Device> dst_device) {
[](PyTpuBuffer* buffer, std::shared_ptr<PjRtDevice> dst_device) {
CHECK(dst_device != nullptr);
GlobalPyRefManager()->CollectGarbage();
py::gil_scoped_release gil_release;
@ -202,7 +204,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
.def_property_readonly("traceback",
[](PyTpuExecutable*) { return py::none(); });
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
py::class_<TpuDevice, PjRtDevice, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
.def_property_readonly("coords", &TpuDevice::coords)
.def_property_readonly("core_on_chip", &TpuDevice::core_on_chip)
.def("__repr__", [](const TpuDevice& device) {

View File

@ -439,26 +439,26 @@ PYBIND11_MODULE(xla_extension, m) {
device_assignment);
});
py::class_<Device, ClientAndPtr<Device>>(
py::class_<PjRtDevice, ClientAndPtr<PjRtDevice>>(
m, "Device",
"A descriptor of an available device.\n\nSubclasses are used to "
"represent specific types of devices, e.g. CPUs, GPUs. Subclasses may "
"have additional properties specific to that device type.")
.def_property_readonly(
"id", &Device::id,
"id", &PjRtDevice::id,
"Integer ID of this device.\n\nUnique across all available devices "
"of this type, including remote devices on multi-host platforms.")
.def_property_readonly("host_id", &Device::host_id,
.def_property_readonly("host_id", &PjRtDevice::host_id,
"Integer ID of this device's host.\n\n"
"This is always 0 except on multi-host platforms.")
.def_property_readonly("platform", &Device::platform_name)
.def_property_readonly("device_kind", &Device::device_kind)
.def_property_readonly("platform", &PjRtDevice::platform_name)
.def_property_readonly("device_kind", &PjRtDevice::device_kind)
.def_property_readonly(
"client",
[](const ClientAndPtr<Device>& device) { return device.client; })
.def("__str__", &Device::DebugString)
[](const ClientAndPtr<PjRtDevice>& device) { return device.client; })
.def("__str__", &PjRtDevice::DebugString)
.def("transfer_to_infeed",
[](const Device& device, const LiteralSlice& literal) {
[](const PjRtDevice& device, const LiteralSlice& literal) {
GlobalPyRefManager()->CollectGarbage();
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
@ -468,7 +468,8 @@ PYBIND11_MODULE(xla_extension, m) {
})
.def(
"transfer_from_outfeed",
[](const Device& device, const Shape& shape) -> StatusOr<py::object> {
[](const PjRtDevice& device,
const Shape& shape) -> StatusOr<py::object> {
GlobalPyRefManager()->CollectGarbage();
std::shared_ptr<Literal> literal_shared;
{
@ -492,12 +493,12 @@ PYBIND11_MODULE(xla_extension, m) {
return LiteralToPython(std::move(literal_shared));
});
py::class_<CpuDevice, Device, ClientAndPtr<CpuDevice>>(m, "CpuDevice")
py::class_<CpuDevice, PjRtDevice, ClientAndPtr<CpuDevice>>(m, "CpuDevice")
.def("__repr__", [](const CpuDevice& device) {
return absl::StrFormat("CpuDevice(id=%i)", device.id());
});
py::class_<GpuDevice, Device, ClientAndPtr<GpuDevice>>(m, "GpuDevice")
py::class_<GpuDevice, PjRtDevice, ClientAndPtr<GpuDevice>>(m, "GpuDevice")
.def("__repr__", [](const GpuDevice& device) {
return absl::StrFormat("GpuDevice(id=%i)", device.id());
});