[XLA:Python] Add infeed/outfeed methods to Device.

Change in preparation for removing them from PyLocalClient.

PiperOrigin-RevId: 292016124
Change-Id: I0f9d0fb4e053f24b2e313611731afaa796388f55
This commit is contained in:
Peter Hawkins 2020-01-28 14:38:03 -08:00 committed by TensorFlower Gardener
parent 9c6e8d5d19
commit 7a4123bda5
6 changed files with 45 additions and 12 deletions

View File

@ -152,6 +152,7 @@ cc_library(
":worker_thread",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor",
"@com_google_absl//absl/memory",

View File

@ -197,7 +197,7 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
se::StreamExecutor* executor =
client->backend().stream_executor(i).ValueOrDie();
auto device_state = absl::make_unique<LocalDeviceState>(
executor, synchronous_deallocation, asynchronous,
executor, client, synchronous_deallocation, asynchronous,
/*allow_event_reuse=*/gpu_platform);
devices.push_back(MakeDevice(platform_name, i, std::move(device_state)));
}

View File

@ -25,12 +25,14 @@ limitations under the License.
namespace xla {
LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor,
LocalClient* client,
bool synchronous_deallocation,
bool asynchronous, bool allow_event_reuse)
: synchronous_deallocation_(synchronous_deallocation),
event_pool_(allow_event_reuse),
compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1),
executor_(executor),
client_(client),
prng_seed_generator_(prng_seed_device_()),
prng_seed_distribution_(std::numeric_limits<int>::min(),
std::numeric_limits<int>::max()) {

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/python/event_pool.h"
#include "tensorflow/compiler/xla/python/semaphore.h"
#include "tensorflow/compiler/xla/python/worker_thread.h"
@ -41,14 +42,17 @@ class LocalDeviceState {
//
// If asynchronous is false, the host will synchronize to the device after
// each execution or transfer. This is intended for debugging only.
LocalDeviceState(se::StreamExecutor* executor, bool synchronous_deallocation,
bool asynchronous, bool allow_event_reuse);
LocalDeviceState(se::StreamExecutor* executor, LocalClient* client,
bool synchronous_deallocation, bool asynchronous,
bool allow_event_reuse);
virtual ~LocalDeviceState();
se::StreamExecutor* executor() const { return executor_; }
// StreamExecutor (local) device ordinal.
int device_ordinal() const { return executor_->device_ordinal(); }
LocalClient* client() const { return client_; }
bool synchronous_deallocation() const { return synchronous_deallocation_; }
EventPool& event_pool() { return event_pool_; }
@ -113,7 +117,8 @@ class LocalDeviceState {
// stream by the host ahead of the device.
Semaphore compute_semaphore_;
se::StreamExecutor* executor_;
se::StreamExecutor* const executor_;
LocalClient* const client_;
std::unique_ptr<se::Stream> compute_stream_;
std::unique_ptr<se::Stream> host_to_device_stream_;
std::vector<std::unique_ptr<se::Stream>> device_to_host_streams_;

View File

@ -340,7 +340,34 @@ PYBIND11_MODULE(xla_extension, m) {
"Integer ID of this device's host.\n\n"
"This is always 0 except on multi-host platforms.")
.def_property_readonly("platform", &Device::platform_name)
.def("__str__", &Device::DebugString);
.def("__str__", &Device::DebugString)
.def("TransferToInfeed",
[](const Device& device, const LiteralSlice& literal) {
GlobalPyRefManager()->CollectGarbage();
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device.GetLocalDeviceState());
return local_device->client()->TransferToInfeedLocal(
literal, local_device->device_ordinal());
})
.def(
"TransferFromOutfeed",
[](const Device& device, const Shape& shape) -> StatusOr<py::object> {
GlobalPyRefManager()->CollectGarbage();
std::shared_ptr<Literal> literal_shared;
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device.GetLocalDeviceState());
TF_ASSIGN_OR_RETURN(
Literal literal,
local_device->client()->TransferFromOutfeedLocal(
shape, local_device->device_ordinal()));
literal_shared = std::make_shared<Literal>(std::move(literal));
}
return LiteralToPython(std::move(literal_shared));
});
py::class_<CpuDevice, Device, std::shared_ptr<CpuDevice>>(m, "CpuDevice")
.def("__repr__", [](const CpuDevice& device) {
@ -411,8 +438,7 @@ PYBIND11_MODULE(xla_extension, m) {
}
return result;
})
// TODO(phawkins): delete overload that accepts a device_ordinal after
// all callers have been updated to pass a Device.
// TODO(phawkins): delete these methods in favor of the versions on Device
.def("TransferToInfeed",
[](PyLocalClient* client, const LiteralSlice& literal,
int device_ordinal) {
@ -430,8 +456,7 @@ PYBIND11_MODULE(xla_extension, m) {
py::gil_scoped_release gil_release;
return client->TransferToInfeed(literal, device);
})
// TODO(phawkins): delete overload that accepts a device_ordinal after
// all callers have been updated to pass a Device.
// TODO(phawkins): delete these methods in favor of the versions on Device
.def("TransferFromOutfeed",
[](PyLocalClient* client, const Shape& shape,
int device_ordinal) -> StatusOr<py::object> {

View File

@ -459,7 +459,7 @@ def transfer_to_infeed(value, device=None):
# TODO(phawkins): support non-default backends.
backend = get_local_backend()
device = device or backend.local_devices()[0]
backend.client.TransferToInfeed(value, device)
device.TransferToInfeed(value)
def transfer_from_outfeed(shape, device=None):
@ -476,8 +476,8 @@ def transfer_from_outfeed(shape, device=None):
# TODO(phawkins): support non-default backends.
backend = get_local_backend()
device = device or backend.local_devices()[0]
return backend.client.TransferFromOutfeed(
shape.with_major_to_minor_layout_if_absent(), device)
return device.TransferFromOutfeed(
shape.with_major_to_minor_layout_if_absent())
DeviceAssignment = _xla.DeviceAssignment