[XLA] Drop useless shape argument from TransferManager::TransferLiteralFromOutfeed.

The shape is redundant with the shape in the output literal, and it turned out that some implementations (e.g., TPU) simply ignored this duplicate shape.

Leave the outfeed shape in the TPU C API to maintain backward compatibility.

PiperOrigin-RevId: 353765398
Change-Id: Ide493732676a35c061f81c5c2ce50c0f664c5ccf
This commit is contained in:
Peter Hawkins 2021-01-25 17:24:06 -08:00 committed by TensorFlower Gardener
parent f4a8ef169d
commit 588f5a7f60
15 changed files with 36 additions and 45 deletions

View File

@ -439,7 +439,7 @@ StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
backend().stream_executor(device_ordinal));
auto literal = Literal::CreateFromShape(shape);
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
executor, shape, &literal));
executor, &literal));
return std::move(literal);
}

View File

@ -176,15 +176,14 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor,
}
Status CpuTransferManager::TransferLiteralFromOutfeed(
se::StreamExecutor* executor, const Shape& literal_shape,
MutableBorrowingLiteral literal) {
if (!literal_shape.IsTuple()) {
int64 size = GetByteSizeRequirement(literal_shape);
se::StreamExecutor* executor, MutableBorrowingLiteral literal) {
if (!literal.shape().IsTuple()) {
int64 size = GetByteSizeRequirement(literal.shape());
// Note: OSS build didn't like implicit conversion from
// literal_shape.dimensions() to the array slice on 2017-07-10.
// literal.shape().dimensions() to the array slice on 2017-07-10.
absl::Span<const int64> dimensions(
absl::bit_cast<const int64*>(literal_shape.dimensions().data()),
literal_shape.dimensions().size());
absl::bit_cast<const int64*>(literal.shape().dimensions().data()),
literal.shape().dimensions().size());
TF_ASSIGN_OR_RETURN(
Shape received_shape,
TransferArrayBufferFromOutfeed(executor, literal.untyped_data(), size));
@ -192,21 +191,21 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
<< "Shape received from outfeed "
<< ShapeUtil::HumanString(received_shape)
<< " did not match the shape that was requested for outfeed: "
<< ShapeUtil::HumanString(literal_shape);
<< ShapeUtil::HumanString(literal.shape());
TF_RET_CHECK(size == GetByteSizeRequirement(received_shape));
*literal.mutable_shape_do_not_use() = received_shape;
return Status::OK();
}
if (ShapeUtil::IsNestedTuple(literal_shape)) {
if (ShapeUtil::IsNestedTuple(literal.shape())) {
return Unimplemented(
"Nested tuple outfeeds are not yet implemented on CPU.");
}
std::vector<std::pair<void*, int64>> buffer_data;
for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) {
for (int64 i = 0; i < literal.shape().tuple_shapes_size(); ++i) {
const Shape& tuple_element_shape =
ShapeUtil::GetTupleElementShape(literal_shape, i);
ShapeUtil::GetTupleElementShape(literal.shape(), i);
int64 size = GetByteSizeRequirement(tuple_element_shape);
buffer_data.push_back({literal.untyped_data({i}), size});
}
@ -214,15 +213,15 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
TF_ASSIGN_OR_RETURN(Shape received_shape,
TransferTupleBuffersFromOutfeed(executor, buffer_data));
TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal_shape))
TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape()))
<< "Shape received from outfeed "
<< ShapeUtil::HumanString(received_shape)
<< " did not match the shape that was requested for outfeed: "
<< ShapeUtil::HumanString(literal_shape);
TF_RET_CHECK(GetByteSizeRequirement(literal_shape) ==
<< ShapeUtil::HumanString(literal.shape());
TF_RET_CHECK(GetByteSizeRequirement(literal.shape()) ==
GetByteSizeRequirement(received_shape));
TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal_shape));
TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal.shape()));
return Status::OK();
}

View File

@ -42,7 +42,6 @@ class CpuTransferManager : public GenericTransferManager {
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
const LiteralSlice& literal) override;
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
const Shape& literal_shape,
MutableBorrowingLiteral literal) override;
bool CanShapedBufferBeAccessedNow(

View File

@ -145,8 +145,7 @@ Status GenericTransferManager::TransferLiteralToInfeed(
}
Status GenericTransferManager::TransferLiteralFromOutfeed(
se::StreamExecutor* executor, const Shape& literal_shape,
MutableBorrowingLiteral literal) {
se::StreamExecutor* executor, MutableBorrowingLiteral literal) {
return Unimplemented("Generic transfer from Outfeed");
}

View File

@ -53,7 +53,6 @@ class GenericTransferManager : public TransferManager {
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
const LiteralSlice& literal) override;
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
const Shape& literal_shape,
MutableBorrowingLiteral literal) override;
Status ResetDevices(absl::Span<se::StreamExecutor* const> executors) override;

View File

@ -114,13 +114,12 @@ StatusOr<InfeedBuffer> GpuTransferManager::TransferBufferToInfeedInternal(
}
Status GpuTransferManager::TransferLiteralFromOutfeed(
se::StreamExecutor* /*executor*/, const Shape& literal_shape,
MutableBorrowingLiteral literal) {
se::StreamExecutor* /*executor*/, MutableBorrowingLiteral literal) {
ShapeTree<std::unique_ptr<gpu::OutfeedBuffer>> outfeed_buffers(
&literal_shape);
&literal.shape());
for (auto& leaf : outfeed_buffers.leaves()) {
const Shape& shape = ShapeUtil::GetSubshape(literal_shape, leaf.first);
const Shape& shape = ShapeUtil::GetSubshape(literal.shape(), leaf.first);
CHECK(shape.IsArray()) << ShapeUtil::HumanStringWithLayout(shape);
leaf.second =
absl::make_unique<gpu::OutfeedBuffer>(GetByteSizeRequirement(shape));
@ -135,7 +134,7 @@ Status GpuTransferManager::TransferLiteralFromOutfeed(
// Now wait till all the buffers are written.
for (auto& leaf : outfeed_buffers.leaves()) {
const Shape& shape = ShapeUtil::GetSubshape(literal_shape, leaf.first);
const Shape& shape = ShapeUtil::GetSubshape(literal.shape(), leaf.first);
CHECK(shape.IsArray()) << ShapeUtil::HumanStringWithLayout(shape);
leaf.second->WaitUntilAvailable();
}

View File

@ -41,7 +41,6 @@ class GpuTransferManager : public GenericTransferManager {
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
const LiteralSlice& literal) override;
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
const Shape& literal_shape,
MutableBorrowingLiteral literal) override;
private:

View File

@ -285,9 +285,9 @@ StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicatedImpl(
VLOG(1) << "Starting outfeed on device " << device;
for (int64 step = 1;
options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
Literal literal;
Literal literal(options.outfeed_shape);
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
executor, options.outfeed_shape, &literal));
executor, &literal));
if (options.outfeed_values != nullptr) {
options.outfeed_values->push_back(std::move(literal));
}

View File

@ -1037,7 +1037,7 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
executor, Shape(arg->shape_with_layout()), &literal));
executor, &literal));
*result->mutable_literal() = literal.ToProto();
return Status::OK();
}

View File

@ -203,10 +203,10 @@ class TransferManager {
const LiteralSlice& literal) = 0;
// Transfers the given literal from the Outfeed interface of the device,
// using the given executor.
// using the given executor. The shape and layout are determined by the
// shape and layout of `literal`.
virtual Status TransferLiteralFromOutfeed(
se::StreamExecutor* executor, const Shape& literal_shape,
MutableBorrowingLiteral literal) = 0;
se::StreamExecutor* executor, MutableBorrowingLiteral literal) = 0;
// Resets the devices associated with this transfer manager.
virtual Status ResetDevices(

View File

@ -573,7 +573,7 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3}));
auto literal = Literal::CreateFromShape(expected.shape());
TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
backend().default_stream_executor(), expected.shape(), &literal));
backend().default_stream_executor(), &literal));
EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
}

View File

@ -53,8 +53,8 @@ Status TpuOutfeedDequeueOp<T>::DoWork(
VLOG(1) << "TransferLiteralFromOutfeed "
<< xla::ShapeUtil::HumanStringWithLayout(xla_shape_);
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralFromOutfeed(
stream_executor, xla_shape_, literal));
TF_RETURN_IF_ERROR(
transfer_manager->TransferLiteralFromOutfeed(stream_executor, literal));
VLOG(1) << "TransferLiteralFromOutfeed complete.";
@ -96,8 +96,8 @@ Status TpuOutfeedDequeueTupleOp<T>::DoWork(
xla::MutableBorrowingLiteral literal;
TF_RETURN_IF_ERROR(
HostTensorToMutableBorrowingLiteral(xla_shapes_[i], output, &literal));
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralFromOutfeed(
stream_executor, xla_shapes_[i], literal));
TF_RETURN_IF_ERROR(
transfer_manager->TransferLiteralFromOutfeed(stream_executor, literal));
}
return Status::OK();
}

View File

@ -219,11 +219,9 @@ void TpuTransferManager_TransferBuffersToInfeed(XLA_TransferManager* manager,
int64_t* buffers_size_in_uint32,
int64_t buffers_array_size,
TF_Status* status);
void TpuTransferManager_TransferLiteralFromOutfeed(XLA_TransferManager* manager,
SE_StreamExecutor* executor,
XLA_Shape* shape,
XLA_Literal* c_literal,
TF_Status* status);
void TpuTransferManager_TransferLiteralFromOutfeed(
XLA_TransferManager* manager, SE_StreamExecutor* executor,
XLA_Shape* shape /*deprecated*/, XLA_Literal* c_literal, TF_Status* status);
void TpuTransferManager_ResetDevices(XLA_TransferManager* manager,
SE_StreamExecutor** executors,
int64_t num_executors, TF_Status* status);

View File

@ -127,14 +127,14 @@ Status TpuTransferManager::TransferBuffersToInfeed(
}
Status TpuTransferManager::TransferLiteralFromOutfeed(
stream_executor::StreamExecutor* executor, const xla::Shape& literal_shape,
stream_executor::StreamExecutor* executor,
xla::MutableBorrowingLiteral literal) {
StatusHelper status;
XLA_Shape c_shape;
XLA_Literal c_literal;
auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
ApiConverter::ToC(literal_shape, &c_shape);
ApiConverter::ToC(literal.shape(), &c_shape);
ApiConverter::ToC(literal, &c_literal);
tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromOutfeedFn(

View File

@ -56,7 +56,6 @@ class TpuTransferManager : public xla::TpuTransferManagerInterface {
Status TransferLiteralFromOutfeed(
stream_executor::StreamExecutor* executor,
const xla::Shape& literal_shape,
xla::MutableBorrowingLiteral literal) override;
Status TransferBuffersToInfeed(