[XLA] Drop useless shape argument from TransferManager::TransferLiteralFromOutfeed.
The shape is redundant with the shape in the output literal, and it turned out that some implementations (e.g., TPU) simply ignored this duplicate shape. Leave the outfeed shape in the TPU C API to maintain backward compatibility. PiperOrigin-RevId: 353765398 Change-Id: Ide493732676a35c061f81c5c2ce50c0f664c5ccf
This commit is contained in:
parent
f4a8ef169d
commit
588f5a7f60
@ -439,7 +439,7 @@ StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
|
||||
backend().stream_executor(device_ordinal));
|
||||
auto literal = Literal::CreateFromShape(shape);
|
||||
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
|
||||
executor, shape, &literal));
|
||||
executor, &literal));
|
||||
return std::move(literal);
|
||||
}
|
||||
|
||||
|
@ -176,15 +176,14 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor,
|
||||
}
|
||||
|
||||
Status CpuTransferManager::TransferLiteralFromOutfeed(
|
||||
se::StreamExecutor* executor, const Shape& literal_shape,
|
||||
MutableBorrowingLiteral literal) {
|
||||
if (!literal_shape.IsTuple()) {
|
||||
int64 size = GetByteSizeRequirement(literal_shape);
|
||||
se::StreamExecutor* executor, MutableBorrowingLiteral literal) {
|
||||
if (!literal.shape().IsTuple()) {
|
||||
int64 size = GetByteSizeRequirement(literal.shape());
|
||||
// Note: OSS build didn't like implicit conversion from
|
||||
// literal_shape.dimensions() to the array slice on 2017-07-10.
|
||||
// literal.shape().dimensions() to the array slice on 2017-07-10.
|
||||
absl::Span<const int64> dimensions(
|
||||
absl::bit_cast<const int64*>(literal_shape.dimensions().data()),
|
||||
literal_shape.dimensions().size());
|
||||
absl::bit_cast<const int64*>(literal.shape().dimensions().data()),
|
||||
literal.shape().dimensions().size());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape received_shape,
|
||||
TransferArrayBufferFromOutfeed(executor, literal.untyped_data(), size));
|
||||
@ -192,21 +191,21 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
|
||||
<< "Shape received from outfeed "
|
||||
<< ShapeUtil::HumanString(received_shape)
|
||||
<< " did not match the shape that was requested for outfeed: "
|
||||
<< ShapeUtil::HumanString(literal_shape);
|
||||
<< ShapeUtil::HumanString(literal.shape());
|
||||
TF_RET_CHECK(size == GetByteSizeRequirement(received_shape));
|
||||
*literal.mutable_shape_do_not_use() = received_shape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (ShapeUtil::IsNestedTuple(literal_shape)) {
|
||||
if (ShapeUtil::IsNestedTuple(literal.shape())) {
|
||||
return Unimplemented(
|
||||
"Nested tuple outfeeds are not yet implemented on CPU.");
|
||||
}
|
||||
|
||||
std::vector<std::pair<void*, int64>> buffer_data;
|
||||
for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) {
|
||||
for (int64 i = 0; i < literal.shape().tuple_shapes_size(); ++i) {
|
||||
const Shape& tuple_element_shape =
|
||||
ShapeUtil::GetTupleElementShape(literal_shape, i);
|
||||
ShapeUtil::GetTupleElementShape(literal.shape(), i);
|
||||
int64 size = GetByteSizeRequirement(tuple_element_shape);
|
||||
buffer_data.push_back({literal.untyped_data({i}), size});
|
||||
}
|
||||
@ -214,15 +213,15 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
|
||||
TF_ASSIGN_OR_RETURN(Shape received_shape,
|
||||
TransferTupleBuffersFromOutfeed(executor, buffer_data));
|
||||
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal_shape))
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape()))
|
||||
<< "Shape received from outfeed "
|
||||
<< ShapeUtil::HumanString(received_shape)
|
||||
<< " did not match the shape that was requested for outfeed: "
|
||||
<< ShapeUtil::HumanString(literal_shape);
|
||||
TF_RET_CHECK(GetByteSizeRequirement(literal_shape) ==
|
||||
<< ShapeUtil::HumanString(literal.shape());
|
||||
TF_RET_CHECK(GetByteSizeRequirement(literal.shape()) ==
|
||||
GetByteSizeRequirement(received_shape));
|
||||
|
||||
TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal_shape));
|
||||
TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal.shape()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -42,7 +42,6 @@ class CpuTransferManager : public GenericTransferManager {
|
||||
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
|
||||
const LiteralSlice& literal) override;
|
||||
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
|
||||
const Shape& literal_shape,
|
||||
MutableBorrowingLiteral literal) override;
|
||||
|
||||
bool CanShapedBufferBeAccessedNow(
|
||||
|
@ -145,8 +145,7 @@ Status GenericTransferManager::TransferLiteralToInfeed(
|
||||
}
|
||||
|
||||
Status GenericTransferManager::TransferLiteralFromOutfeed(
|
||||
se::StreamExecutor* executor, const Shape& literal_shape,
|
||||
MutableBorrowingLiteral literal) {
|
||||
se::StreamExecutor* executor, MutableBorrowingLiteral literal) {
|
||||
return Unimplemented("Generic transfer from Outfeed");
|
||||
}
|
||||
|
||||
|
@ -53,7 +53,6 @@ class GenericTransferManager : public TransferManager {
|
||||
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
|
||||
const LiteralSlice& literal) override;
|
||||
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
|
||||
const Shape& literal_shape,
|
||||
MutableBorrowingLiteral literal) override;
|
||||
|
||||
Status ResetDevices(absl::Span<se::StreamExecutor* const> executors) override;
|
||||
|
@ -114,13 +114,12 @@ StatusOr<InfeedBuffer> GpuTransferManager::TransferBufferToInfeedInternal(
|
||||
}
|
||||
|
||||
Status GpuTransferManager::TransferLiteralFromOutfeed(
|
||||
se::StreamExecutor* /*executor*/, const Shape& literal_shape,
|
||||
MutableBorrowingLiteral literal) {
|
||||
se::StreamExecutor* /*executor*/, MutableBorrowingLiteral literal) {
|
||||
ShapeTree<std::unique_ptr<gpu::OutfeedBuffer>> outfeed_buffers(
|
||||
&literal_shape);
|
||||
&literal.shape());
|
||||
|
||||
for (auto& leaf : outfeed_buffers.leaves()) {
|
||||
const Shape& shape = ShapeUtil::GetSubshape(literal_shape, leaf.first);
|
||||
const Shape& shape = ShapeUtil::GetSubshape(literal.shape(), leaf.first);
|
||||
CHECK(shape.IsArray()) << ShapeUtil::HumanStringWithLayout(shape);
|
||||
leaf.second =
|
||||
absl::make_unique<gpu::OutfeedBuffer>(GetByteSizeRequirement(shape));
|
||||
@ -135,7 +134,7 @@ Status GpuTransferManager::TransferLiteralFromOutfeed(
|
||||
|
||||
// Now wait till all the buffers are written.
|
||||
for (auto& leaf : outfeed_buffers.leaves()) {
|
||||
const Shape& shape = ShapeUtil::GetSubshape(literal_shape, leaf.first);
|
||||
const Shape& shape = ShapeUtil::GetSubshape(literal.shape(), leaf.first);
|
||||
CHECK(shape.IsArray()) << ShapeUtil::HumanStringWithLayout(shape);
|
||||
leaf.second->WaitUntilAvailable();
|
||||
}
|
||||
|
@ -41,7 +41,6 @@ class GpuTransferManager : public GenericTransferManager {
|
||||
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
|
||||
const LiteralSlice& literal) override;
|
||||
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
|
||||
const Shape& literal_shape,
|
||||
MutableBorrowingLiteral literal) override;
|
||||
|
||||
private:
|
||||
|
@ -285,9 +285,9 @@ StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicatedImpl(
|
||||
VLOG(1) << "Starting outfeed on device " << device;
|
||||
for (int64 step = 1;
|
||||
options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
|
||||
Literal literal;
|
||||
Literal literal(options.outfeed_shape);
|
||||
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
|
||||
executor, options.outfeed_shape, &literal));
|
||||
executor, &literal));
|
||||
if (options.outfeed_values != nullptr) {
|
||||
options.outfeed_values->push_back(std::move(literal));
|
||||
}
|
||||
|
@ -1037,7 +1037,7 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
|
||||
executor, Shape(arg->shape_with_layout()), &literal));
|
||||
executor, &literal));
|
||||
*result->mutable_literal() = literal.ToProto();
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -203,10 +203,10 @@ class TransferManager {
|
||||
const LiteralSlice& literal) = 0;
|
||||
|
||||
// Transfers the given literal from the Outfeed interface of the device,
|
||||
// using the given executor.
|
||||
// using the given executor. The shape and layout are determined by the
|
||||
// shape and layout of `literal`.
|
||||
virtual Status TransferLiteralFromOutfeed(
|
||||
se::StreamExecutor* executor, const Shape& literal_shape,
|
||||
MutableBorrowingLiteral literal) = 0;
|
||||
se::StreamExecutor* executor, MutableBorrowingLiteral literal) = 0;
|
||||
|
||||
// Resets the devices associated with this transfer manager.
|
||||
virtual Status ResetDevices(
|
||||
|
@ -573,7 +573,7 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
|
||||
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3}));
|
||||
auto literal = Literal::CreateFromShape(expected.shape());
|
||||
TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
|
||||
backend().default_stream_executor(), expected.shape(), &literal));
|
||||
backend().default_stream_executor(), &literal));
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
|
||||
}
|
||||
|
||||
|
@ -53,8 +53,8 @@ Status TpuOutfeedDequeueOp<T>::DoWork(
|
||||
VLOG(1) << "TransferLiteralFromOutfeed "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(xla_shape_);
|
||||
|
||||
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralFromOutfeed(
|
||||
stream_executor, xla_shape_, literal));
|
||||
TF_RETURN_IF_ERROR(
|
||||
transfer_manager->TransferLiteralFromOutfeed(stream_executor, literal));
|
||||
|
||||
VLOG(1) << "TransferLiteralFromOutfeed complete.";
|
||||
|
||||
@ -96,8 +96,8 @@ Status TpuOutfeedDequeueTupleOp<T>::DoWork(
|
||||
xla::MutableBorrowingLiteral literal;
|
||||
TF_RETURN_IF_ERROR(
|
||||
HostTensorToMutableBorrowingLiteral(xla_shapes_[i], output, &literal));
|
||||
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralFromOutfeed(
|
||||
stream_executor, xla_shapes_[i], literal));
|
||||
TF_RETURN_IF_ERROR(
|
||||
transfer_manager->TransferLiteralFromOutfeed(stream_executor, literal));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -219,11 +219,9 @@ void TpuTransferManager_TransferBuffersToInfeed(XLA_TransferManager* manager,
|
||||
int64_t* buffers_size_in_uint32,
|
||||
int64_t buffers_array_size,
|
||||
TF_Status* status);
|
||||
void TpuTransferManager_TransferLiteralFromOutfeed(XLA_TransferManager* manager,
|
||||
SE_StreamExecutor* executor,
|
||||
XLA_Shape* shape,
|
||||
XLA_Literal* c_literal,
|
||||
TF_Status* status);
|
||||
void TpuTransferManager_TransferLiteralFromOutfeed(
|
||||
XLA_TransferManager* manager, SE_StreamExecutor* executor,
|
||||
XLA_Shape* shape /*deprecated*/, XLA_Literal* c_literal, TF_Status* status);
|
||||
void TpuTransferManager_ResetDevices(XLA_TransferManager* manager,
|
||||
SE_StreamExecutor** executors,
|
||||
int64_t num_executors, TF_Status* status);
|
||||
|
@ -127,14 +127,14 @@ Status TpuTransferManager::TransferBuffersToInfeed(
|
||||
}
|
||||
|
||||
Status TpuTransferManager::TransferLiteralFromOutfeed(
|
||||
stream_executor::StreamExecutor* executor, const xla::Shape& literal_shape,
|
||||
stream_executor::StreamExecutor* executor,
|
||||
xla::MutableBorrowingLiteral literal) {
|
||||
StatusHelper status;
|
||||
XLA_Shape c_shape;
|
||||
XLA_Literal c_literal;
|
||||
auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
|
||||
|
||||
ApiConverter::ToC(literal_shape, &c_shape);
|
||||
ApiConverter::ToC(literal.shape(), &c_shape);
|
||||
ApiConverter::ToC(literal, &c_literal);
|
||||
|
||||
tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromOutfeedFn(
|
||||
|
@ -56,7 +56,6 @@ class TpuTransferManager : public xla::TpuTransferManagerInterface {
|
||||
|
||||
Status TransferLiteralFromOutfeed(
|
||||
stream_executor::StreamExecutor* executor,
|
||||
const xla::Shape& literal_shape,
|
||||
xla::MutableBorrowingLiteral literal) override;
|
||||
|
||||
Status TransferBuffersToInfeed(
|
||||
|
Loading…
x
Reference in New Issue
Block a user