Refactor PJRT client.

Add a few methods to ensure third_party/tensorflow/compiler/xla/python only depend on PJRT interfaces rather than implementation details of PJRT.

PiperOrigin-RevId: 338546086
Change-Id: I859af094d9c56d00b180dbbf367132630d59510c
This commit is contained in:
Qiao Zhang 2020-10-22 14:11:23 -07:00 committed by TensorFlower Gardener
parent 7f9ce6eae5
commit 3ee3eea626
13 changed files with 148 additions and 65 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
namespace xla {
@ -25,7 +26,7 @@ static const char kCpuPlatformName[] = "cpu";
CpuDevice::CpuDevice(int id,
std::unique_ptr<LocalDeviceState> local_device_state)
: PjRtDevice(id, std::move(local_device_state), kCpuPlatformName,
: PjRtDevice(id, std::move(local_device_state),
/*device_kind=*/kCpuPlatformName) {}
StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
@ -57,7 +58,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
}
return std::make_unique<PjRtClient>(
kCpuPlatformName, client, std::move(devices), /*host_id=*/0,
PjRtPlatformId::kCpu, client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr);

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
namespace xla {
@ -25,7 +26,7 @@ static const char kInterpreterPlatformName[] = "interpreter";
InterpreterDevice::InterpreterDevice(
int id, std::unique_ptr<LocalDeviceState> local_device_state)
: PjRtDevice(id, std::move(local_device_state), kInterpreterPlatformName,
: PjRtDevice(id, std::move(local_device_state),
/*device_kind=*/kInterpreterPlatformName) {}
StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
@ -51,7 +52,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
devices.push_back(std::move(device));
return std::make_unique<PjRtClient>(
kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0,
PjRtPlatformId::kInterpreter, client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr);

View File

@ -30,8 +30,6 @@ limitations under the License.
namespace xla {
namespace {
static const char kGpuPlatformName[] = "gpu";
// A custom PjRtClient that overrides the device assignment method.
class GpuClient : public xla::PjRtClient {
public:
@ -298,8 +296,8 @@ Status BuildDistributedDevices(
GpuDevice::GpuDevice(int id,
std::unique_ptr<LocalDeviceState> local_device_state,
std::string device_kind, int node_id)
: PjRtDevice(id, std::move(local_device_state), kGpuPlatformName,
std::move(device_kind), node_id) {}
: PjRtDevice(id, std::move(local_device_state), std::move(device_kind),
node_id) {}
StatusOr<std::unique_ptr<PjRtClient>> GetNvidiaGpuClient(
bool asynchronous, const GpuAllocatorConfig& allocator_config,
@ -325,7 +323,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetNvidiaGpuClient(
}
return std::unique_ptr<PjRtClient>(std::make_unique<GpuClient>(
"gpu", xla_client, std::move(devices),
PjRtPlatformId::kNvidiaGpu, xla_client, std::move(devices),
/*node_id=*/node_id, std::move(allocator),
std::move(host_memory_allocator),
/*should_stage_host_to_device_transfers=*/true,

View File

@ -113,6 +113,13 @@ limitations under the License.
namespace xla {
PjRtPlatformId PjRtDevice::platform_id() const {
return client_->platform_id();
}
const std::string& PjRtDevice::platform_name() const {
return client_->platform_name();
}
StatusOr<LocalDeviceState*> PjRtDevice::GetLocalDeviceState() const {
if (local_device_state_) {
return local_device_state_.get();
@ -145,8 +152,8 @@ StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
devices[replica].size(), replica, devices[0].size());
}
for (int partition = 0; partition < devices[replica].size(); ++partition) {
if (devices[0][0]->platform_name() !=
devices[replica][partition]->platform_name()) {
if (devices[0][0]->platform_id() !=
devices[replica][partition]->platform_id()) {
return InvalidArgument(
"Device assignment passed to Compile() must have devices of a "
"single kind, got %s for replica 0 partition 0 and %s for replica "
@ -175,13 +182,14 @@ class CpuAllocator : public tensorflow::Allocator {
};
PjRtClient::PjRtClient(
std::string platform_name, LocalClient* client,
PjRtPlatformId platform_id, LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options)
: platform_name_(std::move(platform_name)),
: platform_id_(platform_id),
platform_name_(Name(platform_id)),
client_(client),
host_memory_allocator_(std::move(host_memory_allocator)),
devices_(std::move(devices)),
@ -206,15 +214,15 @@ PjRtClient::PjRtClient(
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
<< "Duplicate device id: " << device->id();
if (device->local_device_state()) {
int idx = device->local_device_state()->device_ordinal();
if (device->IsLocalDevice()) {
int idx = device->local_device_id();
if (idx >= local_devices_.size()) {
local_devices_.resize(idx + 1);
}
CHECK(local_devices_[idx] == nullptr) << idx;
local_devices_[idx] = device.get();
}
device->client_ = this;
device->SetClient(this);
}
for (int idx = 0; idx < local_devices_.size(); ++idx) {
CHECK(local_devices_[idx] != nullptr) << idx;
@ -576,6 +584,10 @@ void PjRtBuffer::ScopedHold::AddToInput(
}
}
bool PjRtBuffer::IsOnCpu() const {
return client()->platform_id() == PjRtPlatformId::kCpu;
}
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
@ -865,6 +877,16 @@ StatusOr<Literal> PjRtDevice::TransferFromOutfeed(const Shape& shape) const {
shape, local_device->device_ordinal());
}
StatusOr<PjRtDevice*> PjRtClient::LookupLocalDevice(int local_device_id) const {
for (auto* device : local_devices_) {
if (local_device_id == device->local_device_id()) {
return device;
}
}
return InvalidArgument("No matching device found for local_device_id %d",
local_device_id);
}
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
PjRtClient* client, PjRtDevice* device)
@ -1985,6 +2007,19 @@ PjRtExecutable::ExecuteOnLocalDevices(
return wrapped_results;
}
StatusOr<std::vector<std::shared_ptr<HloModule>>>
PjRtExecutable::GetHloModules() {
std::vector<std::shared_ptr<HloModule>> modules;
modules.reserve(executables().size());
for (const auto& local_exec : executables()) {
if (!local_exec->executable()->has_module()) {
return InvalidArgument("Executable does not have HLO modules.");
}
modules.push_back(local_exec->executable()->shared_module());
}
return std::move(modules);
}
namespace {
StatusOr<Shape> GetShardedShape(const Shape& shape,

View File

@ -36,11 +36,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/thread_annotations.h"
@ -50,26 +52,63 @@ limitations under the License.
namespace xla {
// TODO(zhangqiaorjc): Add a registration mechanism to add new platforms.
enum class PjRtPlatformId : int {
kCpu = 0,
kNvidiaGpu = 1,
kAmdGpu = 2,
kTpu = 3,
kEdgeTpu = 4,
kInterpreter = 5
};
constexpr const char* Name(PjRtPlatformId platform_id) {
switch (platform_id) {
case PjRtPlatformId::kCpu:
return "cpu";
case PjRtPlatformId::kNvidiaGpu:
// TODO(zhangqiaorjc): Rename to nvidia_gpu when we add AMD support.
return "gpu";
case PjRtPlatformId::kAmdGpu:
return "amd_gpu";
case PjRtPlatformId::kTpu:
return "tpu";
case PjRtPlatformId::kEdgeTpu:
return "edge_tpu";
case PjRtPlatformId::kInterpreter:
return "interpreter";
}
}
class PjRtClient;
class PjRtDevice {
public:
explicit PjRtDevice(int id,
std::unique_ptr<LocalDeviceState> local_device_state,
std::string platform_name, std::string device_kind,
int host_id = 0)
std::string device_kind, int host_id = 0)
: id_(id),
local_device_id_(
local_device_state ? local_device_state->device_ordinal() : -1),
local_device_state_(std::move(local_device_state)),
host_id_(host_id),
platform_name_(std::move(platform_name)),
device_kind_(std::move(device_kind)) {}
virtual ~PjRtDevice() {}
// Must set client exactly once.
void SetClient(PjRtClient* client) {
CHECK(client_ == nullptr);
client_ = client;
}
// The ID of this device. IDs are unique among devices of this type
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
int id() const { return id_; }
bool IsLocalDevice() const { return local_device_id_ != -1; }
int local_device_id() const { return local_device_id_; }
// If this is a device local to this host, returns a LocalDeviceState object
// that can be used to manipulate the device. Returns nullptr if the device is
// not local to this host.
@ -85,7 +124,11 @@ class PjRtDevice {
// The ID of this device's host. This is always 0 on single-host platforms.
int host_id() const { return host_id_; }
const std::string& platform_name() const { return platform_name_; }
// Return `platform_id` from client.
PjRtPlatformId platform_id() const;
// Return `platform_name` from client.
const std::string& platform_name() const;
// A vendor-dependent string that uniquely identifies the kind of device.
const std::string& device_kind() const { return device_kind_; }
@ -102,12 +145,10 @@ class PjRtDevice {
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const;
private:
friend class PjRtClient;
const int id_;
const int local_device_id_; // -1 means not local.
const std::unique_ptr<LocalDeviceState> local_device_state_;
const int host_id_;
const std::string platform_name_;
const std::string device_kind_;
PjRtClient* client_ = nullptr;
};
@ -155,7 +196,7 @@ class PjRtClient {
public:
// `allocator` may null, in which case the platform default allocator is used.
explicit PjRtClient(
std::string platform_name, LocalClient* client,
PjRtPlatformId platform_id, LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
@ -178,12 +219,16 @@ class PjRtClient {
return id_to_device_;
}
int host_id() const { return host_id_; }
PjRtPlatformId platform_id() const { return platform_id_; }
const std::string& platform_name() const { return platform_name_; }
LocalDeviceState& device_state(int device_ordinal) const {
return *local_devices_.at(device_ordinal)->local_device_state();
}
// Return a local PjRtDevice for a given `local_device_id`.
virtual StatusOr<PjRtDevice*> LookupLocalDevice(int local_device_id) const;
LocalClient* client() const { return client_; }
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
tensorflow::Allocator* host_memory_allocator() const {
@ -280,6 +325,16 @@ class PjRtClient {
absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier);
virtual StatusOr<ChannelHandle> CreateChannelHandle() {
return client()->CreateChannelHandle();
}
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
return client()->CreateDeviceToHostChannelHandle();
}
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
return client()->CreateHostToDeviceChannelHandle();
}
protected:
friend class PjRtBuffer;
virtual void EnqueueCrossHostReceive(
@ -293,7 +348,8 @@ class PjRtClient {
return Unimplemented("Cross host sends not implemented.");
}
std::string platform_name_;
const PjRtPlatformId platform_id_;
const std::string platform_name_;
LocalClient* client_;
// Allocator to be used for staging memory transfers to devices.
@ -509,7 +565,7 @@ class PjRtBuffer {
PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
PjRtClient* client, PjRtDevice* device);
~PjRtBuffer();
virtual ~PjRtBuffer();
PjRtBuffer(const PjRtBuffer&) = delete;
PjRtBuffer(PjRtBuffer&&) = delete;
@ -519,6 +575,7 @@ class PjRtBuffer {
const Shape& on_host_shape() const { return on_host_shape_; }
const Shape& on_device_shape() const { return on_device_shape_; }
PjRtDevice* device() const { return device_; }
PjRtPlatformId platform_id() const { return client_->platform_id(); }
const std::string& platform_name() const { return client_->platform_name(); }
PjRtClient* client() const { return client_; }
bool IsEmptyTuple() const {
@ -611,6 +668,9 @@ class PjRtBuffer {
// immediate use on the device. Useful in particular for timing benchmarks.
Status BlockHostUntilReady();
// Whether this buffer is on CPU and thus allows for certain optimizations.
bool IsOnCpu() const;
private:
friend class PjRtClient;
// The cached value of the buffer on the host, produced either from a call to
@ -782,6 +842,9 @@ class PjRtExecutable {
const string& name() const;
// Return an HloModule per partition.
StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules();
protected:
bool parameter_is_tupled_arguments() const {
return parameter_is_tupled_arguments_;

View File

@ -118,7 +118,7 @@ PjRtTpuClient::PjRtTpuClient(LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices,
int host_id,
tf_tpu::TpuPlatformInterface* tpu_platform)
: PjRtClient("tpu", client, std::move(devices), host_id,
: PjRtClient(PjRtPlatformId::kTpu, client, std::move(devices), host_id,
/*allocator=*/nullptr,
/*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
@ -145,7 +145,7 @@ StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
return InvalidArgument(
"Passed executable from different client (platform '%s') to "
"PjRtTpuClient::ExecutableFingerprint",
executable.client()->platform_name());
Name(executable.client()->platform_id()));
}
if (executable.executables().size() > 1) {
LOG(INFO) << "ExecutableFingerprint not fully implemented for MPMD "

View File

@ -33,7 +33,7 @@ class PjRtTpuDevice : public PjRtDevice {
int host_id, const std::array<int, 3>& coords,
std::string device_kind)
: PjRtDevice(core.Id(), std::move(local_device_state),
/*platform_name=*/"tpu", std::move(device_kind), host_id),
std::move(device_kind), host_id),
core_(core),
coords_(coords) {}

View File

@ -228,35 +228,31 @@ StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) {
DLContext context;
TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
context.device_id = device.local_device_state()->device_ordinal();
context.device_id = device.local_device_id();
return context;
}
StatusOr<PjRtDevice*> DeviceForDLContext(const PjRtClient& client,
const DLContext& context) {
se::Platform::Id platform_id;
switch (context.device_type) {
case kDLCPU:
platform_id = se::host::kHostPlatformId;
break;
if (client.platform_id() != PjRtPlatformId::kCpu) {
return InvalidArgument(
"DLPack CPU device type mismatch with PjRtClient platform %s",
client.platform_name());
}
return client.LookupLocalDevice(context.device_id);
case kDLGPU:
platform_id = se::cuda::kCudaPlatformId;
break;
if (client.platform_id() != PjRtPlatformId::kNvidiaGpu) {
return InvalidArgument(
"DLPack GPU device type mismatch with PjRtClient platform %s",
client.platform_name());
}
return client.LookupLocalDevice(context.device_id);
default:
return InvalidArgument("Unknown/unsupported DLPack device type %d",
context.device_type);
}
auto it = absl::c_find_if(client.local_devices(), [&](PjRtDevice* device) {
return device->local_device_state()->executor()->platform()->id() ==
platform_id &&
device->local_device_state()->device_ordinal() == context.device_id;
});
if (it == client.local_devices().end()) {
return InvalidArgument(
"No matching device found for DLPack device_type %d device_id %d",
context.device_type, context.device_id);
}
return *it;
}
} // namespace
@ -301,8 +297,7 @@ StatusOr<py::capsule> BufferToDLPackManagedTensor(py::handle py_buffer,
pack->tensor.manager_ctx = pack.get();
pack->tensor.deleter = DLPackTensorDeleter;
TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
dt.ctx.device_id =
buffer->buffer()->device()->local_device_state()->device_ordinal();
dt.ctx.device_id = buffer->buffer()->device()->local_device_id();
dt.ndim = buffer->buffer()->on_host_shape().dimensions_size();
TF_ASSIGN_OR_RETURN(dt.dtype,
PrimitiveTypeToDLDataType(

View File

@ -144,7 +144,7 @@ int PjRtBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) {
// Additionally we call BlockHostUntilReady() below, which may block.
py::gil_scoped_release gil_release;
if (buffer.device()->platform_name() != "cpu") {
if (!buffer.IsOnCpu()) {
return InvalidArgument(
"Python buffer protocol is only defined for CPU buffers.");
}

View File

@ -112,13 +112,13 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
int num_replicas);
StatusOr<ChannelHandle> CreateChannelHandle() {
return pjrt_client_->client()->CreateChannelHandle();
return pjrt_client_->CreateChannelHandle();
}
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
return pjrt_client_->client()->CreateDeviceToHostChannelHandle();
return pjrt_client_->CreateDeviceToHostChannelHandle();
}
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
return pjrt_client_->client()->CreateHostToDeviceChannelHandle();
return pjrt_client_->CreateHostToDeviceChannelHandle();
}
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(

View File

@ -135,15 +135,7 @@ PyExecutable::ExecuteOnLocalDevices(
StatusOr<std::vector<std::shared_ptr<HloModule>>> PyExecutable::HloModules()
const {
std::vector<std::shared_ptr<HloModule>> modules;
modules.reserve(executable_->executables().size());
for (const auto& local_exec : executable_->executables()) {
if (!local_exec->executable()->has_module()) {
return InvalidArgument("Executable does not have HLO modules.");
}
modules.push_back(local_exec->executable()->shared_module());
}
return std::move(modules);
return executable_->GetHloModules();
}
} // namespace xla

View File

@ -37,7 +37,7 @@ namespace xla {
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
int core_on_chip)
: xla::PjRtDevice(id, /*local_device_state=*/nullptr, kTpuPlatform,
: xla::PjRtDevice(id, /*local_device_state=*/nullptr,
/*device_kind=*/"Cloud TPU", host_id),
coords_(coords),
core_on_chip_(core_on_chip) {}

View File

@ -641,9 +641,7 @@ PYBIND11_MODULE(xla_extension, m) {
[](py::object buffer_obj) -> StatusOr<py::object> {
GlobalPyRefManager()->CollectGarbage();
PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
LocalDeviceState* state =
buffer->buffer()->device()->local_device_state();
if (state->executor()->platform_kind() == se::PlatformKind::kHost &&
if (buffer->buffer()->IsOnCpu() &&
buffer->buffer()->on_device_shape().IsArray() &&
buffer->buffer()->on_device_shape().element_type() != BF16) {
py::object out = py::reinterpret_steal<py::object>(