[PJRT] Rename Device to PjRtDevice for consistency with the other PjRt classes.
Rename only; no functional changes. PiperOrigin-RevId: 327476904 Change-Id: I85a410bd68ad9192d5a43107b94cad4e1aeb83f0
This commit is contained in:
parent
800b502f00
commit
39e459f513
@ -25,8 +25,8 @@ static const char kCpuPlatformName[] = "cpu";
|
||||
|
||||
CpuDevice::CpuDevice(int id,
|
||||
std::unique_ptr<LocalDeviceState> local_device_state)
|
||||
: Device(id, std::move(local_device_state), kCpuPlatformName,
|
||||
/*device_kind=*/kCpuPlatformName) {}
|
||||
: PjRtDevice(id, std::move(local_device_state), kCpuPlatformName,
|
||||
/*device_kind=*/kCpuPlatformName) {}
|
||||
|
||||
StatusOr<std::shared_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
||||
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||
@ -39,7 +39,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||
ClientLibrary::GetOrCreateLocalClient(options));
|
||||
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
||||
for (int i = 0; i < client->device_count(); ++i) {
|
||||
se::StreamExecutorConfig config;
|
||||
config.ordinal = i;
|
||||
|
||||
@ -23,7 +23,7 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
class CpuDevice : public Device {
|
||||
class CpuDevice : public PjRtDevice {
|
||||
public:
|
||||
CpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state);
|
||||
};
|
||||
|
||||
@ -32,7 +32,7 @@ TEST(GpuMultiStream, Basics) {
|
||||
GetNvidiaGpuClient(/*asynchronous=*/true, GpuAllocatorConfig(),
|
||||
/*distributed_client=*/nullptr, /*node_id=*/0));
|
||||
|
||||
Device* device = client->local_devices().at(0);
|
||||
PjRtDevice* device = client->local_devices().at(0);
|
||||
|
||||
int n = 1024;
|
||||
Shape shape = ShapeUtil::MakeShape(S32, {n});
|
||||
|
||||
@ -25,8 +25,8 @@ static const char kInterpreterPlatformName[] = "interpreter";
|
||||
|
||||
InterpreterDevice::InterpreterDevice(
|
||||
int id, std::unique_ptr<LocalDeviceState> local_device_state)
|
||||
: Device(id, std::move(local_device_state), kInterpreterPlatformName,
|
||||
/*device_kind=*/kInterpreterPlatformName) {}
|
||||
: PjRtDevice(id, std::move(local_device_state), kInterpreterPlatformName,
|
||||
/*device_kind=*/kInterpreterPlatformName) {}
|
||||
|
||||
StatusOr<std::shared_ptr<PjRtClient>> GetInterpreterClient() {
|
||||
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||
@ -40,7 +40,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetInterpreterClient() {
|
||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||
ClientLibrary::GetOrCreateLocalClient(options));
|
||||
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
||||
se::StreamExecutor* executor =
|
||||
client->backend().stream_executor(0).ValueOrDie();
|
||||
auto device_state = absl::make_unique<LocalDeviceState>(
|
||||
|
||||
@ -23,7 +23,7 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
class InterpreterDevice : public Device {
|
||||
class InterpreterDevice : public PjRtDevice {
|
||||
public:
|
||||
InterpreterDevice(int id,
|
||||
std::unique_ptr<LocalDeviceState> local_device_state);
|
||||
|
||||
@ -207,9 +207,9 @@ StatusOr<std::string> NcclIdStore::GetNcclUniqueId(const NcclCliqueKey& key) {
|
||||
return cache_.emplace(key_string, result.ValueOrDie()).first->second;
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<Device>> BuildLocalDevices(
|
||||
std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
|
||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
||||
for (auto& local_device : local_device_states) {
|
||||
int device_ordinal = local_device->device_ordinal();
|
||||
const se::DeviceDescription& description =
|
||||
@ -225,7 +225,7 @@ std::vector<std::unique_ptr<Device>> BuildLocalDevices(
|
||||
Status BuildDistributedDevices(
|
||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
|
||||
std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
|
||||
std::vector<std::unique_ptr<Device>>* devices,
|
||||
std::vector<std::unique_ptr<PjRtDevice>>* devices,
|
||||
GpuExecutableRunOptions* gpu_executable_run_options) {
|
||||
LocalTopologyProto local_topology;
|
||||
local_topology.set_node_id(node_id);
|
||||
@ -286,8 +286,8 @@ Status BuildDistributedDevices(
|
||||
GpuDevice::GpuDevice(int id,
|
||||
std::unique_ptr<LocalDeviceState> local_device_state,
|
||||
std::string device_kind, int node_id)
|
||||
: Device(id, std::move(local_device_state), kGpuPlatformName,
|
||||
std::move(device_kind), node_id) {}
|
||||
: PjRtDevice(id, std::move(local_device_state), kGpuPlatformName,
|
||||
std::move(device_kind), node_id) {}
|
||||
|
||||
StatusOr<std::shared_ptr<PjRtClient>> GetNvidiaGpuClient(
|
||||
bool asynchronous, const GpuAllocatorConfig& allocator_config,
|
||||
@ -302,7 +302,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetNvidiaGpuClient(
|
||||
auto host_memory_allocator =
|
||||
GetGpuHostAllocator(local_device_states.front()->executor());
|
||||
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
||||
auto gpu_run_options = absl::make_unique<GpuExecutableRunOptions>();
|
||||
if (distributed_client) {
|
||||
TF_RETURN_IF_ERROR(BuildDistributedDevices(
|
||||
|
||||
@ -25,7 +25,7 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
class GpuDevice : public Device {
|
||||
class GpuDevice : public PjRtDevice {
|
||||
public:
|
||||
GpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state,
|
||||
std::string device_kind, int node_id);
|
||||
|
||||
@ -112,19 +112,19 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<LocalDeviceState*> Device::GetLocalDeviceState() const {
|
||||
StatusOr<LocalDeviceState*> PjRtDevice::GetLocalDeviceState() const {
|
||||
if (local_device_state_) {
|
||||
return local_device_state_.get();
|
||||
}
|
||||
return InvalidArgument("Device %s is not a local device.", DebugString());
|
||||
}
|
||||
|
||||
std::string Device::DebugString() const {
|
||||
std::string PjRtDevice::DebugString() const {
|
||||
return absl::StrCat(platform_name(), ":", id());
|
||||
}
|
||||
|
||||
StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
|
||||
absl::Span<const std::vector<Device*>> devices) {
|
||||
absl::Span<const std::vector<PjRtDevice*>> devices) {
|
||||
if (devices.empty()) {
|
||||
return InvalidArgument(
|
||||
"Device assignment passed to Compile() must be non-empty.");
|
||||
@ -175,7 +175,7 @@ class CpuAllocator : public tensorflow::Allocator {
|
||||
|
||||
PjRtClient::PjRtClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<Device>> devices, int host_id,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
@ -201,7 +201,7 @@ PjRtClient::PjRtClient(
|
||||
host_memory_allocator_ = std::make_unique<CpuAllocator>();
|
||||
}
|
||||
|
||||
for (const std::unique_ptr<Device>& device : devices_) {
|
||||
for (const std::unique_ptr<PjRtDevice>& device : devices_) {
|
||||
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
|
||||
<< "Duplicate device id: " << device->id();
|
||||
|
||||
@ -376,8 +376,9 @@ void RecordUsage(PjRtBuffer::ScopedHold device_buffer,
|
||||
// It is safe to delete the returned PjRtBuffer without further
|
||||
// synchronization if an error occurs before the buffer is used.
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
|
||||
const Shape& on_host_shape, Device* device, LocalDeviceState* local_device,
|
||||
se::Stream* copy_stream, bool is_uninitialized_create, PjRtClient* client) {
|
||||
const Shape& on_host_shape, PjRtDevice* device,
|
||||
LocalDeviceState* local_device, se::Stream* copy_stream,
|
||||
bool is_uninitialized_create, PjRtClient* client) {
|
||||
if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) {
|
||||
return InvalidArgument("Can't make a buffer from an empty tuple");
|
||||
}
|
||||
@ -574,7 +575,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtClient* client,
|
||||
Device* device) {
|
||||
PjRtDevice* device) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer");
|
||||
VLOG(2) << "PjRtBuffer::FromHostBuffer: shape: " << shape.ToString()
|
||||
<< " device: " << device->DebugString();
|
||||
@ -736,7 +737,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CreateUninitialized(
|
||||
const Shape& shape, PjRtClient* client, Device* device) {
|
||||
const Shape& shape, PjRtClient* client, PjRtDevice* device) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::CreateUninitialized");
|
||||
VLOG(2) << "PjRtBuffer::CreateUninitialized: shape: " << shape.ToString()
|
||||
<< " device: " << device->DebugString();
|
||||
@ -755,7 +756,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CreateUninitialized(
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
||||
const LiteralSlice& literal, PjRtClient* client, Device* device) {
|
||||
const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostLiteral");
|
||||
VLOG(2) << "PjRtBuffer::FromHostLiteral: shape: "
|
||||
<< literal.shape().ToString() << " device: " << device->DebugString();
|
||||
@ -815,7 +816,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
||||
}
|
||||
|
||||
/*static*/ void PjRtBuffer::MakeCrossHostReceiveBuffers(
|
||||
absl::Span<const Shape> shapes, PjRtClient* client, Device* device,
|
||||
absl::Span<const Shape> shapes, PjRtClient* client, PjRtDevice* device,
|
||||
PjRtCrossHostRecvNotifier&& notifier) {
|
||||
if (shapes.empty()) {
|
||||
notifier(InvalidArgument(
|
||||
@ -849,7 +850,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
||||
|
||||
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
|
||||
PjRtClient* client, Device* device)
|
||||
PjRtClient* client, PjRtDevice* device)
|
||||
: client_(client),
|
||||
on_host_shape_(std::move(on_host_shape)),
|
||||
on_device_shape_(std::move(on_device_shape)),
|
||||
@ -1189,7 +1190,7 @@ PjRtBuffer::ScopedHold PjRtBuffer::GetBufferWithHold(ScopedHold::Type type) {
|
||||
StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
|
||||
std::shared_ptr<BufferSequencingEvent>>>
|
||||
PjRtBuffer::CopyToDeviceHelper(
|
||||
Device* dst_device, LocalDeviceState* dst_local_device,
|
||||
PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
|
||||
LocalDeviceState* transfer_local_device, se::Stream* transfer_stream,
|
||||
std::shared_ptr<TrackedDeviceBuffer> src_device_buffer) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -1249,7 +1250,7 @@ PjRtBuffer::CopyToDeviceHelper(
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
||||
Device* dst_device) {
|
||||
PjRtDevice* dst_device) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::CopyToDevice");
|
||||
if (dst_device == device_) {
|
||||
return InvalidArgument(
|
||||
@ -1420,7 +1421,7 @@ StatusOr<TupleHandle> MakeTupleHelper(
|
||||
std::unique_ptr<PjRtBuffer> OutputBufferHelper(
|
||||
ScopedShapedBuffer* result_buffer,
|
||||
std::shared_ptr<BufferSequencingEvent> definition_event, PjRtClient* client,
|
||||
Device* device, LocalDeviceState* local_device) {
|
||||
PjRtDevice* device, LocalDeviceState* local_device) {
|
||||
std::shared_ptr<TrackedDeviceBuffer> out_buffer =
|
||||
TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer,
|
||||
{definition_event});
|
||||
@ -1433,7 +1434,7 @@ std::unique_ptr<PjRtBuffer> OutputBufferHelper(
|
||||
return pjrt_buffer;
|
||||
}
|
||||
|
||||
static Device* LookupDevice(const PjRtClient& client, int device_id) {
|
||||
static PjRtDevice* LookupDevice(const PjRtClient& client, int device_id) {
|
||||
auto it = client.id_to_device().find(device_id);
|
||||
CHECK(it != client.id_to_device().end())
|
||||
<< "Unknown device id: " << device_id;
|
||||
@ -1447,7 +1448,7 @@ PjRtExecutable::PjRtExecutable(
|
||||
bool parameter_is_tupled_arguments,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||
std::vector<std::pair<int, int>> local_logical_device_ids,
|
||||
std::vector<Device*> local_devices, PjRtClient* client)
|
||||
std::vector<PjRtDevice*> local_devices, PjRtClient* client)
|
||||
: client_(client),
|
||||
device_assignment_(std::move(device_assignment)),
|
||||
parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
|
||||
@ -1559,7 +1560,7 @@ PjRtExecutable::MakeExecutionInputsAndWaitForEvents(
|
||||
StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
|
||||
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
|
||||
Device* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers,
|
||||
PjRtDevice* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment) const {
|
||||
int device_ordinal = device->local_device_state()->device_ordinal();
|
||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
||||
@ -1695,7 +1696,7 @@ std::vector<std::unique_ptr<PjRtBuffer>> PjRtExecutable::MakeOutputBuffers(
|
||||
int device_ordinal, const ExecuteOptions& options,
|
||||
ScopedShapedBuffer result_buffer,
|
||||
std::shared_ptr<BufferSequencingEvent> definition_event,
|
||||
Device* device) const {
|
||||
PjRtDevice* device) const {
|
||||
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
|
||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
||||
if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
|
||||
@ -1729,7 +1730,7 @@ StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
||||
PjRtExecutable::ExecuteHelper(absl::Span<PjRtBuffer* const> argument_handles,
|
||||
int replica, int partition, const RunId& run_id,
|
||||
const ExecuteOptions& options,
|
||||
Device* device) const {
|
||||
PjRtDevice* device) const {
|
||||
std::shared_ptr<DeviceAssignment> device_assignment;
|
||||
if (device == nullptr) {
|
||||
CHECK(device_assignment_ != nullptr);
|
||||
@ -1828,7 +1829,7 @@ StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> PjRtExecutable::Execute(
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
||||
PjRtExecutable::ExecuteOnLocalDevice(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, Device* device,
|
||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||
const ExecuteOptions& options) const {
|
||||
if (device_assignment_ == nullptr) {
|
||||
VLOG(1) << "Executing portable single-core program on "
|
||||
@ -1894,7 +1895,7 @@ PjRtExecutable::ExecuteOnLocalDevices(
|
||||
for (int i = 0; i < num_local_devices; ++i) {
|
||||
const int replica = local_logical_device_ids_[i].first;
|
||||
const int partition = local_logical_device_ids_[i].second;
|
||||
Device* device = local_devices_[i];
|
||||
PjRtDevice* device = local_devices_[i];
|
||||
const LocalDeviceState& device_state = *device->local_device_state();
|
||||
device_state.execute_thread()->Schedule([&, replica, partition, i] {
|
||||
results[i] = ExecuteHelper(argument_handles[i], replica, partition,
|
||||
@ -2141,12 +2142,12 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
build_options.set_result_layout(result_layout);
|
||||
|
||||
std::vector<std::pair<int, int>> local_logical_device_ids;
|
||||
std::vector<Device*> local_devices;
|
||||
std::vector<PjRtDevice*> local_devices;
|
||||
if (device_assignment != nullptr) {
|
||||
for (int replica = 0; replica < num_replicas; ++replica) {
|
||||
for (int partition = 0; partition < num_partitions; ++partition) {
|
||||
int device_id = (*device_assignment)(replica, partition);
|
||||
Device* device = LookupDevice(*client, device_id);
|
||||
PjRtDevice* device = LookupDevice(*client, device_id);
|
||||
if (device->host_id() != client->host_id()) {
|
||||
VLOG(3) << "Non-local device: " << device_id;
|
||||
continue;
|
||||
|
||||
@ -52,17 +52,18 @@ namespace xla {
|
||||
|
||||
class PjRtClient;
|
||||
|
||||
class Device {
|
||||
class PjRtDevice {
|
||||
public:
|
||||
explicit Device(int id, std::unique_ptr<LocalDeviceState> local_device_state,
|
||||
std::string platform_name, std::string device_kind,
|
||||
int host_id = 0)
|
||||
explicit PjRtDevice(int id,
|
||||
std::unique_ptr<LocalDeviceState> local_device_state,
|
||||
std::string platform_name, std::string device_kind,
|
||||
int host_id = 0)
|
||||
: id_(id),
|
||||
local_device_state_(std::move(local_device_state)),
|
||||
host_id_(host_id),
|
||||
platform_name_(std::move(platform_name)),
|
||||
device_kind_(std::move(device_kind)) {}
|
||||
virtual ~Device() {}
|
||||
virtual ~PjRtDevice() {}
|
||||
|
||||
// The ID of this device. IDs are unique among devices of this type
|
||||
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
|
||||
@ -130,7 +131,7 @@ class PjRtClient {
|
||||
// `allocator` may null, in which case the platform default allocator is used.
|
||||
explicit PjRtClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<Device>> devices, int host_id,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
@ -142,11 +143,15 @@ class PjRtClient {
|
||||
|
||||
int device_count() const { return devices_.size(); }
|
||||
int local_device_count() const { return local_devices_.size(); }
|
||||
const std::vector<std::unique_ptr<Device>>& devices() const {
|
||||
const std::vector<std::unique_ptr<PjRtDevice>>& devices() const {
|
||||
return devices_;
|
||||
}
|
||||
const std::vector<Device*>& local_devices() const { return local_devices_; }
|
||||
const std::map<int, Device*>& id_to_device() const { return id_to_device_; }
|
||||
const std::vector<PjRtDevice*>& local_devices() const {
|
||||
return local_devices_;
|
||||
}
|
||||
const std::map<int, PjRtDevice*>& id_to_device() const {
|
||||
return id_to_device_;
|
||||
}
|
||||
int host_id() const { return host_id_; }
|
||||
const std::string& platform_name() const { return platform_name_; }
|
||||
|
||||
@ -210,11 +215,11 @@ class PjRtClient {
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
|
||||
|
||||
// Includes all devices, including non-local devices on multi-host platforms.
|
||||
std::vector<std::unique_ptr<Device>> devices_;
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices_;
|
||||
// Maps Device::id() to the corresponding Device. Includes all devices.
|
||||
std::map<int, Device*> id_to_device_;
|
||||
std::map<int, PjRtDevice*> id_to_device_;
|
||||
// Local devices indexed by local device ordinal.
|
||||
std::vector<Device*> local_devices_;
|
||||
std::vector<PjRtDevice*> local_devices_;
|
||||
int host_id_;
|
||||
|
||||
se::DeviceMemoryAllocator* allocator_;
|
||||
@ -233,7 +238,7 @@ class PjRtClient {
|
||||
// Converts a 2D set of Device objects indexed by [replica][partition] into an
|
||||
// xla::DeviceAssignment.
|
||||
StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
|
||||
absl::Span<const std::vector<Device*>> devices);
|
||||
absl::Span<const std::vector<PjRtDevice*>> devices);
|
||||
|
||||
// Holds a reference from Python to a tuple of device buffers. A PjRtBuffer
|
||||
// can be either valid or invalid. An invalid buffer is one that has never been
|
||||
@ -417,7 +422,7 @@ class PjRtBuffer {
|
||||
|
||||
// Returns a buffer with uninitialized contents.
|
||||
static StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitialized(
|
||||
const Shape& shape, PjRtClient* client, Device* device);
|
||||
const Shape& shape, PjRtClient* client, PjRtDevice* device);
|
||||
|
||||
// Describes the semantics the caller to FromHostBuffer expects from the
|
||||
// runtime, in a total order from most restrictive to least restrictive.
|
||||
@ -449,13 +454,13 @@ class PjRtBuffer {
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtClient* client,
|
||||
Device* device);
|
||||
PjRtDevice* device);
|
||||
|
||||
// Note that literal must remain in scope until the transfer has completed, so
|
||||
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
||||
// the return value before letting literal go out of scope.
|
||||
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostLiteral(
|
||||
const LiteralSlice& literal, PjRtClient* client, Device* device);
|
||||
const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device);
|
||||
|
||||
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
|
||||
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
||||
@ -467,12 +472,13 @@ class PjRtBuffer {
|
||||
// sending host and used in a call to CopyToRemoteDevice. None of the recv
|
||||
// buffers will become ready until *all* of the sends have completed.
|
||||
static void MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
|
||||
PjRtClient* client, Device* device,
|
||||
PjRtClient* client,
|
||||
PjRtDevice* device,
|
||||
PjRtCrossHostRecvNotifier&& notifier);
|
||||
|
||||
PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
|
||||
PjRtClient* client, Device* device);
|
||||
PjRtClient* client, PjRtDevice* device);
|
||||
~PjRtBuffer();
|
||||
|
||||
PjRtBuffer(const PjRtBuffer&) = delete;
|
||||
@ -482,7 +488,7 @@ class PjRtBuffer {
|
||||
|
||||
const Shape& on_host_shape() const { return on_host_shape_; }
|
||||
const Shape& on_device_shape() const { return on_device_shape_; }
|
||||
Device* device() const { return device_; }
|
||||
PjRtDevice* device() const { return device_; }
|
||||
const std::string& platform_name() const { return client_->platform_name(); }
|
||||
PjRtClient* client() const { return client_; }
|
||||
bool IsEmptyTuple() const {
|
||||
@ -556,7 +562,7 @@ class PjRtBuffer {
|
||||
|
||||
// Copies the buffer to device `dst_device`. Returns an error if the buffer is
|
||||
// already on dst_device.
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(Device* dst_device);
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(PjRtDevice* dst_device);
|
||||
|
||||
// Copies the buffer to the remote device encoded in serialized_descriptor.
|
||||
// This call must be preceded by a call to MakeCrossHostReceiveBuffers on the
|
||||
@ -629,7 +635,7 @@ class PjRtBuffer {
|
||||
|
||||
StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
|
||||
std::shared_ptr<BufferSequencingEvent>>>
|
||||
CopyToDeviceHelper(Device* dst_device, LocalDeviceState* dst_local_device,
|
||||
CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
|
||||
LocalDeviceState* transfer_local_device,
|
||||
se::Stream* transfer_stream,
|
||||
std::shared_ptr<TrackedDeviceBuffer> src_device_buffer);
|
||||
@ -637,7 +643,7 @@ class PjRtBuffer {
|
||||
PjRtClient* const client_;
|
||||
const Shape on_host_shape_;
|
||||
const Shape on_device_shape_;
|
||||
Device* const device_;
|
||||
PjRtDevice* const device_;
|
||||
|
||||
mutable absl::Mutex mu_;
|
||||
std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_);
|
||||
@ -707,7 +713,7 @@ class PjRtExecutable {
|
||||
bool parameter_is_tupled_arguments,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||
std::vector<std::pair<int, int>> local_logical_device_ids,
|
||||
std::vector<Device*> local_devices, PjRtClient* client);
|
||||
std::vector<PjRtDevice*> local_devices, PjRtClient* client);
|
||||
|
||||
virtual ~PjRtExecutable() = default;
|
||||
|
||||
@ -741,14 +747,16 @@ class PjRtExecutable {
|
||||
return local_logical_device_ids_;
|
||||
}
|
||||
|
||||
const std::vector<Device*>& local_devices() const { return local_devices_; }
|
||||
const std::vector<PjRtDevice*>& local_devices() const {
|
||||
return local_devices_;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> Execute(
|
||||
absl::Span<PjRtBuffer* const> argument_handles,
|
||||
const ExecuteOptions& options) const;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteOnLocalDevice(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, Device* device,
|
||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||
const ExecuteOptions& options) const;
|
||||
|
||||
// Execute on local devices. Takes a sequence of argument lists (one argument
|
||||
@ -786,7 +794,7 @@ class PjRtExecutable {
|
||||
StatusOr<ScopedShapedBuffer> EnqueueExecution(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
||||
int partition, int executable_idx, const RunId& run_id,
|
||||
const ExecuteOptions& options, Device* device,
|
||||
const ExecuteOptions& options, PjRtDevice* device,
|
||||
std::vector<PjRtBuffer::ScopedHold>* device_buffers,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment) const;
|
||||
|
||||
@ -794,12 +802,12 @@ class PjRtExecutable {
|
||||
int device_ordinal, const ExecuteOptions& options,
|
||||
ScopedShapedBuffer result_buffer,
|
||||
std::shared_ptr<BufferSequencingEvent> definition_event,
|
||||
Device* device) const;
|
||||
PjRtDevice* device) const;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
||||
int partition, const RunId& run_id, const ExecuteOptions& options,
|
||||
Device* device = nullptr) const;
|
||||
PjRtDevice* device = nullptr) const;
|
||||
|
||||
// Create shared pointers so we can free them after the execution: with
|
||||
// asynchronous execution, the process being executed can outlive the
|
||||
@ -828,7 +836,7 @@ class PjRtExecutable {
|
||||
// assigned.
|
||||
// shared_ptrs instead of unique_ptrs to play well with the Python bindings
|
||||
// (see xla.cc).
|
||||
std::vector<Device*> local_devices_;
|
||||
std::vector<PjRtDevice*> local_devices_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
@ -193,7 +193,7 @@ StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
|
||||
return minor_to_major;
|
||||
}
|
||||
|
||||
StatusOr<DLDeviceType> DLDeviceTypeForDevice(const Device& device) {
|
||||
StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
|
||||
const se::Platform* platform =
|
||||
device.local_device_state()->executor()->platform();
|
||||
if (platform->id() == se::host::kHostPlatformId) {
|
||||
@ -205,15 +205,15 @@ StatusOr<DLDeviceType> DLDeviceTypeForDevice(const Device& device) {
|
||||
device.DebugString());
|
||||
}
|
||||
|
||||
StatusOr<DLContext> DLContextForDevice(const Device& device) {
|
||||
StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) {
|
||||
DLContext context;
|
||||
TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
|
||||
context.device_id = device.local_device_state()->device_ordinal();
|
||||
return context;
|
||||
}
|
||||
|
||||
StatusOr<Device*> DeviceForDLContext(const PjRtClient& client,
|
||||
const DLContext& context) {
|
||||
StatusOr<PjRtDevice*> DeviceForDLContext(const PjRtClient& client,
|
||||
const DLContext& context) {
|
||||
se::Platform::Id platform_id;
|
||||
switch (context.device_type) {
|
||||
case kDLCPU:
|
||||
@ -226,7 +226,7 @@ StatusOr<Device*> DeviceForDLContext(const PjRtClient& client,
|
||||
return InvalidArgument("Unknown/unsupported DLPack device type %d",
|
||||
context.device_type);
|
||||
}
|
||||
auto it = absl::c_find_if(client.local_devices(), [&](Device* device) {
|
||||
auto it = absl::c_find_if(client.local_devices(), [&](PjRtDevice* device) {
|
||||
return device->local_device_state()->executor()->platform()->id() ==
|
||||
platform_id &&
|
||||
device->local_device_state()->device_ordinal() == context.device_id;
|
||||
@ -313,7 +313,7 @@ StatusOr<std::unique_ptr<PyBuffer>> DLPackManagedTensorToBuffer(
|
||||
dlmt->dl_tensor.ndim);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Device * device,
|
||||
PjRtDevice * device,
|
||||
DeviceForDLContext(*client->pjrt_client(), dlmt->dl_tensor.ctx));
|
||||
absl::Span<int64 const> dimensions(
|
||||
reinterpret_cast<int64*>(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim);
|
||||
|
||||
@ -217,7 +217,7 @@ std::string CallSignature::DebugString() const {
|
||||
|
||||
struct CacheEntry {
|
||||
std::shared_ptr<xla::PyExecutable> executable;
|
||||
xla::Device* device;
|
||||
xla::PjRtDevice* device;
|
||||
PyTreeDef out_pytree_def;
|
||||
// These are the objects required to create a `DeviceArray` object.
|
||||
// We use Python types within the vector because this is what we will be
|
||||
@ -235,7 +235,7 @@ class CompiledFunction {
|
||||
CompiledFunction(py::function cache_miss_fun, py::function python_f_jitted,
|
||||
bool jax_enable_x64, std::vector<int> static_argnums,
|
||||
std::shared_ptr<xla::PyClient> pyclient,
|
||||
xla::Device* device);
|
||||
xla::PjRtDevice* device);
|
||||
~CompiledFunction();
|
||||
|
||||
// This function will:
|
||||
@ -268,7 +268,7 @@ class CompiledFunction {
|
||||
absl::flat_hash_map<CallSignature, std::unique_ptr<CacheEntry>> executables_;
|
||||
|
||||
const std::shared_ptr<xla::PyClient> pyclient_;
|
||||
xla::Device* const default_device_;
|
||||
xla::PjRtDevice* const default_device_;
|
||||
};
|
||||
|
||||
CompiledFunction::CompiledFunction(py::function cache_miss_fun,
|
||||
@ -276,7 +276,7 @@ CompiledFunction::CompiledFunction(py::function cache_miss_fun,
|
||||
bool jax_enable_x64,
|
||||
std::vector<int> static_argnums,
|
||||
std::shared_ptr<xla::PyClient> pyclient,
|
||||
xla::Device* device)
|
||||
xla::PjRtDevice* device)
|
||||
: cache_miss_fun_(std::move(cache_miss_fun)),
|
||||
python_f_jitted_(std::move(python_f_jitted)),
|
||||
jax_enable_x64_(jax_enable_x64),
|
||||
@ -374,9 +374,9 @@ void FlattenArguments(const py::args& args, const py::kwargs& py_kwargs,
|
||||
}
|
||||
|
||||
template <typename CppType, typename Pybind11Type>
|
||||
std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(const py::handle& scalar,
|
||||
xla::PjRtClient* client,
|
||||
xla::Device* device) {
|
||||
std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(
|
||||
const py::handle& scalar, xla::PjRtClient* client,
|
||||
xla::PjRtDevice* device) {
|
||||
CppType data = py::cast<Pybind11Type>(scalar);
|
||||
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<CppType>({});
|
||||
return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer(
|
||||
@ -389,7 +389,7 @@ std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(const py::handle& scalar,
|
||||
// not convertible (thus, this must be called after other checks).
|
||||
StatusOr<std::unique_ptr<xla::PjRtBuffer>> ScalarToBuffer(
|
||||
py::handle scalar, bool jax_enable_x64, xla::PjRtClient* client,
|
||||
xla::Device* device) {
|
||||
xla::PjRtDevice* device) {
|
||||
// Important: In Python, isinstance(True, int) returns True. Thus, we have
|
||||
// to check for bool before int.
|
||||
if (py::isinstance<py::bool_>(scalar)) {
|
||||
@ -467,7 +467,7 @@ const py::dtype* DtypeTo32BitDtype(const py::dtype& dtype) {
|
||||
//
|
||||
// Returns `OkStatus()` on success.
|
||||
Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
||||
xla::Device* default_device,
|
||||
xla::PjRtDevice* default_device,
|
||||
ParsedArgumentsAsBuffers& arguments) {
|
||||
std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
|
||||
auto& keep_alive = arguments.keep_alive;
|
||||
@ -490,12 +490,12 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
||||
// https://github.com/google/jax/pull/1916 for the rationale why the
|
||||
// computation follows the data locality.
|
||||
// It's also similar to PyTorch's behavior.
|
||||
xla::Device* data_device = nullptr;
|
||||
xla::PjRtDevice* data_device = nullptr;
|
||||
for (py::handle arg : arguments.flat_dynamic_args) {
|
||||
if (py::isinstance(arg, device_array)) {
|
||||
xla::PyBuffer* buffer =
|
||||
py::cast<xla::PyBuffer*>(arg.attr("device_buffer"));
|
||||
xla::Device* device = buffer->buffer()->device();
|
||||
xla::PjRtDevice* device = buffer->buffer()->device();
|
||||
if (data_device && (device != data_device)) {
|
||||
return InvalidArgument(
|
||||
"%s",
|
||||
@ -682,7 +682,7 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
|
||||
[](py::function cache_miss_fun,
|
||||
py::function fallback_on_unsupported_argument,
|
||||
bool jax_enable_x64, std::vector<int> static_argnums,
|
||||
xla::ClientAndPtr<xla::Device> client_and_device)
|
||||
xla::ClientAndPtr<xla::PjRtDevice> client_and_device)
|
||||
-> std::unique_ptr<CompiledFunction> {
|
||||
return std::make_unique<CompiledFunction>(
|
||||
std::move(cache_miss_fun),
|
||||
|
||||
@ -101,14 +101,14 @@ uint32_t constexpr kOutfeedCidShutdown = 0;
|
||||
// Encapsulates data received from a device outfeed.
|
||||
class OutfeedData {
|
||||
public:
|
||||
OutfeedData(Device* device, uint32_t consumer_id, Shape shape)
|
||||
OutfeedData(PjRtDevice* device, uint32_t consumer_id, Shape shape)
|
||||
: device_(device),
|
||||
consumer_id_(consumer_id),
|
||||
shape_(shape),
|
||||
literal_(nullptr),
|
||||
literal_size_bytes_(0) {}
|
||||
|
||||
Device* device() { return device_; }
|
||||
PjRtDevice* device() { return device_; }
|
||||
uint32_t consumer_id() const { return consumer_id_; }
|
||||
Shape shape() const { return shape_; }
|
||||
std::unique_ptr<Literal> literal() {
|
||||
@ -123,7 +123,7 @@ class OutfeedData {
|
||||
std::string DebugString() const;
|
||||
|
||||
private:
|
||||
Device* device_;
|
||||
PjRtDevice* device_;
|
||||
uint32_t consumer_id_;
|
||||
Shape shape_;
|
||||
std::unique_ptr<Literal> literal_;
|
||||
@ -187,8 +187,8 @@ class OutfeedReceiverImpl {
|
||||
Status SendShutdownOutfeedHeader(int device_idx);
|
||||
|
||||
// Receives a raw Literal from a device outfeed.
|
||||
StatusOr<std::unique_ptr<Literal>> ReceiveRawFromOutfeed(const Device* device,
|
||||
const Shape& shape);
|
||||
StatusOr<std::unique_ptr<Literal>> ReceiveRawFromOutfeed(
|
||||
const PjRtDevice* device, const Shape& shape);
|
||||
|
||||
// Enqueues received data in the callbaback queue.
|
||||
void EnqueueReceivedData(std::unique_ptr<OutfeedData> received)
|
||||
@ -200,7 +200,7 @@ class OutfeedReceiverImpl {
|
||||
|
||||
OutfeedReceiver::Callback callback_;
|
||||
// The devices on which we are listening.
|
||||
std::vector<Device*> devices_;
|
||||
std::vector<PjRtDevice*> devices_;
|
||||
// Maximum bytes capacity of the callback queue.
|
||||
uint64_t max_callback_queue_size_bytes_;
|
||||
|
||||
@ -283,7 +283,7 @@ void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
++num_listening_threads_;
|
||||
}
|
||||
Device* device = devices_[device_idx];
|
||||
PjRtDevice* device = devices_[device_idx];
|
||||
while (true) {
|
||||
Shape header_shape = ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords});
|
||||
std::unique_ptr<Literal> header =
|
||||
@ -339,7 +339,7 @@ void OutfeedReceiverImpl::EnqueueReceivedData(
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
|
||||
const Device* device, const Shape& shape) {
|
||||
const PjRtDevice* device, const Shape& shape) {
|
||||
std::shared_ptr<Literal> literal_shared;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
@ -390,7 +390,7 @@ void OutfeedReceiverImpl::CallbackThreadLoop() {
|
||||
}
|
||||
|
||||
Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
|
||||
const Device* device = devices_[device_idx];
|
||||
const PjRtDevice* device = devices_[device_idx];
|
||||
constexpr int consumer_id = kOutfeedCidShutdown;
|
||||
VLOG(2) << "[" << device->DebugString()
|
||||
<< "] SendSpecialHeader cons=" << consumer_id;
|
||||
|
||||
@ -33,7 +33,7 @@ class OutfeedReceiver {
|
||||
public:
|
||||
// A callback takes: device, consumer id, received.
|
||||
using Callback =
|
||||
std::function<void(Device*, uint32_t, std::shared_ptr<Literal>)>;
|
||||
std::function<void(PjRtDevice*, uint32_t, std::shared_ptr<Literal>)>;
|
||||
|
||||
// Constructs the receiver for the given clients and callback function.
|
||||
//
|
||||
|
||||
@ -40,7 +40,7 @@ class OutfeedReceiverForPython {
|
||||
public:
|
||||
// A callback to Python takes: consumer id, received literal.
|
||||
using CallbackToPython =
|
||||
std::function<void(ClientAndPtr<Device>, uint32_t, pybind11::object)>;
|
||||
std::function<void(ClientAndPtr<PjRtDevice>, uint32_t, pybind11::object)>;
|
||||
|
||||
OutfeedReceiverForPython(CallbackToPython callback_python,
|
||||
std::vector<std::shared_ptr<PyClient>> clients,
|
||||
@ -48,7 +48,7 @@ class OutfeedReceiverForPython {
|
||||
: callback_python_(std::move(callback_python)),
|
||||
clients_(std::move(clients)) {
|
||||
OutfeedReceiver::Callback callback =
|
||||
[this](Device* device, uint32_t consumer_id,
|
||||
[this](PjRtDevice* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> literal) {
|
||||
this->Callback(device, consumer_id, std::move(literal));
|
||||
};
|
||||
@ -86,7 +86,7 @@ class OutfeedReceiverForPython {
|
||||
arrays);
|
||||
}
|
||||
|
||||
void Callback(Device* device, uint32_t consumer_id,
|
||||
void Callback(PjRtDevice* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> literal) {
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
@ -106,7 +106,7 @@ class OutfeedReceiverForPython {
|
||||
LiteralToPython(std::move(literal)).ValueOrDie();
|
||||
// The callback_ should handle all exceptions in user-code. If we get
|
||||
// an exception here, it is a bug in the callback and we should stop.
|
||||
callback_python_(WrapWithClient<Device>(*it, device), consumer_id,
|
||||
callback_python_(WrapWithClient<PjRtDevice>(*it, device), consumer_id,
|
||||
std::move(literal_python));
|
||||
}
|
||||
|
||||
|
||||
@ -78,11 +78,11 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) {
|
||||
std::vector<PjRtClient*> clients{cpu_client.get()};
|
||||
|
||||
auto receiver = absl::make_unique<Accumulator>();
|
||||
OutfeedReceiver::Callback callback = [&receiver](
|
||||
Device* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
OutfeedReceiver::Callback callback =
|
||||
[&receiver](PjRtDevice* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
auto outfeed_receiver =
|
||||
std::make_shared<OutfeedReceiver>(callback, clients, 128);
|
||||
outfeed_receiver->Start();
|
||||
@ -111,11 +111,11 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedTwoComputations) {
|
||||
std::vector<PjRtClient*> clients{cpu_client.get()};
|
||||
|
||||
auto receiver = absl::make_unique<Accumulator>();
|
||||
OutfeedReceiver::Callback callback = [&receiver](
|
||||
Device* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
OutfeedReceiver::Callback callback =
|
||||
[&receiver](PjRtDevice* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
auto outfeed_receiver =
|
||||
std::make_shared<OutfeedReceiver>(callback, clients, 128);
|
||||
outfeed_receiver->Start();
|
||||
@ -156,11 +156,11 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedTwoOutfeed) {
|
||||
std::vector<PjRtClient*> clients{cpu_client.get()};
|
||||
|
||||
auto receiver = absl::make_unique<Accumulator>();
|
||||
OutfeedReceiver::Callback callback = [&receiver](
|
||||
Device* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
OutfeedReceiver::Callback callback =
|
||||
[&receiver](PjRtDevice* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
auto outfeed_receiver =
|
||||
std::make_shared<OutfeedReceiver>(callback, clients, 128);
|
||||
outfeed_receiver->Start();
|
||||
@ -199,11 +199,11 @@ TEST(OutfeedReceiverTest, DifferentShapeForConsumerIdError) {
|
||||
std::vector<PjRtClient*> clients{cpu_client.get()};
|
||||
|
||||
auto receiver = absl::make_unique<Accumulator>();
|
||||
OutfeedReceiver::Callback callback = [&receiver](
|
||||
Device* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
OutfeedReceiver::Callback callback =
|
||||
[&receiver](PjRtDevice* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
auto outfeed_receiver =
|
||||
std::make_shared<OutfeedReceiver>(callback, clients, 128);
|
||||
outfeed_receiver->Start();
|
||||
@ -233,11 +233,11 @@ TEST(OutfeedReceiverTest, InvalidConsumerIdError) {
|
||||
std::vector<PjRtClient*> clients{cpu_client.get()};
|
||||
|
||||
auto receiver = absl::make_unique<Accumulator>();
|
||||
OutfeedReceiver::Callback callback = [&receiver](
|
||||
Device* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
OutfeedReceiver::Callback callback =
|
||||
[&receiver](PjRtDevice* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
auto outfeed_receiver =
|
||||
std::make_shared<OutfeedReceiver>(callback, clients, 128);
|
||||
outfeed_receiver->Start();
|
||||
|
||||
@ -51,12 +51,12 @@ PyBuffer::~PyBuffer() {
|
||||
}
|
||||
}
|
||||
|
||||
ClientAndPtr<Device> PyBuffer::device() const {
|
||||
ClientAndPtr<PjRtDevice> PyBuffer::device() const {
|
||||
return WrapWithClient(client_, buffer_->device());
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyBuffer>> PyBuffer::CopyToDevice(
|
||||
const ClientAndPtr<Device>& dst_device) const {
|
||||
const ClientAndPtr<PjRtDevice>& dst_device) const {
|
||||
CHECK(dst_device.get() != nullptr);
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
std::unique_ptr<PjRtBuffer> out;
|
||||
|
||||
@ -38,12 +38,12 @@ class PyBuffer {
|
||||
std::shared_ptr<PyClient> client() const { return client_; }
|
||||
PjRtBuffer* buffer() const { return buffer_.get(); }
|
||||
|
||||
ClientAndPtr<Device> device() const;
|
||||
ClientAndPtr<PjRtDevice> device() const;
|
||||
const std::string& platform_name() const { return buffer_->platform_name(); }
|
||||
bool is_deleted() const { return buffer_->IsDeleted(); }
|
||||
|
||||
StatusOr<std::unique_ptr<PyBuffer>> CopyToDevice(
|
||||
const ClientAndPtr<Device>& dst_device) const;
|
||||
const ClientAndPtr<PjRtDevice>& dst_device) const;
|
||||
|
||||
void Delete() { return buffer_->Delete(); }
|
||||
|
||||
|
||||
@ -33,8 +33,8 @@ namespace pprof = tensorflow::tfprof::pprof;
|
||||
PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
|
||||
: pjrt_client_(std::move(pjrt_client)) {}
|
||||
|
||||
std::vector<ClientAndPtr<Device>> PyClient::Devices() {
|
||||
std::vector<ClientAndPtr<Device>> devices;
|
||||
std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
|
||||
std::vector<ClientAndPtr<PjRtDevice>> devices;
|
||||
devices.reserve(pjrt_client_->devices().size());
|
||||
for (const auto& device : pjrt_client_->devices()) {
|
||||
devices.push_back(WrapWithClient(shared_from_this(), device.get()));
|
||||
@ -42,21 +42,21 @@ std::vector<ClientAndPtr<Device>> PyClient::Devices() {
|
||||
return devices;
|
||||
}
|
||||
|
||||
std::vector<ClientAndPtr<Device>> PyClient::LocalDevices() {
|
||||
std::vector<ClientAndPtr<Device>> devices;
|
||||
std::vector<ClientAndPtr<PjRtDevice>> PyClient::LocalDevices() {
|
||||
std::vector<ClientAndPtr<PjRtDevice>> devices;
|
||||
devices.reserve(pjrt_client_->local_devices().size());
|
||||
for (Device* device : pjrt_client_->local_devices()) {
|
||||
for (PjRtDevice* device : pjrt_client_->local_devices()) {
|
||||
devices.push_back(WrapWithClient(shared_from_this(), device));
|
||||
}
|
||||
return devices;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>>
|
||||
StatusOr<std::vector<std::vector<ClientAndPtr<PjRtDevice>>>>
|
||||
PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
DeviceAssignment device_assignment,
|
||||
pjrt_client_->GetDefaultDeviceAssignment(num_replicas, num_partitions));
|
||||
std::vector<std::vector<ClientAndPtr<Device>>> result;
|
||||
std::vector<std::vector<ClientAndPtr<PjRtDevice>>> result;
|
||||
result.resize(num_replicas);
|
||||
for (int r = 0; r < num_replicas; ++r) {
|
||||
result[r].resize(num_partitions);
|
||||
@ -70,12 +70,12 @@ PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
|
||||
return result;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<ClientAndPtr<Device>>>
|
||||
StatusOr<std::vector<ClientAndPtr<PjRtDevice>>>
|
||||
PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
|
||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||
pjrt_client_->GetDefaultDeviceAssignment(
|
||||
num_replicas, /*num_partitions=*/1));
|
||||
std::vector<ClientAndPtr<Device>> result;
|
||||
std::vector<ClientAndPtr<PjRtDevice>> result;
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
int device_id = device_assignment(i, 0);
|
||||
auto iter = pjrt_client_->id_to_device().find(device_id);
|
||||
@ -86,7 +86,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
|
||||
const pybind11::object& argument, Device* device, bool force_copy,
|
||||
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
|
||||
PjRtBuffer::HostBufferSemantics host_buffer_semantics) {
|
||||
if (device == nullptr) {
|
||||
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
|
||||
@ -206,7 +206,7 @@ namespace {
|
||||
struct HeapProfileKey {
|
||||
Traceback* traceback;
|
||||
int64 size;
|
||||
Device* device;
|
||||
PjRtDevice* device;
|
||||
bool operator==(const HeapProfileKey& other) const;
|
||||
};
|
||||
|
||||
|
||||
@ -100,14 +100,14 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
||||
int device_count() const { return pjrt_client_->device_count(); }
|
||||
int host_id() const { return pjrt_client_->host_id(); }
|
||||
|
||||
std::vector<ClientAndPtr<Device>> Devices();
|
||||
std::vector<ClientAndPtr<Device>> LocalDevices();
|
||||
std::vector<ClientAndPtr<PjRtDevice>> Devices();
|
||||
std::vector<ClientAndPtr<PjRtDevice>> LocalDevices();
|
||||
|
||||
StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>>
|
||||
StatusOr<std::vector<std::vector<ClientAndPtr<PjRtDevice>>>>
|
||||
GetDefaultDeviceAssignment(int num_replicas, int num_partitions);
|
||||
|
||||
// TODO(skye): delete after all callers can handle 2D output
|
||||
StatusOr<std::vector<ClientAndPtr<Device>>> GetDefaultDeviceAssignment1D(
|
||||
StatusOr<std::vector<ClientAndPtr<PjRtDevice>>> GetDefaultDeviceAssignment1D(
|
||||
int num_replicas);
|
||||
|
||||
StatusOr<ChannelHandle> CreateChannelHandle() {
|
||||
@ -121,7 +121,7 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
|
||||
const pybind11::object& argument, Device* device, bool force_copy,
|
||||
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
|
||||
PjRtBuffer::HostBufferSemantics host_buffer_semantics);
|
||||
|
||||
StatusOr<std::shared_ptr<PyExecutable>> Compile(
|
||||
|
||||
@ -58,10 +58,10 @@ PyExecutable::~PyExecutable() {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ClientAndPtr<Device>> PyExecutable::LocalDevices() const {
|
||||
std::vector<ClientAndPtr<Device>> devices;
|
||||
std::vector<ClientAndPtr<PjRtDevice>> PyExecutable::LocalDevices() const {
|
||||
std::vector<ClientAndPtr<PjRtDevice>> devices;
|
||||
devices.reserve(executable_->local_devices().size());
|
||||
for (Device* device : executable_->local_devices()) {
|
||||
for (PjRtDevice* device : executable_->local_devices()) {
|
||||
devices.push_back(WrapWithClient(client_, device));
|
||||
}
|
||||
return devices;
|
||||
|
||||
@ -47,7 +47,7 @@ class PyExecutable {
|
||||
return executable_->local_logical_device_ids();
|
||||
}
|
||||
|
||||
std::vector<ClientAndPtr<Device>> LocalDevices() const;
|
||||
std::vector<ClientAndPtr<PjRtDevice>> LocalDevices() const;
|
||||
|
||||
int64 SizeOfGeneratedCodeInBytes() const {
|
||||
return executable_->SizeOfGeneratedCodeInBytes();
|
||||
|
||||
@ -37,8 +37,8 @@ namespace xla {
|
||||
|
||||
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||
int core_on_chip)
|
||||
: xla::Device(id, /*local_device_state=*/nullptr, kTpuPlatform,
|
||||
/*device_kind=*/"Cloud TPU", host_id),
|
||||
: xla::PjRtDevice(id, /*local_device_state=*/nullptr, kTpuPlatform,
|
||||
/*device_kind=*/"Cloud TPU", host_id),
|
||||
coords_(coords),
|
||||
core_on_chip_(core_on_chip) {}
|
||||
|
||||
@ -47,9 +47,9 @@ std::string TpuDevice::DebugString() const {
|
||||
coords_[0], coords_[1], coords_[2], core_on_chip_);
|
||||
}
|
||||
|
||||
xla::StatusOr<std::vector<std::shared_ptr<xla::Device>>>
|
||||
xla::StatusOr<std::vector<std::shared_ptr<xla::PjRtDevice>>>
|
||||
TpuDevice::GetTpuDevices(const tpu_driver::SystemInfo& system_info) {
|
||||
std::vector<std::shared_ptr<Device>> devices;
|
||||
std::vector<std::shared_ptr<PjRtDevice>> devices;
|
||||
for (const auto& chip : system_info.tpu_chip()) {
|
||||
auto& coord = chip.chip_coord();
|
||||
std::array<int, 3> coords_array = {coord.x(), coord.y(), coord.z()};
|
||||
@ -78,7 +78,7 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
|
||||
tpu_driver::SystemInfo system_info;
|
||||
client->QuerySystemInfo(&system_info);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::vector<std::shared_ptr<Device>> devices,
|
||||
TF_ASSIGN_OR_RETURN(std::vector<std::shared_ptr<PjRtDevice>> devices,
|
||||
TpuDevice::GetTpuDevices(system_info));
|
||||
|
||||
return std::make_shared<PyTpuClient>(kTpuPlatform, std::move(client),
|
||||
@ -88,13 +88,13 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
|
||||
|
||||
PyTpuClient::PyTpuClient(std::string platform_name,
|
||||
std::unique_ptr<tpu_driver::TpuDriver> driver,
|
||||
std::vector<std::shared_ptr<Device>> devices,
|
||||
std::vector<std::shared_ptr<PjRtDevice>> devices,
|
||||
int host_id)
|
||||
: platform_name_(std::move(platform_name)),
|
||||
driver_(std::move(driver)),
|
||||
devices_(std::move(devices)),
|
||||
host_id_(host_id) {
|
||||
for (const std::shared_ptr<Device>& device : devices_) {
|
||||
for (const std::shared_ptr<PjRtDevice>& device : devices_) {
|
||||
CHECK(id_to_device_.insert({device->id(), device}).second)
|
||||
<< "Duplicate device id: " << device->id();
|
||||
|
||||
@ -173,7 +173,7 @@ static Status CheckDataType(xla::PrimitiveType dtype) {
|
||||
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::FromLiterals(
|
||||
std::vector<BorrowingLiteral> leaves, const Shape& tuple_shape,
|
||||
std::shared_ptr<void> leaves_references,
|
||||
std::shared_ptr<PyTpuClient> client, std::shared_ptr<Device> device) {
|
||||
std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device) {
|
||||
tensorflow::profiler::TraceMe traceme("PyTpuBuffer::FromLiterals");
|
||||
VLOG(1) << "PyTpuBuffer::FromLiterals: shape: " << tuple_shape.DebugString()
|
||||
<< " device: " << device->DebugString();
|
||||
@ -229,7 +229,7 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::FromLiterals(
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::MakeTuple(
|
||||
absl::Span<PyTpuBuffer* const> buffers, std::shared_ptr<PyTpuClient> client,
|
||||
std::shared_ptr<Device> device) {
|
||||
std::shared_ptr<PjRtDevice> device) {
|
||||
std::vector<Shape> child_shapes;
|
||||
std::vector<std::shared_ptr<TpuSharedBuffer>> child_device_buffers;
|
||||
std::vector<tpu_driver::BufferHandle*> child_handle_ptrs;
|
||||
@ -388,7 +388,7 @@ PyTpuBuffer::DestructureTuple() {
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CopyToDevice(
|
||||
std::shared_ptr<Device> dst_device) {
|
||||
std::shared_ptr<PjRtDevice> dst_device) {
|
||||
tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CopyToDevice");
|
||||
if (on_host_shape_.IsTuple()) {
|
||||
return Unimplemented("CopyToDevice for tuples is not supported.");
|
||||
@ -433,7 +433,7 @@ Status PyTpuBuffer::BlockHostUntilReady() {
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::AllocateBuffer(
|
||||
const Shape& shape, std::shared_ptr<PyTpuClient> client,
|
||||
std::shared_ptr<Device> device) {
|
||||
std::shared_ptr<PjRtDevice> device) {
|
||||
tensorflow::profiler::TraceMe traceme("PyTpuBuffer::AllocateBuffer");
|
||||
VLOG(1) << "PyTpuBuffer::AllocateBuffer: shape: " << shape.DebugString()
|
||||
<< " device: " << device->DebugString();
|
||||
@ -465,7 +465,7 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::AllocateBuffer(
|
||||
/*static*/
|
||||
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CreateBuffer(
|
||||
const Shape& non_tuple_shape, absl::optional<BufferInitializer> initializer,
|
||||
std::shared_ptr<PyTpuClient> client, std::shared_ptr<Device> device) {
|
||||
std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device) {
|
||||
tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CreateBuffer");
|
||||
VLOG(1) << "PyTpuBuffer::CreateBuffer: shape: "
|
||||
<< non_tuple_shape.DebugString()
|
||||
@ -493,8 +493,8 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CreateBuffer(
|
||||
std::vector<std::shared_ptr<TpuSharedBuffer>>(), client);
|
||||
}
|
||||
|
||||
static std::shared_ptr<Device> LookupDevice(const PyTpuClient& client,
|
||||
int device_id) {
|
||||
static std::shared_ptr<PjRtDevice> LookupDevice(const PyTpuClient& client,
|
||||
int device_id) {
|
||||
auto it = client.id_to_device().find(device_id);
|
||||
CHECK(it != client.id_to_device().end())
|
||||
<< "Unknown device id: " << device_id;
|
||||
@ -516,7 +516,7 @@ PyTpuExecutable::PyTpuExecutable(
|
||||
for (int replica = 0; replica < num_replicas; ++replica) {
|
||||
for (int partition = 0; partition < num_partitions; ++partition) {
|
||||
int device_id = device_assignment_(replica, partition);
|
||||
std::shared_ptr<Device> device = LookupDevice(*client_, device_id);
|
||||
std::shared_ptr<PjRtDevice> device = LookupDevice(*client_, device_id);
|
||||
if (device->host_id() != client_->host_id()) {
|
||||
VLOG(3) << "Non-local device: " << device_id;
|
||||
continue;
|
||||
@ -541,7 +541,7 @@ PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper(
|
||||
absl::Span<PyTpuBuffer* const> this_core_arguments, int replica,
|
||||
int partition, const RunId& run_id) {
|
||||
const int device_id = device_assignment_(replica, partition);
|
||||
std::shared_ptr<Device> device = LookupDevice(*client_, device_id);
|
||||
std::shared_ptr<PjRtDevice> device = LookupDevice(*client_, device_id);
|
||||
CHECK_EQ(device->host_id(), client_->host_id());
|
||||
tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute");
|
||||
VLOG(3) << "Replica " << replica << ", partition " << partition
|
||||
|
||||
@ -38,7 +38,7 @@ namespace xla {
|
||||
|
||||
constexpr char kTpuPlatform[] = "tpu";
|
||||
|
||||
class TpuDevice : public Device {
|
||||
class TpuDevice : public PjRtDevice {
|
||||
public:
|
||||
TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||
int core_on_chip);
|
||||
@ -48,8 +48,8 @@ class TpuDevice : public Device {
|
||||
|
||||
std::string DebugString() const override;
|
||||
|
||||
static xla::StatusOr<std::vector<std::shared_ptr<xla::Device>>> GetTpuDevices(
|
||||
const tpu_driver::SystemInfo& system_info);
|
||||
static xla::StatusOr<std::vector<std::shared_ptr<xla::PjRtDevice>>>
|
||||
GetTpuDevices(const tpu_driver::SystemInfo& system_info);
|
||||
|
||||
private:
|
||||
const std::array<int, 3> coords_;
|
||||
@ -66,7 +66,7 @@ class PyTpuClient {
|
||||
|
||||
explicit PyTpuClient(std::string platform_name,
|
||||
std::unique_ptr<tpu_driver::TpuDriver> driver,
|
||||
std::vector<std::shared_ptr<Device>> devices,
|
||||
std::vector<std::shared_ptr<PjRtDevice>> devices,
|
||||
int host_id);
|
||||
virtual ~PyTpuClient() = default;
|
||||
|
||||
@ -83,11 +83,11 @@ class PyTpuClient {
|
||||
|
||||
int device_count() const { return devices_.size(); }
|
||||
int local_device_count() const { return local_devices_.size(); }
|
||||
const std::vector<std::shared_ptr<Device>>& devices() { return devices_; }
|
||||
const std::vector<std::shared_ptr<Device>>& local_devices() {
|
||||
const std::vector<std::shared_ptr<PjRtDevice>>& devices() { return devices_; }
|
||||
const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() {
|
||||
return local_devices_;
|
||||
}
|
||||
const std::map<int, std::shared_ptr<Device>>& id_to_device() const {
|
||||
const std::map<int, std::shared_ptr<PjRtDevice>>& id_to_device() const {
|
||||
return id_to_device_;
|
||||
}
|
||||
int host_id() const { return host_id_; }
|
||||
@ -110,11 +110,11 @@ class PyTpuClient {
|
||||
std::unique_ptr<tpu_driver::TpuDriver> driver_;
|
||||
|
||||
// Includes all devices, including non-local devices on multi-host platforms.
|
||||
std::vector<std::shared_ptr<Device>> devices_;
|
||||
std::vector<std::shared_ptr<PjRtDevice>> devices_;
|
||||
// Maps Device::id() to the corresponding Device. Includes all devices.
|
||||
std::map<int, std::shared_ptr<Device>> id_to_device_;
|
||||
std::map<int, std::shared_ptr<PjRtDevice>> id_to_device_;
|
||||
// Local devices indexed by local device ordinal.
|
||||
std::vector<std::shared_ptr<Device>> local_devices_;
|
||||
std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
|
||||
int host_id_;
|
||||
|
||||
// A thread pool for scheduling core executions in parallel.
|
||||
@ -128,7 +128,7 @@ struct TpuSharedBuffer final {
|
||||
TpuSharedBuffer(tpu_driver::TpuDriver* driver,
|
||||
std::unique_ptr<tpu_driver::BufferHandle> handle,
|
||||
std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use,
|
||||
std::shared_ptr<Device> src_device)
|
||||
std::shared_ptr<PjRtDevice> src_device)
|
||||
: driver(driver),
|
||||
device(std::move(src_device)),
|
||||
handle(std::move(handle)),
|
||||
@ -143,7 +143,7 @@ struct TpuSharedBuffer final {
|
||||
}
|
||||
|
||||
tpu_driver::TpuDriver* const driver;
|
||||
const std::shared_ptr<Device> device;
|
||||
const std::shared_ptr<PjRtDevice> device;
|
||||
|
||||
std::unique_ptr<tpu_driver::BufferHandle> handle;
|
||||
std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use;
|
||||
@ -162,12 +162,12 @@ class PyTpuBuffer {
|
||||
static StatusOr<std::unique_ptr<PyTpuBuffer>> FromLiterals(
|
||||
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
|
||||
std::shared_ptr<void> leaves_reference,
|
||||
std::shared_ptr<PyTpuClient> client, std::shared_ptr<Device> device);
|
||||
std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device);
|
||||
|
||||
// Supports nested tuple creation.
|
||||
static StatusOr<std::unique_ptr<PyTpuBuffer>> MakeTuple(
|
||||
absl::Span<PyTpuBuffer* const> buffers,
|
||||
std::shared_ptr<PyTpuClient> client, std::shared_ptr<Device> device);
|
||||
std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device);
|
||||
|
||||
PyTpuBuffer() = delete;
|
||||
PyTpuBuffer(Shape on_host_shape,
|
||||
@ -181,7 +181,7 @@ class PyTpuBuffer {
|
||||
PyTpuBuffer& operator=(PyTpuBuffer&&) = delete;
|
||||
|
||||
const Shape& on_host_shape() const { return on_host_shape_; }
|
||||
std::shared_ptr<Device> device() const { return device_; }
|
||||
std::shared_ptr<PjRtDevice> device() const { return device_; }
|
||||
const std::string& platform_name() const { return client_->platform_name(); }
|
||||
std::shared_ptr<PyTpuClient> client() const { return client_; }
|
||||
|
||||
@ -210,7 +210,7 @@ class PyTpuBuffer {
|
||||
// Copies the buffer to target device `dst_device` and returns a PyTpuBuffer
|
||||
// object holding the context to the target device buffer.
|
||||
StatusOr<std::unique_ptr<PyTpuBuffer>> CopyToDevice(
|
||||
std::shared_ptr<Device> dst_device);
|
||||
std::shared_ptr<PjRtDevice> dst_device);
|
||||
|
||||
// Blocks the host until the buffer's value has been computed and is ready for
|
||||
// immediate use on the device. Useful in particular for timing benchmarks.
|
||||
@ -220,7 +220,7 @@ class PyTpuBuffer {
|
||||
// tuple, the returned buffer corresponds to the root tuple buffer.
|
||||
static StatusOr<std::unique_ptr<PyTpuBuffer>> AllocateBuffer(
|
||||
const Shape& shape, std::shared_ptr<PyTpuClient> client,
|
||||
std::shared_ptr<Device> device);
|
||||
std::shared_ptr<PjRtDevice> device);
|
||||
|
||||
private:
|
||||
// Initializes a just allocated device buffer. The returned event will be
|
||||
@ -231,11 +231,11 @@ class PyTpuBuffer {
|
||||
static StatusOr<std::unique_ptr<PyTpuBuffer>> CreateBuffer(
|
||||
const Shape& non_tuple_shape,
|
||||
absl::optional<BufferInitializer> initializer,
|
||||
std::shared_ptr<PyTpuClient> client, std::shared_ptr<Device> device);
|
||||
std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device);
|
||||
|
||||
const std::shared_ptr<PyTpuClient> client_;
|
||||
const Shape on_host_shape_;
|
||||
const std::shared_ptr<Device> device_;
|
||||
const std::shared_ptr<PjRtDevice> device_;
|
||||
|
||||
// If this is a tuple, `device_buffer_` stores the tuple buffer and
|
||||
// `child_buffers_` stores the child buffers; else, `device_buffer_` stores
|
||||
@ -302,7 +302,7 @@ class PyTpuExecutable {
|
||||
return local_logical_device_ids_;
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<Device>>& local_devices() const {
|
||||
const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const {
|
||||
return local_devices_;
|
||||
}
|
||||
|
||||
@ -350,7 +350,7 @@ class PyTpuExecutable {
|
||||
// assigned.
|
||||
// shared_ptrs instead of unique_ptrs to play well with the Python bindings
|
||||
// (see xla.cc).
|
||||
std::vector<std::shared_ptr<Device>> local_devices_;
|
||||
std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
|
||||
|
||||
xla::Shape result_shape_;
|
||||
};
|
||||
|
||||
@ -40,11 +40,12 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
.def("host_id", &PyTpuClient::host_id)
|
||||
.def("get_default_device_assignment",
|
||||
[](PyTpuClient* client, int num_replicas, int num_partitions)
|
||||
-> StatusOr<std::vector<std::vector<std::shared_ptr<Device>>>> {
|
||||
-> StatusOr<
|
||||
std::vector<std::vector<std::shared_ptr<PjRtDevice>>>> {
|
||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||
client->GetDefaultDeviceAssignment(
|
||||
num_replicas, num_partitions));
|
||||
std::vector<std::vector<std::shared_ptr<Device>>> result;
|
||||
std::vector<std::vector<std::shared_ptr<PjRtDevice>>> result;
|
||||
result.resize(num_replicas);
|
||||
for (int r = 0; r < num_replicas; ++r) {
|
||||
result[r].resize(num_partitions);
|
||||
@ -60,11 +61,11 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
// TODO(skye): delete after all callers can handle 2D output
|
||||
.def("get_default_device_assignment",
|
||||
[](PyTpuClient* client, int num_replicas)
|
||||
-> StatusOr<std::vector<std::shared_ptr<Device>>> {
|
||||
-> StatusOr<std::vector<std::shared_ptr<PjRtDevice>>> {
|
||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||
client->GetDefaultDeviceAssignment(
|
||||
num_replicas, /*num_partitions=*/1));
|
||||
std::vector<std::shared_ptr<Device>> result;
|
||||
std::vector<std::shared_ptr<PjRtDevice>> result;
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
int device_id = device_assignment(i, 0);
|
||||
auto iter = client->id_to_device().find(device_id);
|
||||
@ -96,7 +97,8 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
.def(
|
||||
"buffer_from_pyval",
|
||||
[](std::shared_ptr<PyTpuClient> client,
|
||||
const pybind11::object& argument, std::shared_ptr<Device> device,
|
||||
const pybind11::object& argument,
|
||||
std::shared_ptr<PjRtDevice> device,
|
||||
bool force_copy) -> StatusOr<std::unique_ptr<PyTpuBuffer>> {
|
||||
if (device == nullptr) {
|
||||
TF_RET_CHECK(!client->local_devices().empty());
|
||||
@ -145,7 +147,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
py::class_<PyTpuBuffer>(m, "PyTpuBuffer")
|
||||
.def_property_readonly("client", &PyTpuBuffer::client)
|
||||
.def("copy_to_device",
|
||||
[](PyTpuBuffer* buffer, std::shared_ptr<Device> dst_device) {
|
||||
[](PyTpuBuffer* buffer, std::shared_ptr<PjRtDevice> dst_device) {
|
||||
CHECK(dst_device != nullptr);
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
py::gil_scoped_release gil_release;
|
||||
@ -202,7 +204,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
.def_property_readonly("traceback",
|
||||
[](PyTpuExecutable*) { return py::none(); });
|
||||
|
||||
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
|
||||
py::class_<TpuDevice, PjRtDevice, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
|
||||
.def_property_readonly("coords", &TpuDevice::coords)
|
||||
.def_property_readonly("core_on_chip", &TpuDevice::core_on_chip)
|
||||
.def("__repr__", [](const TpuDevice& device) {
|
||||
|
||||
@ -439,26 +439,26 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
device_assignment);
|
||||
});
|
||||
|
||||
py::class_<Device, ClientAndPtr<Device>>(
|
||||
py::class_<PjRtDevice, ClientAndPtr<PjRtDevice>>(
|
||||
m, "Device",
|
||||
"A descriptor of an available device.\n\nSubclasses are used to "
|
||||
"represent specific types of devices, e.g. CPUs, GPUs. Subclasses may "
|
||||
"have additional properties specific to that device type.")
|
||||
.def_property_readonly(
|
||||
"id", &Device::id,
|
||||
"id", &PjRtDevice::id,
|
||||
"Integer ID of this device.\n\nUnique across all available devices "
|
||||
"of this type, including remote devices on multi-host platforms.")
|
||||
.def_property_readonly("host_id", &Device::host_id,
|
||||
.def_property_readonly("host_id", &PjRtDevice::host_id,
|
||||
"Integer ID of this device's host.\n\n"
|
||||
"This is always 0 except on multi-host platforms.")
|
||||
.def_property_readonly("platform", &Device::platform_name)
|
||||
.def_property_readonly("device_kind", &Device::device_kind)
|
||||
.def_property_readonly("platform", &PjRtDevice::platform_name)
|
||||
.def_property_readonly("device_kind", &PjRtDevice::device_kind)
|
||||
.def_property_readonly(
|
||||
"client",
|
||||
[](const ClientAndPtr<Device>& device) { return device.client; })
|
||||
.def("__str__", &Device::DebugString)
|
||||
[](const ClientAndPtr<PjRtDevice>& device) { return device.client; })
|
||||
.def("__str__", &PjRtDevice::DebugString)
|
||||
.def("transfer_to_infeed",
|
||||
[](const Device& device, const LiteralSlice& literal) {
|
||||
[](const PjRtDevice& device, const LiteralSlice& literal) {
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
@ -468,7 +468,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
})
|
||||
.def(
|
||||
"transfer_from_outfeed",
|
||||
[](const Device& device, const Shape& shape) -> StatusOr<py::object> {
|
||||
[](const PjRtDevice& device,
|
||||
const Shape& shape) -> StatusOr<py::object> {
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
std::shared_ptr<Literal> literal_shared;
|
||||
{
|
||||
@ -492,12 +493,12 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
return LiteralToPython(std::move(literal_shared));
|
||||
});
|
||||
|
||||
py::class_<CpuDevice, Device, ClientAndPtr<CpuDevice>>(m, "CpuDevice")
|
||||
py::class_<CpuDevice, PjRtDevice, ClientAndPtr<CpuDevice>>(m, "CpuDevice")
|
||||
.def("__repr__", [](const CpuDevice& device) {
|
||||
return absl::StrFormat("CpuDevice(id=%i)", device.id());
|
||||
});
|
||||
|
||||
py::class_<GpuDevice, Device, ClientAndPtr<GpuDevice>>(m, "GpuDevice")
|
||||
py::class_<GpuDevice, PjRtDevice, ClientAndPtr<GpuDevice>>(m, "GpuDevice")
|
||||
.def("__repr__", [](const GpuDevice& device) {
|
||||
return absl::StrFormat("GpuDevice(id=%i)", device.id());
|
||||
});
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user