STT-tensorflow/tensorflow/compiler/xla/python/outfeed_receiver_py.cc
Peter Hawkins 3f7adb0d48 [XLA:Python] Add a PyClient wrapper object around PjRtClient.
Refactoring in preparation for adding Python-specific logic to clients, but it also is more readable than inlining this kind of logic into the binding code.

PiperOrigin-RevId: 314994078
Change-Id: I3b7253d4a14ff418068e49f9bd321d85695f9c8b
2020-06-05 14:18:04 -07:00

171 lines
6.8 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/outfeed_receiver_py.h"
#include <memory>
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
#include "absl/synchronization/mutex.h"
#include "pybind11/functional.h"
#include "pybind11/pybind11.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/outfeed_receiver.h"
#include "tensorflow/compiler/xla/python/py_client.h"
#include "tensorflow/compiler/xla/python/types.h"
namespace xla {
namespace py = pybind11;
namespace {
// A wrapper for OutfeedReceiver for use from Python, useful for ensuring
// that the GIL is released before destroying the OutfeedReceiver.
class OutfeedReceiverForPython {
public:
// A callback to Python takes: consumer id, received literal.
using CallbackToPython =
std::function<void(ClientAndPtr<Device>, uint32_t, pybind11::object)>;
OutfeedReceiverForPython(CallbackToPython callback_python,
std::vector<std::shared_ptr<PyClient>> clients,
ssize_t max_callback_queue_size_bytes)
: callback_python_(std::move(callback_python)),
clients_(std::move(clients)) {
OutfeedReceiver::Callback callback =
[this](Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> literal) {
this->Callback(device, consumer_id, std::move(literal));
};
std::vector<PjRtClient*> client_ptrs(clients.size());
absl::c_transform(clients_, client_ptrs.begin(),
[](const std::shared_ptr<PyClient>& client) {
return client->pjrt_client();
});
outfeed_receiver_ = absl::make_unique<OutfeedReceiver>(
callback, client_ptrs, max_callback_queue_size_bytes);
}
OutfeedReceiverForPython(const OutfeedReceiverForPython&) = delete;
OutfeedReceiverForPython& operator=(const OutfeedReceiverForPython&) = delete;
~OutfeedReceiverForPython() {
// This destructor is called from the Python GC. Release it for the duration
// of the destruction, including the destruction of the OutfeedReceiver,
// when we may actually have to wait for threads to end. During this time
// we do not callback to Python (sometimes we get an exception
// "std::runtime_error: scoped_acquire::dec_ref(): thread state must
// be current!"").
{
absl::MutexLock lock(&mu_);
outfeed_receiver_shutting_down_ = true;
}
py::gil_scoped_release gil_release;
outfeed_receiver_ = nullptr; // Shutdown the outfeed receiver.
}
void Start() { outfeed_receiver_->Start(); }
StatusOr<XlaOp> AddOutfeed(XlaBuilder* builder, XlaOp token,
uint32_t consumer_id, std::vector<XlaOp> arrays) {
return outfeed_receiver_->AddOutfeedToBuilder(builder, token, consumer_id,
arrays);
}
void Callback(Device* device, uint32_t consumer_id,
std::shared_ptr<Literal> literal) {
{
absl::MutexLock lock(&mu_);
if (outfeed_receiver_shutting_down_) {
VLOG(2) << "Ignoring unsafe callback to Python during shutdown";
return;
}
}
// We expect the number of clients to be small, so an O(n) search is fine.
auto it = absl::c_find_if(
clients_, [device](const std::shared_ptr<PyClient>& client) {
return client->pjrt_client() == device->client();
});
CHECK(it != clients_.end());
py::gil_scoped_acquire gil_acquire; // Need GIL also for LiteralToPython
py::object literal_python =
LiteralToPython(std::move(literal)).ValueOrDie();
// The callback_ should handle all exceptions in user-code. If we get
// an exception here, it is a bug in the callback and we should stop.
callback_python_(WrapWithClient<Device>(*it, device), consumer_id,
std::move(literal_python));
}
private:
CallbackToPython callback_python_;
absl::Mutex mu_;
bool outfeed_receiver_shutting_down_ TF_GUARDED_BY(mu_) = false;
std::vector<std::shared_ptr<PyClient>> clients_;
std::unique_ptr<OutfeedReceiver> outfeed_receiver_;
};
} // namespace
void BuildOutfeedReceiverSubmodule(py::module* m) {
py::module outfeed_receiver =
m->def_submodule("outfeed_receiver", "Outfeed receiver");
outfeed_receiver.def(
"start",
[](OutfeedReceiverForPython::CallbackToPython callback_to_python,
std::vector<std::shared_ptr<PyClient>> clients,
ssize_t max_callback_queue_size_bytes)
-> std::unique_ptr<OutfeedReceiverForPython> {
auto server = absl::make_unique<OutfeedReceiverForPython>(
callback_to_python, clients, max_callback_queue_size_bytes);
server->Start();
return server;
},
py::arg("callback_to_python"), py::arg("backends"),
py::arg("max_queue_size_bytes") = 256 * 1024 * 1024,
R"(Starts a multithreaded outfeed receiver.
There is one thread for each of the specified devices. When Python
drops the last reference to the returned object, the receiver is shut
down. The destructor will block until all data is received from
devices.
Args:
* callback_to_python: a Python callback to call, with <consumer_id>
and the data received.
* backends: the list of backends to listen on.
* max_queue_size_bytes: an optional integer to bound the maximum size
of arrays in the callback queue. When this limit is reached the
device listener pauses.
)",
py::call_guard<py::gil_scoped_release>());
py::class_<OutfeedReceiverForPython> outfeed_receiver_class(
outfeed_receiver, "OutfeedReceiverForPython");
outfeed_receiver_class.def(
"add_outfeed", &OutfeedReceiverForPython::AddOutfeed, py::arg("builder"),
py::arg("token"), py::arg("consumer_id"), py::arg("arrays"),
R"(Adds an outfeed into the given computation builder.
Has the side-effect of registering the sent shape along with the consumer
ID. Returns error if the outfeed shape is not compatible with previously
used shape for the same consumer ID.)",
py::call_guard<py::gil_scoped_release>());
}
} // namespace xla