diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index c56b8f50906..1b2a19ba2aa 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -439,7 +439,7 @@ StatusOr LocalClient::TransferFromOutfeedLocal(const Shape& shape, backend().stream_executor(device_ordinal)); auto literal = Literal::CreateFromShape(shape); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( - executor, shape, &literal)); + executor, &literal)); return std::move(literal); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index bfd8e9e111a..f2a7044fc54 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -176,15 +176,14 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, } Status CpuTransferManager::TransferLiteralFromOutfeed( - se::StreamExecutor* executor, const Shape& literal_shape, - MutableBorrowingLiteral literal) { - if (!literal_shape.IsTuple()) { - int64 size = GetByteSizeRequirement(literal_shape); + se::StreamExecutor* executor, MutableBorrowingLiteral literal) { + if (!literal.shape().IsTuple()) { + int64 size = GetByteSizeRequirement(literal.shape()); // Note: OSS build didn't like implicit conversion from - // literal_shape.dimensions() to the array slice on 2017-07-10. + // literal.shape().dimensions() to the array slice on 2017-07-10. absl::Span dimensions( - absl::bit_cast(literal_shape.dimensions().data()), - literal_shape.dimensions().size()); + absl::bit_cast(literal.shape().dimensions().data()), + literal.shape().dimensions().size()); TF_ASSIGN_OR_RETURN( Shape received_shape, TransferArrayBufferFromOutfeed(executor, literal.untyped_data(), size)); @@ -192,21 +191,21 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( << "Shape received from outfeed " << ShapeUtil::HumanString(received_shape) << " did not match the shape that was requested for outfeed: " - << ShapeUtil::HumanString(literal_shape); + << ShapeUtil::HumanString(literal.shape()); TF_RET_CHECK(size == GetByteSizeRequirement(received_shape)); *literal.mutable_shape_do_not_use() = received_shape; return Status::OK(); } - if (ShapeUtil::IsNestedTuple(literal_shape)) { + if (ShapeUtil::IsNestedTuple(literal.shape())) { return Unimplemented( "Nested tuple outfeeds are not yet implemented on CPU."); } std::vector> buffer_data; - for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { + for (int64 i = 0; i < literal.shape().tuple_shapes_size(); ++i) { const Shape& tuple_element_shape = - ShapeUtil::GetTupleElementShape(literal_shape, i); + ShapeUtil::GetTupleElementShape(literal.shape(), i); int64 size = GetByteSizeRequirement(tuple_element_shape); buffer_data.push_back({literal.untyped_data({i}), size}); } @@ -214,15 +213,15 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( TF_ASSIGN_OR_RETURN(Shape received_shape, TransferTupleBuffersFromOutfeed(executor, buffer_data)); - TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal_shape)) + TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape())) << "Shape received from outfeed " << ShapeUtil::HumanString(received_shape) << " did not match the shape that was requested for outfeed: " - << ShapeUtil::HumanString(literal_shape); - TF_RET_CHECK(GetByteSizeRequirement(literal_shape) == + << ShapeUtil::HumanString(literal.shape()); + TF_RET_CHECK(GetByteSizeRequirement(literal.shape()) == GetByteSizeRequirement(received_shape)); - TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal_shape)); + TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal.shape())); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 43d2e0a3cab..faf1561ea86 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -42,7 +42,6 @@ class CpuTransferManager : public GenericTransferManager { Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, - const Shape& literal_shape, MutableBorrowingLiteral literal) override; bool CanShapedBufferBeAccessedNow( diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index d2febb5fb73..c051bd11006 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -145,8 +145,7 @@ Status GenericTransferManager::TransferLiteralToInfeed( } Status GenericTransferManager::TransferLiteralFromOutfeed( - se::StreamExecutor* executor, const Shape& literal_shape, - MutableBorrowingLiteral literal) { + se::StreamExecutor* executor, MutableBorrowingLiteral literal) { return Unimplemented("Generic transfer from Outfeed"); } diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 9cc344be06c..79fa6ac67e0 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -53,7 +53,6 @@ class GenericTransferManager : public TransferManager { Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, - const Shape& literal_shape, MutableBorrowingLiteral literal) override; Status ResetDevices(absl::Span executors) override; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index a1be8cfb07c..436a45848fa 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -114,13 +114,12 @@ StatusOr GpuTransferManager::TransferBufferToInfeedInternal( } Status GpuTransferManager::TransferLiteralFromOutfeed( - se::StreamExecutor* /*executor*/, const Shape& literal_shape, - MutableBorrowingLiteral literal) { + se::StreamExecutor* /*executor*/, MutableBorrowingLiteral literal) { ShapeTree> outfeed_buffers( - &literal_shape); + &literal.shape()); for (auto& leaf : outfeed_buffers.leaves()) { - const Shape& shape = ShapeUtil::GetSubshape(literal_shape, leaf.first); + const Shape& shape = ShapeUtil::GetSubshape(literal.shape(), leaf.first); CHECK(shape.IsArray()) << ShapeUtil::HumanStringWithLayout(shape); leaf.second = absl::make_unique(GetByteSizeRequirement(shape)); @@ -135,7 +134,7 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( // Now wait till all the buffers are written. for (auto& leaf : outfeed_buffers.leaves()) { - const Shape& shape = ShapeUtil::GetSubshape(literal_shape, leaf.first); + const Shape& shape = ShapeUtil::GetSubshape(literal.shape(), leaf.first); CHECK(shape.IsArray()) << ShapeUtil::HumanStringWithLayout(shape); leaf.second->WaitUntilAvailable(); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index fa88816bc8b..acc301feddc 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -41,7 +41,6 @@ class GpuTransferManager : public GenericTransferManager { Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, - const Shape& literal_shape, MutableBorrowingLiteral literal) override; private: diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 63736a899b6..988a0572589 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -285,9 +285,9 @@ StatusOr> HloRunner::ExecuteReplicatedImpl( VLOG(1) << "Starting outfeed on device " << device; for (int64 step = 1; options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { - Literal literal; + Literal literal(options.outfeed_shape); TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - executor, options.outfeed_shape, &literal)); + executor, &literal)); if (options.outfeed_values != nullptr) { options.outfeed_values->push_back(std::move(literal)); } diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index c5d4d0415c9..275c82d4e8c 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -1037,7 +1037,7 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, Shape(arg->shape_with_layout()), &literal)); + executor, &literal)); *result->mutable_literal() = literal.ToProto(); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index d7636c30c36..67a61b0145d 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -203,10 +203,10 @@ class TransferManager { const LiteralSlice& literal) = 0; // Transfers the given literal from the Outfeed interface of the device, - // using the given executor. + // using the given executor. The shape and layout are determined by the + // shape and layout of `literal`. virtual Status TransferLiteralFromOutfeed( - se::StreamExecutor* executor, const Shape& literal_shape, - MutableBorrowingLiteral literal) = 0; + se::StreamExecutor* executor, MutableBorrowingLiteral literal) = 0; // Resets the devices associated with this transfer manager. virtual Status ResetDevices( diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index d1e9a9c7aa2..0ce4b105d01 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -573,7 +573,7 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({2, 3})); auto literal = Literal::CreateFromShape(expected.shape()); TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - backend().default_stream_executor(), expected.shape(), &literal)); + backend().default_stream_executor(), &literal)); EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal)); } diff --git a/tensorflow/core/tpu/kernels/outfeed_ops.cc b/tensorflow/core/tpu/kernels/outfeed_ops.cc index bc9a9d14db9..a0fea5e2e56 100644 --- a/tensorflow/core/tpu/kernels/outfeed_ops.cc +++ b/tensorflow/core/tpu/kernels/outfeed_ops.cc @@ -53,8 +53,8 @@ Status TpuOutfeedDequeueOp::DoWork( VLOG(1) << "TransferLiteralFromOutfeed " << xla::ShapeUtil::HumanStringWithLayout(xla_shape_); - TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralFromOutfeed( - stream_executor, xla_shape_, literal)); + TF_RETURN_IF_ERROR( + transfer_manager->TransferLiteralFromOutfeed(stream_executor, literal)); VLOG(1) << "TransferLiteralFromOutfeed complete."; @@ -96,8 +96,8 @@ Status TpuOutfeedDequeueTupleOp::DoWork( xla::MutableBorrowingLiteral literal; TF_RETURN_IF_ERROR( HostTensorToMutableBorrowingLiteral(xla_shapes_[i], output, &literal)); - TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralFromOutfeed( - stream_executor, xla_shapes_[i], literal)); + TF_RETURN_IF_ERROR( + transfer_manager->TransferLiteralFromOutfeed(stream_executor, literal)); } return Status::OK(); } diff --git a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h index fc0645f8298..a7a02b709bd 100644 --- a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h +++ b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h @@ -219,11 +219,9 @@ void TpuTransferManager_TransferBuffersToInfeed(XLA_TransferManager* manager, int64_t* buffers_size_in_uint32, int64_t buffers_array_size, TF_Status* status); -void TpuTransferManager_TransferLiteralFromOutfeed(XLA_TransferManager* manager, - SE_StreamExecutor* executor, - XLA_Shape* shape, - XLA_Literal* c_literal, - TF_Status* status); +void TpuTransferManager_TransferLiteralFromOutfeed( + XLA_TransferManager* manager, SE_StreamExecutor* executor, + XLA_Shape* shape /*deprecated*/, XLA_Literal* c_literal, TF_Status* status); void TpuTransferManager_ResetDevices(XLA_TransferManager* manager, SE_StreamExecutor** executors, int64_t num_executors, TF_Status* status); diff --git a/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc b/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc index 1790c24a9f6..dd516cf58aa 100644 --- a/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc +++ b/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc @@ -127,14 +127,14 @@ Status TpuTransferManager::TransferBuffersToInfeed( } Status TpuTransferManager::TransferLiteralFromOutfeed( - stream_executor::StreamExecutor* executor, const xla::Shape& literal_shape, + stream_executor::StreamExecutor* executor, xla::MutableBorrowingLiteral literal) { StatusHelper status; XLA_Shape c_shape; XLA_Literal c_literal; auto* tpu_executor = static_cast(executor->implementation()); - ApiConverter::ToC(literal_shape, &c_shape); + ApiConverter::ToC(literal.shape(), &c_shape); ApiConverter::ToC(literal, &c_literal); tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromOutfeedFn( diff --git a/tensorflow/stream_executor/tpu/tpu_transfer_manager.h b/tensorflow/stream_executor/tpu/tpu_transfer_manager.h index cef9b90c47a..3e1425b6cea 100644 --- a/tensorflow/stream_executor/tpu/tpu_transfer_manager.h +++ b/tensorflow/stream_executor/tpu/tpu_transfer_manager.h @@ -56,7 +56,6 @@ class TpuTransferManager : public xla::TpuTransferManagerInterface { Status TransferLiteralFromOutfeed( stream_executor::StreamExecutor* executor, - const xla::Shape& literal_shape, xla::MutableBorrowingLiteral literal) override; Status TransferBuffersToInfeed(