[XLA:Python] Small refactoring to OutfeedReceiver.

Don't keep a std::shared_ptr<PjRtClient> in the OutfeedReceiver class, since it is possible that clients will not manage their PjRtClient using shared ownership.

Change in preparation for adding a Python-specific wrapper class around PjRtClient for use of XLA:Python bindings.

PiperOrigin-RevId: 314948133
Change-Id: I3d87242fc393272ef4b54fec2f39691765ff91a1
This commit is contained in:
Peter Hawkins 2020-06-05 10:07:33 -07:00 committed by TensorFlower Gardener
parent 96520cd3a6
commit 8a938dc21c
7 changed files with 93 additions and 84 deletions

View File

@ -187,6 +187,7 @@ PjRtClient::PjRtClient(
CHECK(local_devices_[idx] == nullptr) << idx;
local_devices_[idx] = device.get();
}
device->client_ = this;
}
for (int idx = 0; idx < local_devices_.size(); ++idx) {
CHECK(local_devices_[idx] != nullptr) << idx;

View File

@ -47,6 +47,8 @@ limitations under the License.
namespace xla {
class PjRtClient;
class Device {
public:
explicit Device(int id, std::unique_ptr<LocalDeviceState> local_device_state,
@ -86,12 +88,17 @@ class Device {
virtual std::string DebugString() const;
PjRtClient* client() const { return client_; }
private:
friend class PjRtClient;
const int id_;
const std::unique_ptr<LocalDeviceState> local_device_state_;
const int host_id_;
const std::string platform_name_;
const std::string device_kind_;
PjRtClient* client_ = nullptr;
};
// Forward declaration.

View File

@ -297,7 +297,7 @@ cc_library(
)
tf_cc_test(
name = "cpu_outfeed_receiver_test",
name = "outfeed_receiver_test_cpu",
size = "small",
srcs = ["outfeed_receiver_test.cc"],
deps = [
@ -328,6 +328,7 @@ cc_library(
":types",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"@pybind11",

View File

@ -98,23 +98,17 @@ uint32_t constexpr kOutfeedHeaderStart = 271828;
// Special consumer IDs, without outfeed payload.
uint32_t constexpr kOutfeedCidShutdown = 0;
// A Device and its PjRtClient.
struct DeviceWithClient {
Device* device;
std::shared_ptr<PjRtClient> client;
};
// Encapsulates data received from a device outfeed.
class OutfeedData {
public:
OutfeedData(DeviceWithClient device_client, uint32_t consumer_id, Shape shape)
: device_client_(device_client),
OutfeedData(Device* device, uint32_t consumer_id, Shape shape)
: device_(device),
consumer_id_(consumer_id),
shape_(shape),
literal_(nullptr),
literal_size_bytes_(0) {}
DeviceWithClient device_client() { return device_client_; }
Device* device() { return device_; }
uint32_t consumer_id() const { return consumer_id_; }
Shape shape() const { return shape_; }
std::unique_ptr<Literal> literal() {
@ -129,7 +123,7 @@ class OutfeedData {
std::string DebugString() const;
private:
DeviceWithClient device_client_;
Device* device_;
uint32_t consumer_id_;
Shape shape_;
std::unique_ptr<Literal> literal_;
@ -150,15 +144,14 @@ void OutfeedData::SetLiteral(std::unique_ptr<Literal> literal) {
}
std::string OutfeedData::DebugString() const {
return absl::StrFormat("dev=%s; cons=%d; shape=%s",
device_client_.device->DebugString(), consumer_id_,
shape_.ToString());
return absl::StrFormat("dev=%s; cons=%d; shape=%s", device_->DebugString(),
consumer_id_, shape_.ToString());
}
class OutfeedReceiverImpl {
public:
OutfeedReceiverImpl(OutfeedReceiver::Callback callback,
std::vector<std::shared_ptr<PjRtClient>> clients,
absl::Span<PjRtClient* const> clients,
ssize_t max_callback_queue_size_bytes);
OutfeedReceiverImpl(const OutfeedReceiverImpl&) = delete;
@ -206,8 +199,8 @@ class OutfeedReceiverImpl {
void Shutdown();
OutfeedReceiver::Callback callback_;
// The devices on which we are listening, with their clients.
std::vector<DeviceWithClient> devices_;
// The devices on which we are listening.
std::vector<Device*> devices_;
// Maximum bytes capacity of the callback queue.
uint64_t max_callback_queue_size_bytes_;
@ -232,14 +225,13 @@ class OutfeedReceiverImpl {
};
OutfeedReceiverImpl::OutfeedReceiverImpl(
OutfeedReceiver::Callback callback,
std::vector<std::shared_ptr<PjRtClient>> clients,
OutfeedReceiver::Callback callback, absl::Span<PjRtClient* const> clients,
ssize_t max_callback_queue_size_bytes) {
callback_ = callback;
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
for (const auto& client : clients) {
for (const auto& device : client->devices()) {
devices_.push_back(DeviceWithClient{device.get(), client});
devices_.push_back(device.get());
}
}
CHECK_GT(devices_.size(), 0);
@ -291,11 +283,11 @@ void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
absl::MutexLock lock(&mu_);
++num_listening_threads_;
}
DeviceWithClient device_client = devices_[device_idx];
Device* device = devices_[device_idx];
while (true) {
Shape header_shape = ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords});
std::unique_ptr<Literal> header =
ReceiveRawFromOutfeed(device_client.device, header_shape).ValueOrDie();
ReceiveRawFromOutfeed(device, header_shape).ValueOrDie();
absl::Span<uint32_t> header_data = header->data<uint32>();
CHECK_EQ(header_data.size(), kOutfeedHeaderWords);
CHECK_EQ(header_data[0], kOutfeedHeaderStart);
@ -306,18 +298,17 @@ void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
auto registered_shape = shape_registry_.find(consumer_id);
if (registered_shape == shape_registry_.end()) {
LOG(FATAL)
<< "[" << device_client.device->DebugString()
<< "[" << device->DebugString()
<< "] Cannot find registered shape for consumer ID " << consumer_id
<< ". Perhaps the code was compiled with a different instance "
<< "of OutfeedReceiver.";
}
shape = registered_shape->second;
}
auto received =
absl::make_unique<OutfeedData>(device_client, consumer_id, shape);
auto received = absl::make_unique<OutfeedData>(device, consumer_id, shape);
VLOG(2) << "Listener received header " << received->DebugString();
if (consumer_id == kOutfeedCidShutdown) {
VLOG(2) << "[" << device_client.device->DebugString()
VLOG(2) << "[" << device->DebugString()
<< "] Listener received shutdown header";
absl::MutexLock lock(&mu_);
--num_listening_threads_;
@ -328,7 +319,7 @@ void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
return;
}
std::unique_ptr<Literal> data =
ReceiveRawFromOutfeed(device_client.device, shape).ValueOrDie();
ReceiveRawFromOutfeed(device, shape).ValueOrDie();
received->SetLiteral(std::move(data));
absl::MutexLock lock(&mu_);
EnqueueReceivedData(std::move(received));
@ -392,15 +383,14 @@ void OutfeedReceiverImpl::CallbackThreadLoop() {
}
{
tensorflow::profiler::TraceMe traceme("OutfeedReceiver::Callback");
DeviceWithClient device_client = received->device_client();
callback_(device_client.device, std::move(device_client.client),
received->consumer_id(), received->literal());
callback_(received->device(), received->consumer_id(),
received->literal());
}
}
}
Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
const Device* device = devices_[device_idx].device;
const Device* device = devices_[device_idx];
constexpr int consumer_id = kOutfeedCidShutdown;
VLOG(2) << "[" << device->DebugString()
<< "] SendSpecialHeader cons=" << consumer_id;
@ -421,7 +411,7 @@ Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtExecutable> executable,
PjRtExecutable::Compile(computation, devices_[device_idx].client.get(),
PjRtExecutable::Compile(computation, devices_[device_idx]->client(),
std::move(compile_options)));
ExecuteOptions execute_options;
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
@ -468,11 +458,11 @@ StatusOr<XlaOp> OutfeedReceiverImpl::AddOutfeedToBuilder(
return token;
}
OutfeedReceiver::OutfeedReceiver(
Callback callback, std::vector<std::shared_ptr<PjRtClient>> clients,
ssize_t max_callback_queue_size_bytes) {
OutfeedReceiver::OutfeedReceiver(Callback callback,
absl::Span<PjRtClient* const> clients,
ssize_t max_callback_queue_size_bytes) {
p_impl_ = absl::make_unique<OutfeedReceiverImpl>(
callback, std::move(clients), max_callback_queue_size_bytes);
callback, clients, max_callback_queue_size_bytes);
}
OutfeedReceiver::~OutfeedReceiver() {}

View File

@ -31,10 +31,9 @@ class OutfeedReceiverImpl;
// Implements a multithreaded receiver of outfeeds from devices.
class OutfeedReceiver {
public:
// A callback takes: device, client (for the device), consumer id, received.
// The client pointer should be alive while the device is used.
using Callback = std::function<void(Device*, std::shared_ptr<PjRtClient>,
uint32_t, std::shared_ptr<Literal>)>;
// A callback takes: device, consumer id, received.
using Callback =
std::function<void(Device*, uint32_t, std::shared_ptr<Literal>)>;
// Constructs the receiver for the given clients and callback function.
//
@ -45,8 +44,7 @@ class OutfeedReceiver {
// max_callback_queue_size_bytes: the maximum number of bytes for all
// received outfeeds queued to be processed. When this limit is reached
// we pause receiving outfeeds from devices.
OutfeedReceiver(Callback callback,
std::vector<std::shared_ptr<PjRtClient>> clients,
OutfeedReceiver(Callback callback, absl::Span<PjRtClient* const> clients,
ssize_t max_callback_queue_size_bytes);
OutfeedReceiver(const OutfeedReceiver&) = delete;

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
#include "absl/synchronization/mutex.h"
#include "pybind11/functional.h"
@ -42,16 +43,20 @@ class OutfeedReceiverForPython {
OutfeedReceiverForPython(CallbackToPython callback_python,
std::vector<std::shared_ptr<PjRtClient>> clients,
ssize_t max_callback_queue_size_bytes) {
callback_python_ = callback_python;
outfeed_receiver_shutting_down_ = false;
ssize_t max_callback_queue_size_bytes)
: callback_python_(std::move(callback_python)),
clients_(std::move(clients)) {
OutfeedReceiver::Callback callback =
[this](Device* device, std::shared_ptr<PjRtClient> client,
uint32_t consumer_id, std::shared_ptr<Literal> literal) {
this->Callback(device, client, consumer_id, literal);
[this](Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> literal) {
this->Callback(device, consumer_id, std::move(literal));
};
std::vector<PjRtClient*> client_ptrs(clients.size());
absl::c_transform(
clients_, client_ptrs.begin(),
[](const std::shared_ptr<PjRtClient>& client) { return client.get(); });
outfeed_receiver_ = absl::make_unique<OutfeedReceiver>(
callback, std::move(clients), max_callback_queue_size_bytes);
callback, client_ptrs, max_callback_queue_size_bytes);
}
OutfeedReceiverForPython(const OutfeedReceiverForPython&) = delete;
OutfeedReceiverForPython& operator=(const OutfeedReceiverForPython&) = delete;
@ -79,8 +84,8 @@ class OutfeedReceiverForPython {
arrays);
}
void Callback(Device* device, std::shared_ptr<PjRtClient> client,
uint32_t consumer_id, std::shared_ptr<Literal> literal) {
void Callback(Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> literal) {
{
absl::MutexLock lock(&mu_);
if (outfeed_receiver_shutting_down_) {
@ -88,19 +93,26 @@ class OutfeedReceiverForPython {
return;
}
}
// We expect the number of clients to be small, so an O(n) search is fine.
auto it = absl::c_find_if(
clients_, [device](const std::shared_ptr<PjRtClient>& client) {
return client.get() == device->client();
});
CHECK(it != clients_.end());
py::gil_scoped_acquire gil_acquire; // Need GIL also for LiteralToPython
py::object literal_python =
LiteralToPython(std::move(literal)).ValueOrDie();
// The callback_ should handle all exceptions in user-code. If we get
// an exception here, it is a bug in the callback and we should stop.
callback_python_(WrapWithClient<Device>(std::move(client), device),
consumer_id, std::move(literal_python));
callback_python_(WrapWithClient<Device>(*it, device), consumer_id,
std::move(literal_python));
}
private:
CallbackToPython callback_python_;
absl::Mutex mu_;
bool outfeed_receiver_shutting_down_ TF_GUARDED_BY(mu_);
bool outfeed_receiver_shutting_down_ TF_GUARDED_BY(mu_) = false;
std::vector<std::shared_ptr<PjRtClient>> clients_;
std::unique_ptr<OutfeedReceiver> outfeed_receiver_;
};

View File

@ -75,14 +75,14 @@ class Accumulator {
TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) {
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
GetCpuClient(true));
std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client};
std::vector<PjRtClient*> clients{cpu_client.get()};
auto receiver = absl::make_unique<Accumulator>();
OutfeedReceiver::Callback callback =
[&receiver](Device* device, std::shared_ptr<PjRtClient> client,
uint32_t consumer_id, std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
OutfeedReceiver::Callback callback = [&receiver](
Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
auto outfeed_receiver =
std::make_shared<OutfeedReceiver>(callback, clients, 128);
outfeed_receiver->Start();
@ -108,14 +108,14 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) {
TEST(OutfeedReceiverTest, ReceiveOutfeedTwoComputations) {
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
GetCpuClient(true));
std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client};
std::vector<PjRtClient*> clients{cpu_client.get()};
auto receiver = absl::make_unique<Accumulator>();
OutfeedReceiver::Callback callback =
[&receiver](Device* device, std::shared_ptr<PjRtClient> client,
uint32_t consumer_id, std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
OutfeedReceiver::Callback callback = [&receiver](
Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
auto outfeed_receiver =
std::make_shared<OutfeedReceiver>(callback, clients, 128);
outfeed_receiver->Start();
@ -153,14 +153,14 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedTwoComputations) {
TEST(OutfeedReceiverTest, ReceiveOutfeedTwoOutfeed) {
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
GetCpuClient(true));
std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client};
std::vector<PjRtClient*> clients{cpu_client.get()};
auto receiver = absl::make_unique<Accumulator>();
OutfeedReceiver::Callback callback =
[&receiver](Device* device, std::shared_ptr<PjRtClient> client,
uint32_t consumer_id, std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
OutfeedReceiver::Callback callback = [&receiver](
Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
auto outfeed_receiver =
std::make_shared<OutfeedReceiver>(callback, clients, 128);
outfeed_receiver->Start();
@ -196,14 +196,14 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedTwoOutfeed) {
TEST(OutfeedReceiverTest, DifferentShapeForConsumerIdError) {
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
GetCpuClient(true));
std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client};
std::vector<PjRtClient*> clients{cpu_client.get()};
auto receiver = absl::make_unique<Accumulator>();
OutfeedReceiver::Callback callback =
[&receiver](Device* device, std::shared_ptr<PjRtClient> client,
uint32_t consumer_id, std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
OutfeedReceiver::Callback callback = [&receiver](
Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
auto outfeed_receiver =
std::make_shared<OutfeedReceiver>(callback, clients, 128);
outfeed_receiver->Start();
@ -230,14 +230,14 @@ TEST(OutfeedReceiverTest, DifferentShapeForConsumerIdError) {
TEST(OutfeedReceiverTest, InvalidConsumerIdError) {
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
GetCpuClient(true));
std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client};
std::vector<PjRtClient*> clients{cpu_client.get()};
auto receiver = absl::make_unique<Accumulator>();
OutfeedReceiver::Callback callback =
[&receiver](Device* device, std::shared_ptr<PjRtClient> client,
uint32_t consumer_id, std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
OutfeedReceiver::Callback callback = [&receiver](
Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
auto outfeed_receiver =
std::make_shared<OutfeedReceiver>(callback, clients, 128);
outfeed_receiver->Start();