[XLA:Python] Small refactoring to OutfeedReceiver.
Don't keep a std::shared_ptr<PjRtClient> in the OutfeedReceiver class, since it is possible that clients will not manage their PjRtClient using shared ownership. Change in preparation for adding a Python-specific wrapper class around PjRtClient for use of XLA:Python bindings. PiperOrigin-RevId: 314948133 Change-Id: I3d87242fc393272ef4b54fec2f39691765ff91a1
This commit is contained in:
parent
96520cd3a6
commit
8a938dc21c
@ -187,6 +187,7 @@ PjRtClient::PjRtClient(
|
||||
CHECK(local_devices_[idx] == nullptr) << idx;
|
||||
local_devices_[idx] = device.get();
|
||||
}
|
||||
device->client_ = this;
|
||||
}
|
||||
for (int idx = 0; idx < local_devices_.size(); ++idx) {
|
||||
CHECK(local_devices_[idx] != nullptr) << idx;
|
||||
|
@ -47,6 +47,8 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
class PjRtClient;
|
||||
|
||||
class Device {
|
||||
public:
|
||||
explicit Device(int id, std::unique_ptr<LocalDeviceState> local_device_state,
|
||||
@ -86,12 +88,17 @@ class Device {
|
||||
|
||||
virtual std::string DebugString() const;
|
||||
|
||||
PjRtClient* client() const { return client_; }
|
||||
|
||||
private:
|
||||
friend class PjRtClient;
|
||||
|
||||
const int id_;
|
||||
const std::unique_ptr<LocalDeviceState> local_device_state_;
|
||||
const int host_id_;
|
||||
const std::string platform_name_;
|
||||
const std::string device_kind_;
|
||||
PjRtClient* client_ = nullptr;
|
||||
};
|
||||
|
||||
// Forward declaration.
|
||||
|
@ -297,7 +297,7 @@ cc_library(
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "cpu_outfeed_receiver_test",
|
||||
name = "outfeed_receiver_test_cpu",
|
||||
size = "small",
|
||||
srcs = ["outfeed_receiver_test.cc"],
|
||||
deps = [
|
||||
@ -328,6 +328,7 @@ cc_library(
|
||||
":types",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@pybind11",
|
||||
|
@ -98,23 +98,17 @@ uint32_t constexpr kOutfeedHeaderStart = 271828;
|
||||
// Special consumer IDs, without outfeed payload.
|
||||
uint32_t constexpr kOutfeedCidShutdown = 0;
|
||||
|
||||
// A Device and its PjRtClient.
|
||||
struct DeviceWithClient {
|
||||
Device* device;
|
||||
std::shared_ptr<PjRtClient> client;
|
||||
};
|
||||
|
||||
// Encapsulates data received from a device outfeed.
|
||||
class OutfeedData {
|
||||
public:
|
||||
OutfeedData(DeviceWithClient device_client, uint32_t consumer_id, Shape shape)
|
||||
: device_client_(device_client),
|
||||
OutfeedData(Device* device, uint32_t consumer_id, Shape shape)
|
||||
: device_(device),
|
||||
consumer_id_(consumer_id),
|
||||
shape_(shape),
|
||||
literal_(nullptr),
|
||||
literal_size_bytes_(0) {}
|
||||
|
||||
DeviceWithClient device_client() { return device_client_; }
|
||||
Device* device() { return device_; }
|
||||
uint32_t consumer_id() const { return consumer_id_; }
|
||||
Shape shape() const { return shape_; }
|
||||
std::unique_ptr<Literal> literal() {
|
||||
@ -129,7 +123,7 @@ class OutfeedData {
|
||||
std::string DebugString() const;
|
||||
|
||||
private:
|
||||
DeviceWithClient device_client_;
|
||||
Device* device_;
|
||||
uint32_t consumer_id_;
|
||||
Shape shape_;
|
||||
std::unique_ptr<Literal> literal_;
|
||||
@ -150,15 +144,14 @@ void OutfeedData::SetLiteral(std::unique_ptr<Literal> literal) {
|
||||
}
|
||||
|
||||
std::string OutfeedData::DebugString() const {
|
||||
return absl::StrFormat("dev=%s; cons=%d; shape=%s",
|
||||
device_client_.device->DebugString(), consumer_id_,
|
||||
shape_.ToString());
|
||||
return absl::StrFormat("dev=%s; cons=%d; shape=%s", device_->DebugString(),
|
||||
consumer_id_, shape_.ToString());
|
||||
}
|
||||
|
||||
class OutfeedReceiverImpl {
|
||||
public:
|
||||
OutfeedReceiverImpl(OutfeedReceiver::Callback callback,
|
||||
std::vector<std::shared_ptr<PjRtClient>> clients,
|
||||
absl::Span<PjRtClient* const> clients,
|
||||
ssize_t max_callback_queue_size_bytes);
|
||||
|
||||
OutfeedReceiverImpl(const OutfeedReceiverImpl&) = delete;
|
||||
@ -206,8 +199,8 @@ class OutfeedReceiverImpl {
|
||||
void Shutdown();
|
||||
|
||||
OutfeedReceiver::Callback callback_;
|
||||
// The devices on which we are listening, with their clients.
|
||||
std::vector<DeviceWithClient> devices_;
|
||||
// The devices on which we are listening.
|
||||
std::vector<Device*> devices_;
|
||||
// Maximum bytes capacity of the callback queue.
|
||||
uint64_t max_callback_queue_size_bytes_;
|
||||
|
||||
@ -232,14 +225,13 @@ class OutfeedReceiverImpl {
|
||||
};
|
||||
|
||||
OutfeedReceiverImpl::OutfeedReceiverImpl(
|
||||
OutfeedReceiver::Callback callback,
|
||||
std::vector<std::shared_ptr<PjRtClient>> clients,
|
||||
OutfeedReceiver::Callback callback, absl::Span<PjRtClient* const> clients,
|
||||
ssize_t max_callback_queue_size_bytes) {
|
||||
callback_ = callback;
|
||||
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
|
||||
for (const auto& client : clients) {
|
||||
for (const auto& device : client->devices()) {
|
||||
devices_.push_back(DeviceWithClient{device.get(), client});
|
||||
devices_.push_back(device.get());
|
||||
}
|
||||
}
|
||||
CHECK_GT(devices_.size(), 0);
|
||||
@ -291,11 +283,11 @@ void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
++num_listening_threads_;
|
||||
}
|
||||
DeviceWithClient device_client = devices_[device_idx];
|
||||
Device* device = devices_[device_idx];
|
||||
while (true) {
|
||||
Shape header_shape = ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords});
|
||||
std::unique_ptr<Literal> header =
|
||||
ReceiveRawFromOutfeed(device_client.device, header_shape).ValueOrDie();
|
||||
ReceiveRawFromOutfeed(device, header_shape).ValueOrDie();
|
||||
absl::Span<uint32_t> header_data = header->data<uint32>();
|
||||
CHECK_EQ(header_data.size(), kOutfeedHeaderWords);
|
||||
CHECK_EQ(header_data[0], kOutfeedHeaderStart);
|
||||
@ -306,18 +298,17 @@ void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
|
||||
auto registered_shape = shape_registry_.find(consumer_id);
|
||||
if (registered_shape == shape_registry_.end()) {
|
||||
LOG(FATAL)
|
||||
<< "[" << device_client.device->DebugString()
|
||||
<< "[" << device->DebugString()
|
||||
<< "] Cannot find registered shape for consumer ID " << consumer_id
|
||||
<< ". Perhaps the code was compiled with a different instance "
|
||||
<< "of OutfeedReceiver.";
|
||||
}
|
||||
shape = registered_shape->second;
|
||||
}
|
||||
auto received =
|
||||
absl::make_unique<OutfeedData>(device_client, consumer_id, shape);
|
||||
auto received = absl::make_unique<OutfeedData>(device, consumer_id, shape);
|
||||
VLOG(2) << "Listener received header " << received->DebugString();
|
||||
if (consumer_id == kOutfeedCidShutdown) {
|
||||
VLOG(2) << "[" << device_client.device->DebugString()
|
||||
VLOG(2) << "[" << device->DebugString()
|
||||
<< "] Listener received shutdown header";
|
||||
absl::MutexLock lock(&mu_);
|
||||
--num_listening_threads_;
|
||||
@ -328,7 +319,7 @@ void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
|
||||
return;
|
||||
}
|
||||
std::unique_ptr<Literal> data =
|
||||
ReceiveRawFromOutfeed(device_client.device, shape).ValueOrDie();
|
||||
ReceiveRawFromOutfeed(device, shape).ValueOrDie();
|
||||
received->SetLiteral(std::move(data));
|
||||
absl::MutexLock lock(&mu_);
|
||||
EnqueueReceivedData(std::move(received));
|
||||
@ -392,15 +383,14 @@ void OutfeedReceiverImpl::CallbackThreadLoop() {
|
||||
}
|
||||
{
|
||||
tensorflow::profiler::TraceMe traceme("OutfeedReceiver::Callback");
|
||||
DeviceWithClient device_client = received->device_client();
|
||||
callback_(device_client.device, std::move(device_client.client),
|
||||
received->consumer_id(), received->literal());
|
||||
callback_(received->device(), received->consumer_id(),
|
||||
received->literal());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
|
||||
const Device* device = devices_[device_idx].device;
|
||||
const Device* device = devices_[device_idx];
|
||||
constexpr int consumer_id = kOutfeedCidShutdown;
|
||||
VLOG(2) << "[" << device->DebugString()
|
||||
<< "] SendSpecialHeader cons=" << consumer_id;
|
||||
@ -421,7 +411,7 @@ Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<PjRtExecutable> executable,
|
||||
PjRtExecutable::Compile(computation, devices_[device_idx].client.get(),
|
||||
PjRtExecutable::Compile(computation, devices_[device_idx]->client(),
|
||||
std::move(compile_options)));
|
||||
ExecuteOptions execute_options;
|
||||
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
|
||||
@ -468,11 +458,11 @@ StatusOr<XlaOp> OutfeedReceiverImpl::AddOutfeedToBuilder(
|
||||
return token;
|
||||
}
|
||||
|
||||
OutfeedReceiver::OutfeedReceiver(
|
||||
Callback callback, std::vector<std::shared_ptr<PjRtClient>> clients,
|
||||
ssize_t max_callback_queue_size_bytes) {
|
||||
OutfeedReceiver::OutfeedReceiver(Callback callback,
|
||||
absl::Span<PjRtClient* const> clients,
|
||||
ssize_t max_callback_queue_size_bytes) {
|
||||
p_impl_ = absl::make_unique<OutfeedReceiverImpl>(
|
||||
callback, std::move(clients), max_callback_queue_size_bytes);
|
||||
callback, clients, max_callback_queue_size_bytes);
|
||||
}
|
||||
|
||||
OutfeedReceiver::~OutfeedReceiver() {}
|
||||
|
@ -31,10 +31,9 @@ class OutfeedReceiverImpl;
|
||||
// Implements a multithreaded receiver of outfeeds from devices.
|
||||
class OutfeedReceiver {
|
||||
public:
|
||||
// A callback takes: device, client (for the device), consumer id, received.
|
||||
// The client pointer should be alive while the device is used.
|
||||
using Callback = std::function<void(Device*, std::shared_ptr<PjRtClient>,
|
||||
uint32_t, std::shared_ptr<Literal>)>;
|
||||
// A callback takes: device, consumer id, received.
|
||||
using Callback =
|
||||
std::function<void(Device*, uint32_t, std::shared_ptr<Literal>)>;
|
||||
|
||||
// Constructs the receiver for the given clients and callback function.
|
||||
//
|
||||
@ -45,8 +44,7 @@ class OutfeedReceiver {
|
||||
// max_callback_queue_size_bytes: the maximum number of bytes for all
|
||||
// received outfeeds queued to be processed. When this limit is reached
|
||||
// we pause receiving outfeeds from devices.
|
||||
OutfeedReceiver(Callback callback,
|
||||
std::vector<std::shared_ptr<PjRtClient>> clients,
|
||||
OutfeedReceiver(Callback callback, absl::Span<PjRtClient* const> clients,
|
||||
ssize_t max_callback_queue_size_bytes);
|
||||
|
||||
OutfeedReceiver(const OutfeedReceiver&) = delete;
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "pybind11/functional.h"
|
||||
@ -42,16 +43,20 @@ class OutfeedReceiverForPython {
|
||||
|
||||
OutfeedReceiverForPython(CallbackToPython callback_python,
|
||||
std::vector<std::shared_ptr<PjRtClient>> clients,
|
||||
ssize_t max_callback_queue_size_bytes) {
|
||||
callback_python_ = callback_python;
|
||||
outfeed_receiver_shutting_down_ = false;
|
||||
ssize_t max_callback_queue_size_bytes)
|
||||
: callback_python_(std::move(callback_python)),
|
||||
clients_(std::move(clients)) {
|
||||
OutfeedReceiver::Callback callback =
|
||||
[this](Device* device, std::shared_ptr<PjRtClient> client,
|
||||
uint32_t consumer_id, std::shared_ptr<Literal> literal) {
|
||||
this->Callback(device, client, consumer_id, literal);
|
||||
[this](Device* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> literal) {
|
||||
this->Callback(device, consumer_id, std::move(literal));
|
||||
};
|
||||
std::vector<PjRtClient*> client_ptrs(clients.size());
|
||||
absl::c_transform(
|
||||
clients_, client_ptrs.begin(),
|
||||
[](const std::shared_ptr<PjRtClient>& client) { return client.get(); });
|
||||
outfeed_receiver_ = absl::make_unique<OutfeedReceiver>(
|
||||
callback, std::move(clients), max_callback_queue_size_bytes);
|
||||
callback, client_ptrs, max_callback_queue_size_bytes);
|
||||
}
|
||||
OutfeedReceiverForPython(const OutfeedReceiverForPython&) = delete;
|
||||
OutfeedReceiverForPython& operator=(const OutfeedReceiverForPython&) = delete;
|
||||
@ -79,8 +84,8 @@ class OutfeedReceiverForPython {
|
||||
arrays);
|
||||
}
|
||||
|
||||
void Callback(Device* device, std::shared_ptr<PjRtClient> client,
|
||||
uint32_t consumer_id, std::shared_ptr<Literal> literal) {
|
||||
void Callback(Device* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> literal) {
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
if (outfeed_receiver_shutting_down_) {
|
||||
@ -88,19 +93,26 @@ class OutfeedReceiverForPython {
|
||||
return;
|
||||
}
|
||||
}
|
||||
// We expect the number of clients to be small, so an O(n) search is fine.
|
||||
auto it = absl::c_find_if(
|
||||
clients_, [device](const std::shared_ptr<PjRtClient>& client) {
|
||||
return client.get() == device->client();
|
||||
});
|
||||
CHECK(it != clients_.end());
|
||||
py::gil_scoped_acquire gil_acquire; // Need GIL also for LiteralToPython
|
||||
py::object literal_python =
|
||||
LiteralToPython(std::move(literal)).ValueOrDie();
|
||||
// The callback_ should handle all exceptions in user-code. If we get
|
||||
// an exception here, it is a bug in the callback and we should stop.
|
||||
callback_python_(WrapWithClient<Device>(std::move(client), device),
|
||||
consumer_id, std::move(literal_python));
|
||||
callback_python_(WrapWithClient<Device>(*it, device), consumer_id,
|
||||
std::move(literal_python));
|
||||
}
|
||||
|
||||
private:
|
||||
CallbackToPython callback_python_;
|
||||
absl::Mutex mu_;
|
||||
bool outfeed_receiver_shutting_down_ TF_GUARDED_BY(mu_);
|
||||
bool outfeed_receiver_shutting_down_ TF_GUARDED_BY(mu_) = false;
|
||||
std::vector<std::shared_ptr<PjRtClient>> clients_;
|
||||
std::unique_ptr<OutfeedReceiver> outfeed_receiver_;
|
||||
};
|
||||
|
||||
|
@ -75,14 +75,14 @@ class Accumulator {
|
||||
TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
|
||||
GetCpuClient(true));
|
||||
std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client};
|
||||
std::vector<PjRtClient*> clients{cpu_client.get()};
|
||||
|
||||
auto receiver = absl::make_unique<Accumulator>();
|
||||
OutfeedReceiver::Callback callback =
|
||||
[&receiver](Device* device, std::shared_ptr<PjRtClient> client,
|
||||
uint32_t consumer_id, std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
OutfeedReceiver::Callback callback = [&receiver](
|
||||
Device* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
auto outfeed_receiver =
|
||||
std::make_shared<OutfeedReceiver>(callback, clients, 128);
|
||||
outfeed_receiver->Start();
|
||||
@ -108,14 +108,14 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) {
|
||||
TEST(OutfeedReceiverTest, ReceiveOutfeedTwoComputations) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
|
||||
GetCpuClient(true));
|
||||
std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client};
|
||||
std::vector<PjRtClient*> clients{cpu_client.get()};
|
||||
|
||||
auto receiver = absl::make_unique<Accumulator>();
|
||||
OutfeedReceiver::Callback callback =
|
||||
[&receiver](Device* device, std::shared_ptr<PjRtClient> client,
|
||||
uint32_t consumer_id, std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
OutfeedReceiver::Callback callback = [&receiver](
|
||||
Device* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
auto outfeed_receiver =
|
||||
std::make_shared<OutfeedReceiver>(callback, clients, 128);
|
||||
outfeed_receiver->Start();
|
||||
@ -153,14 +153,14 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedTwoComputations) {
|
||||
TEST(OutfeedReceiverTest, ReceiveOutfeedTwoOutfeed) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
|
||||
GetCpuClient(true));
|
||||
std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client};
|
||||
std::vector<PjRtClient*> clients{cpu_client.get()};
|
||||
|
||||
auto receiver = absl::make_unique<Accumulator>();
|
||||
OutfeedReceiver::Callback callback =
|
||||
[&receiver](Device* device, std::shared_ptr<PjRtClient> client,
|
||||
uint32_t consumer_id, std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
OutfeedReceiver::Callback callback = [&receiver](
|
||||
Device* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
auto outfeed_receiver =
|
||||
std::make_shared<OutfeedReceiver>(callback, clients, 128);
|
||||
outfeed_receiver->Start();
|
||||
@ -196,14 +196,14 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedTwoOutfeed) {
|
||||
TEST(OutfeedReceiverTest, DifferentShapeForConsumerIdError) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
|
||||
GetCpuClient(true));
|
||||
std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client};
|
||||
std::vector<PjRtClient*> clients{cpu_client.get()};
|
||||
|
||||
auto receiver = absl::make_unique<Accumulator>();
|
||||
OutfeedReceiver::Callback callback =
|
||||
[&receiver](Device* device, std::shared_ptr<PjRtClient> client,
|
||||
uint32_t consumer_id, std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
OutfeedReceiver::Callback callback = [&receiver](
|
||||
Device* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
auto outfeed_receiver =
|
||||
std::make_shared<OutfeedReceiver>(callback, clients, 128);
|
||||
outfeed_receiver->Start();
|
||||
@ -230,14 +230,14 @@ TEST(OutfeedReceiverTest, DifferentShapeForConsumerIdError) {
|
||||
TEST(OutfeedReceiverTest, InvalidConsumerIdError) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
|
||||
GetCpuClient(true));
|
||||
std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client};
|
||||
std::vector<PjRtClient*> clients{cpu_client.get()};
|
||||
|
||||
auto receiver = absl::make_unique<Accumulator>();
|
||||
OutfeedReceiver::Callback callback =
|
||||
[&receiver](Device* device, std::shared_ptr<PjRtClient> client,
|
||||
uint32_t consumer_id, std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
OutfeedReceiver::Callback callback = [&receiver](
|
||||
Device* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
auto outfeed_receiver =
|
||||
std::make_shared<OutfeedReceiver>(callback, clients, 128);
|
||||
outfeed_receiver->Start();
|
||||
|
Loading…
Reference in New Issue
Block a user