[XLA] Change LocalClient::TransferFromOutfeedLocal to write into a caller-provided literal rather than allocating its own literal.
PiperOrigin-RevId: 353904746 Change-Id: Ia9868bee8c90c5e2a99fbb787eede16dbe75d2d1
This commit is contained in:
parent
8e985babf8
commit
f5d89fe581
@ -433,14 +433,12 @@ Status LocalClient::TransferToInfeedLocal(const LiteralSlice& literal,
|
|||||||
literal);
|
literal);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
|
Status LocalClient::TransferFromOutfeedLocal(int device_ordinal,
|
||||||
int device_ordinal) {
|
MutableBorrowingLiteral literal) {
|
||||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
||||||
backend().stream_executor(device_ordinal));
|
backend().stream_executor(device_ordinal));
|
||||||
auto literal = Literal::CreateFromShape(shape);
|
return backend().transfer_manager()->TransferLiteralFromOutfeed(executor,
|
||||||
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
|
literal);
|
||||||
executor, &literal));
|
|
||||||
return std::move(literal);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<int> LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) {
|
StatusOr<int> LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) {
|
||||||
|
@ -172,13 +172,13 @@ class LocalClient : public Client {
|
|||||||
// Client::TransferToInfeed.
|
// Client::TransferToInfeed.
|
||||||
Status TransferToInfeedLocal(const LiteralSlice& literal, int device_ordinal);
|
Status TransferToInfeedLocal(const LiteralSlice& literal, int device_ordinal);
|
||||||
|
|
||||||
// Transfer and return a value of the given shape from the outfeed of the
|
// Transfer and return a value from the outfeed of the given device. The
|
||||||
// given device.
|
// shape of the object to transfer is determined by `literal`'s shape.
|
||||||
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
|
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
|
||||||
// not inherit from Client and there is no possibility of confusion with
|
// not inherit from Client and there is no possibility of confusion with
|
||||||
// Client::TransferFromOutfeed.
|
// Client::TransferFromOutfeed.
|
||||||
StatusOr<Literal> TransferFromOutfeedLocal(const Shape& shape,
|
Status TransferFromOutfeedLocal(int device_ordinal,
|
||||||
int device_ordinal);
|
MutableBorrowingLiteral literal);
|
||||||
|
|
||||||
// Returns the device ordinal that corresponds to the given replica number.
|
// Returns the device ordinal that corresponds to the given replica number.
|
||||||
//
|
//
|
||||||
|
@ -132,6 +132,7 @@ cc_library(
|
|||||||
hdrs = ["pjrt_client.h"],
|
hdrs = ["pjrt_client.h"],
|
||||||
visibility = ["//tensorflow/compiler/xla:friends"],
|
visibility = ["//tensorflow/compiler/xla:friends"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status",
|
"//tensorflow/compiler/xla:status",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
#include "tensorflow/compiler/xla/layout.h"
|
#include "tensorflow/compiler/xla/layout.h"
|
||||||
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/shape.h"
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
@ -93,7 +94,7 @@ class PjRtDevice {
|
|||||||
virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0;
|
virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0;
|
||||||
|
|
||||||
// Transfer and return a value of the given shape from the outfeed queue.
|
// Transfer and return a value of the given shape from the outfeed queue.
|
||||||
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const = 0;
|
virtual Status TransferFromOutfeed(MutableBorrowingLiteral literal) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Forward declaration.
|
// Forward declaration.
|
||||||
|
@ -910,11 +910,11 @@ Status PjRtStreamExecutorDevice::TransferToInfeed(
|
|||||||
literal, local_device->device_ordinal());
|
literal, local_device->device_ordinal());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<Literal> PjRtStreamExecutorDevice::TransferFromOutfeed(
|
Status PjRtStreamExecutorDevice::TransferFromOutfeed(
|
||||||
const Shape& shape) const {
|
MutableBorrowingLiteral literal) const {
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
||||||
return local_device->client()->TransferFromOutfeedLocal(
|
return local_device->client()->TransferFromOutfeedLocal(
|
||||||
shape, local_device->device_ordinal());
|
local_device->device_ordinal(), literal);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
|
StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
|
||||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
#include "tensorflow/compiler/xla/layout.h"
|
#include "tensorflow/compiler/xla/layout.h"
|
||||||
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
|
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||||
@ -106,7 +107,7 @@ class PjRtStreamExecutorDevice : public PjRtDevice {
|
|||||||
|
|
||||||
Status TransferToInfeed(const LiteralSlice& literal) const override;
|
Status TransferToInfeed(const LiteralSlice& literal) const override;
|
||||||
|
|
||||||
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override;
|
Status TransferFromOutfeed(MutableBorrowingLiteral literal) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const int id_;
|
const int id_;
|
||||||
|
@ -341,11 +341,9 @@ void OutfeedReceiverImpl::EnqueueReceivedData(
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
|
StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
|
||||||
const PjRtDevice* device, const Shape& shape) {
|
const PjRtDevice* device, const Shape& shape) {
|
||||||
std::shared_ptr<Literal> literal_shared;
|
auto literal = std::make_unique<Literal>(shape);
|
||||||
|
TF_RETURN_IF_ERROR(device->TransferFromOutfeed(literal.get()));
|
||||||
TF_ASSIGN_OR_RETURN(Literal literal, device->TransferFromOutfeed(shape));
|
return literal;
|
||||||
|
|
||||||
return absl::make_unique<Literal>(std::move(literal));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void OutfeedReceiverImpl::CallbackThreadLoop() {
|
void OutfeedReceiverImpl::CallbackThreadLoop() {
|
||||||
|
@ -69,7 +69,7 @@ class TpuDevice : public PjRtDevice {
|
|||||||
return Unimplemented("Infeed not yet implemented via this API");
|
return Unimplemented("Infeed not yet implemented via this API");
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override {
|
Status TransferFromOutfeed(MutableBorrowingLiteral literal) const override {
|
||||||
return Unimplemented("Outfeed not yet implemented via this API");
|
return Unimplemented("Outfeed not yet implemented via this API");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,7 +146,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
[](const PjRtDevice& device,
|
[](const PjRtDevice& device,
|
||||||
const Shape& shape) -> StatusOr<py::object> {
|
const Shape& shape) -> StatusOr<py::object> {
|
||||||
GlobalPyRefManager()->CollectGarbage();
|
GlobalPyRefManager()->CollectGarbage();
|
||||||
std::shared_ptr<Literal> literal_shared;
|
std::shared_ptr<Literal> literal;
|
||||||
{
|
{
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
Shape shape_with_layout = shape;
|
Shape shape_with_layout = shape;
|
||||||
@ -156,12 +156,10 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
LayoutUtil::SetToDefaultLayout(subshape);
|
LayoutUtil::SetToDefaultLayout(subshape);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
TF_ASSIGN_OR_RETURN(Literal literal, device.TransferFromOutfeed(
|
literal = std::make_shared<Literal>(shape_with_layout);
|
||||||
shape_with_layout));
|
TF_RETURN_IF_ERROR(device.TransferFromOutfeed(literal.get()));
|
||||||
|
|
||||||
literal_shared = std::make_shared<Literal>(std::move(literal));
|
|
||||||
}
|
}
|
||||||
return LiteralToPython(std::move(literal_shared));
|
return LiteralToPython(std::move(literal));
|
||||||
});
|
});
|
||||||
|
|
||||||
py::class_<CpuDevice, PjRtDevice, ClientAndPtr<CpuDevice>>(m, "CpuDevice")
|
py::class_<CpuDevice, PjRtDevice, ClientAndPtr<CpuDevice>>(m, "CpuDevice")
|
||||||
|
@ -937,9 +937,9 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) {
|
|||||||
LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
|
LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
|
||||||
local_client_->default_device_ordinal()));
|
local_client_->default_device_ordinal()));
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(Literal result,
|
Literal result(shape);
|
||||||
local_client_->TransferFromOutfeedLocal(
|
ASSERT_IS_OK(local_client_->TransferFromOutfeedLocal(
|
||||||
shape, local_client_->default_device_ordinal()));
|
local_client_->default_device_ordinal(), &result));
|
||||||
|
|
||||||
LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
|
LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
|
||||||
}
|
}
|
||||||
|
@ -92,9 +92,8 @@ void TestWithDeviceCount(const int device_count) {
|
|||||||
|
|
||||||
for (int device_ordinal = 0; device_ordinal < device_count;
|
for (int device_ordinal = 0; device_ordinal < device_count;
|
||||||
device_ordinal++) {
|
device_ordinal++) {
|
||||||
TF_ASSERT_OK_AND_ASSIGN(Literal outfeed,
|
Literal outfeed(ShapeUtil::MakeShape(S32, {}));
|
||||||
client->TransferFromOutfeedLocal(
|
TF_ASSERT_OK(client->TransferFromOutfeedLocal(device_ordinal, &outfeed));
|
||||||
ShapeUtil::MakeShape(S32, {}), device_ordinal));
|
|
||||||
EXPECT_EQ(outfeed, LiteralUtil::CreateR0<int32>(device_ordinal * 100 + 1));
|
EXPECT_EQ(outfeed, LiteralUtil::CreateR0<int32>(device_ordinal * 100 + 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -282,9 +282,9 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
|
|||||||
outfeed_thread_pool.emplace(tensorflow::Env::Default(), "infeed",
|
outfeed_thread_pool.emplace(tensorflow::Env::Default(), "infeed",
|
||||||
/*num_threads=*/1);
|
/*num_threads=*/1);
|
||||||
auto consume_outfeed = [client, outfeed_shape] {
|
auto consume_outfeed = [client, outfeed_shape] {
|
||||||
|
Literal outfeed(*outfeed_shape);
|
||||||
TF_CHECK_OK(
|
TF_CHECK_OK(
|
||||||
client->TransferFromOutfeedLocal(*outfeed_shape, /*device_ordinal=*/0)
|
client->TransferFromOutfeedLocal(/*device_ordinal=*/0, &outfeed));
|
||||||
.status());
|
|
||||||
VLOG(1) << "Received outfeed data of shape "
|
VLOG(1) << "Received outfeed data of shape "
|
||||||
<< ShapeUtil::HumanStringWithLayout(*outfeed_shape);
|
<< ShapeUtil::HumanStringWithLayout(*outfeed_shape);
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user