[XLA] Don't pass on_host_shape to ShapedBuffer/ScopedShapedBuffer inside XLA.

PiperOrigin-RevId: 336133292
Change-Id: I47a6fa5a5f2c6a460bdaeb1acc5125ff20710230
This commit is contained in:
Peter Hawkins 2020-10-08 11:54:04 -07:00 committed by TensorFlower Gardener
parent 17dc4b07a0
commit e7f6b0c7ee
15 changed files with 54 additions and 73 deletions

View File

@ -267,9 +267,9 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
} }
static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer( static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer(
Shape const& on_host_shape, const ShapeTree<MaybeOwningDeviceMemory>& tree, const ShapeTree<MaybeOwningDeviceMemory>& tree, se::Platform* platform,
se::Platform* platform, int device_ordinal) { int device_ordinal) {
ShapedBuffer result(on_host_shape, tree.shape(), platform, device_ordinal); ShapedBuffer result(tree.shape(), platform, device_ordinal);
auto it = tree.begin(); auto it = tree.begin();
auto out_it = result.buffers().begin(); auto out_it = result.buffers().begin();
for (; it != tree.end(); ++it, ++out_it) { for (; it != tree.end(); ++it, ++out_it) {
@ -299,8 +299,8 @@ StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
shaped_buffer_ptrs.reserve(arguments.size()); shaped_buffer_ptrs.reserve(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) { for (size_t i = 0; i < arguments.size(); ++i) {
shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer( shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer(
*argument_host_shapes[i], arguments[i].Buffers(), arguments[i].Buffers(), backend_->platform(),
backend_->platform(), stream->parent()->device_ordinal())); stream->parent()->device_ordinal()));
shaped_buffer_ptrs.push_back(&shaped_buffers.back()); shaped_buffer_ptrs.push_back(&shaped_buffers.back());
} }

View File

@ -143,13 +143,10 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
// We only need to care about replica id 0 here, since the GlobalDataHandle is // We only need to care about replica id 0 here, since the GlobalDataHandle is
// the same for all buffers across replicas. // the same for all buffers across replicas.
const ShapedBuffer* shaped_buffer = replicated_buffers[0]; const ShapedBuffer* shaped_buffer = replicated_buffers[0];
if (!shaped_buffer->on_host_shape().IsTuple()) { if (!shaped_buffer->on_device_shape().IsTuple()) {
return InvalidArgument("global data handle %d is not a tuple", return InvalidArgument("global data handle %d is not a tuple",
data.handle()); data.handle());
} }
// If the on-host representation is a tuple, then the on-device one should be
// as well.
TF_RET_CHECK(shaped_buffer->on_device_shape().IsTuple());
if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) { if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) {
return Unimplemented("Deconstructing nested tuples is not implemented."); return Unimplemented("Deconstructing nested tuples is not implemented.");
@ -160,7 +157,6 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape()); i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape());
++i) { ++i) {
auto element_buffer = ShapedBuffer( auto element_buffer = ShapedBuffer(
ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i),
ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i), ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i),
shaped_buffer->platform(), shaped_buffer->device_ordinal()); shaped_buffer->platform(), shaped_buffer->device_ordinal());
element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}), element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}),

View File

@ -210,8 +210,7 @@ StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
absl::Span<MaybeOwningDeviceMemory> buffers, absl::Span<MaybeOwningDeviceMemory> buffers,
absl::Span<ExecutionInput> arguments) { absl::Span<ExecutionInput> arguments) {
se::Stream* stream = run_options->stream(); se::Stream* stream = run_options->stream();
ExecutionOutput result(/*on_host_shape=*/result_shape(), ExecutionOutput result(/*on_device_shape=*/result_shape(),
/*on_device_shape=*/result_shape(),
run_options->allocator(), run_options->allocator(),
stream->parent()->device_ordinal()); stream->parent()->device_ordinal());
const HloInputOutputAliasConfig& input_output_alias = const HloInputOutputAliasConfig& input_output_alias =

View File

@ -59,11 +59,11 @@ void ExecutionInput::SetUnownedBuffer(const ShapeIndex& index,
unowned_indices_.insert(index); unowned_indices_.insert(index);
} }
xla::StatusOr<xla::ShapedBuffer> ExecutionInput::ToShapedBuffer( StatusOr<ShapedBuffer> ExecutionInput::ToShapedBuffer(
se::DeviceMemoryAllocator* allocator, int device_ordinal) const { se::DeviceMemoryAllocator* allocator, int device_ordinal) const {
const Shape& input_shape = shape(); const Shape& input_shape = shape();
xla::ShapedBuffer shaped_buffer(input_shape, input_shape, ShapedBuffer shaped_buffer(input_shape, allocator->platform(),
allocator->platform(), device_ordinal); device_ordinal);
for (const auto& index_buffer : Buffers()) { for (const auto& index_buffer : Buffers()) {
const tensorflow::se::OwningDeviceMemory* mem = const tensorflow::se::OwningDeviceMemory* mem =
index_buffer.second.AsOwningDeviceMemory(); index_buffer.second.AsOwningDeviceMemory();
@ -93,8 +93,7 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
static ExecutionInput MakeMaybeOwningDeviceMemoryTree( static ExecutionInput MakeMaybeOwningDeviceMemoryTree(
const ShapedBuffer& shaped_buffer) { const ShapedBuffer& shaped_buffer) {
ExecutionInput result(shaped_buffer.on_device_shape(), ExecutionInput result(shaped_buffer.on_device_shape());
shaped_buffer.on_host_shape());
shaped_buffer.buffers().ForEachElement( shaped_buffer.buffers().ForEachElement(
[&](const ShapeIndex& index, const se::DeviceMemoryBase& mem) { [&](const ShapeIndex& index, const se::DeviceMemoryBase& mem) {
result.SetBuffer(index, MaybeOwningDeviceMemory(mem)); result.SetBuffer(index, MaybeOwningDeviceMemory(mem));

View File

@ -153,13 +153,13 @@ class ExecutionOutput {
std::vector<se::OwningDeviceMemory> to_be_released) std::vector<se::OwningDeviceMemory> to_be_released)
: result_(std::move(result)), : result_(std::move(result)),
to_be_released_(std::move(to_be_released)) {} to_be_released_(std::move(to_be_released)) {}
// TODO(b/170310047): remove this overload.
ExecutionOutput(Shape on_host_shape, Shape on_device_shape,
se::DeviceMemoryAllocator* allocator, int device_ordinal)
: result_(std::move(on_device_shape), allocator, device_ordinal) {}
ExecutionOutput(Shape on_device_shape, se::DeviceMemoryAllocator* allocator, ExecutionOutput(Shape on_device_shape, se::DeviceMemoryAllocator* allocator,
int device_ordinal) int device_ordinal)
: result_(std::move(on_device_shape), allocator, device_ordinal) {} : result_(std::move(on_device_shape), allocator, device_ordinal) {}
ExecutionOutput(Shape on_host_shape, Shape on_device_shape,
se::DeviceMemoryAllocator* allocator, int device_ordinal)
: result_(std::move(on_host_shape), std::move(on_device_shape), allocator,
device_ordinal) {}
ExecutionOutput(ExecutionOutput&&) = default; ExecutionOutput(ExecutionOutput&&) = default;
ExecutionOutput& operator=(ExecutionOutput&&) = default; ExecutionOutput& operator=(ExecutionOutput&&) = default;

View File

@ -69,13 +69,8 @@ void GenericTransferManager::TransferLiteralFromDevice(
TF_RET_CHECK(stream->parent()->device_ordinal() == TF_RET_CHECK(stream->parent()->device_ordinal() ==
device_buffer.device_ordinal()); device_buffer.device_ordinal());
// The on-host and on-device shape should always be the same for the generic
// transfer manager.
TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
device_buffer.on_host_shape()));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
device_buffer.on_host_shape(), device_buffer.on_device_shape(),
[&](const Shape& subshape, const ShapeIndex& index) -> Status { [&](const Shape& subshape, const ShapeIndex& index) -> Status {
if (subshape.IsArray()) { if (subshape.IsArray()) {
stream->ThenMemcpy( stream->ThenMemcpy(
@ -103,21 +98,15 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
<< ShapeUtil::HumanString(shape) << ShapeUtil::HumanString(shape)
<< "; device buffer: " << device_buffer; << "; device buffer: " << device_buffer;
// The on-host and on-device shape should always be the same for the generic
// transfer manager.
TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
device_buffer.on_host_shape()))
<< device_buffer.ToString();
TF_RET_CHECK( TF_RET_CHECK(
ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape())); ShapeUtil::Compatible(literal.shape(), device_buffer.on_device_shape()));
TF_RET_CHECK(stream->parent()->device_ordinal() == TF_RET_CHECK(stream->parent()->device_ordinal() ==
device_buffer.device_ordinal()); device_buffer.device_ordinal());
TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer)); TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
return ShapeUtil::ForEachSubshapeWithStatus( return ShapeUtil::ForEachSubshapeWithStatus(
device_buffer.on_host_shape(), device_buffer.on_device_shape(),
[&](const Shape& device_subshape, const ShapeIndex& index) -> Status { [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
se::DeviceMemoryBase device_memory = device_buffer.buffer(index); se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
if (device_subshape.IsArray()) { if (device_subshape.IsArray()) {

View File

@ -450,8 +450,7 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); HloInstruction* root = hlo_module_->entry_computation()->root_instruction();
const Shape& root_shape = root->shape(); const Shape& root_shape = root->shape();
auto device_ordinal = executor->device_ordinal(); auto device_ordinal = executor->device_ordinal();
ExecutionOutput result(/*on_host_shape=*/root->shape(), ExecutionOutput result(/*on_device_shape=*/root->shape(), memory_allocator,
/*on_device_shape=*/root->shape(), memory_allocator,
device_ordinal); device_ordinal);
TF_ASSIGN_OR_RETURN(BufferAllocations buffer_allocations, TF_ASSIGN_OR_RETURN(BufferAllocations buffer_allocations,

View File

@ -211,8 +211,7 @@ static std::vector<ExecutionInput> ExecutionInputsFromScopedShapedBuffers(
*buffer_tree.mutable_element(index) = execution_input_buffer; *buffer_tree.mutable_element(index) = execution_input_buffer;
} }
}); });
execution_inputs.emplace_back(std::move(buffer_tree), execution_inputs.emplace_back(std::move(buffer_tree));
input_buffer.on_host_shape());
} }
return execution_inputs; return execution_inputs;
} }

View File

@ -56,7 +56,7 @@ StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
} }
for (auto& argument : arguments) { for (auto& argument : arguments) {
const ShapeTree<MaybeOwningDeviceMemory>& buffers = argument.Buffers(); const ShapeTree<MaybeOwningDeviceMemory>& buffers = argument.Buffers();
argument_buffers.push_back(ShapedBuffer(buffers.shape(), buffers.shape(), argument_buffers.push_back(ShapedBuffer(buffers.shape(),
/*platform=*/nullptr, /*platform=*/nullptr,
/*device_ordinal=*/device_ordinal)); /*device_ordinal=*/device_ordinal));
auto in_it = buffers.begin(); auto in_it = buffers.begin();

View File

@ -31,10 +31,6 @@ limitations under the License.
namespace xla { namespace xla {
ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
const se::Platform* platform, int device_ordinal)
: ShapedBuffer(on_device_shape, platform, device_ordinal) {}
ShapedBuffer::ShapedBuffer(Shape on_device_shape, const se::Platform* platform, ShapedBuffer::ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
int device_ordinal) int device_ordinal)
: on_device_shape_(std::move(on_device_shape)), : on_device_shape_(std::move(on_device_shape)),
@ -44,6 +40,10 @@ ShapedBuffer::ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape_); on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape_);
} }
ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
const se::Platform* platform, int device_ordinal)
: ShapedBuffer(on_device_shape, platform, device_ordinal) {}
ShapedBuffer::ShapedBuffer(ShapedBuffer&& s) ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
: on_host_shape_(std::move(s.on_host_shape_)), : on_host_shape_(std::move(s.on_host_shape_)),
on_device_shape_(std::move(s.on_device_shape_)), on_device_shape_(std::move(s.on_device_shape_)),
@ -90,10 +90,9 @@ void ShapedBuffer::clear() {
} }
string ShapedBuffer::ToString() const { string ShapedBuffer::ToString() const {
string s = absl::StrCat( string s =
"ShapedBuffer(", platform_->Name(), ":", device_ordinal(), absl::StrCat("ShapedBuffer(", platform_->Name(), ":", device_ordinal(),
"), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()), "), on-device shape=" +
", on-device shape=" +
ShapeUtil::HumanStringWithLayout(on_device_shape()), ShapeUtil::HumanStringWithLayout(on_device_shape()),
":\n"); ":\n");
ShapeUtil::ForEachSubshape( ShapeUtil::ForEachSubshape(
@ -118,14 +117,6 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) {
return out; return out;
} }
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,
Shape on_device_shape,
se::DeviceMemoryAllocator* allocator,
int device_ordinal)
: ShapedBuffer(std::move(on_device_shape), allocator->platform(),
device_ordinal),
allocator_(allocator) {}
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape, ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape,
se::DeviceMemoryAllocator* allocator, se::DeviceMemoryAllocator* allocator,
int device_ordinal) int device_ordinal)
@ -133,6 +124,13 @@ ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape,
device_ordinal), device_ordinal),
allocator_(allocator) {} allocator_(allocator) {}
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,
Shape on_device_shape,
se::DeviceMemoryAllocator* allocator,
int device_ordinal)
: ScopedShapedBuffer(std::move(on_device_shape), allocator,
device_ordinal) {}
ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer, ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,
se::DeviceMemoryAllocator* allocator) se::DeviceMemoryAllocator* allocator)
: ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {} : ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {}

View File

@ -45,6 +45,7 @@ class ShapedBuffer {
// ShapedBuffer. // ShapedBuffer.
ShapedBuffer(Shape on_device_shape, const se::Platform* platform, ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
int device_ordinal); int device_ordinal);
// TODO(b/170310047): remove this overload. // TODO(b/170310047): remove this overload.
ShapedBuffer(Shape on_host_shape, Shape on_device_shape, ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
const se::Platform* platform, int device_ordinal); const se::Platform* platform, int device_ordinal);
@ -100,7 +101,7 @@ class ShapedBuffer {
// Reset the shape of this shaped buffer and underlying buffer structure. // Reset the shape of this shaped buffer and underlying buffer structure.
// //
// Precondition: EqualStructure(this->on_device_shape_, on_device_shape). // Precondition: EqualStructure(this->on_device_shape_, on_device_shape).
void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) { void set_shapes(const Shape& on_device_shape) {
CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_)) CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_))
<< "Structures are not the same. new: " << on_device_shape << "Structures are not the same. new: " << on_device_shape
<< ", old: " << on_device_shape_; << ", old: " << on_device_shape_;
@ -108,6 +109,10 @@ class ShapedBuffer {
on_device_shape_ = on_device_shape; on_device_shape_ = on_device_shape;
buffers_.replace_shape_ptr(&on_device_shape_); buffers_.replace_shape_ptr(&on_device_shape_);
} }
// TODO(b/170310047): remove this overload.
void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) {
set_shapes(on_device_shape);
}
// Returns the underlying ShapeTree containing all the device addresses in the // Returns the underlying ShapeTree containing all the device addresses in the
// ShapedBuffer. // ShapedBuffer.

View File

@ -97,12 +97,12 @@ class TestAllocator : public se::DeviceMemoryAllocator {
TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) { TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) {
Shape s = ShapeUtil::MakeShape(F32, {1}); Shape s = ShapeUtil::MakeShape(F32, {1});
TestAllocator allocator; TestAllocator allocator;
ScopedShapedBuffer sb1(s, s, &allocator, /*device_ordinal=*/0); ScopedShapedBuffer sb1(s, &allocator, /*device_ordinal=*/0);
sb1.set_buffer( sb1.set_buffer(
allocator.Allocate(/*device_ordinal=*/0, /*size=*/42).ValueOrDie(), allocator.Allocate(/*device_ordinal=*/0, /*size=*/42).ValueOrDie(),
/*index=*/{}); /*index=*/{});
ScopedShapedBuffer sb2(s, s, &allocator, /*device_ordinal=*/1); ScopedShapedBuffer sb2(s, &allocator, /*device_ordinal=*/1);
sb2.set_buffer( sb2.set_buffer(
allocator.Allocate(/*device_ordinal=*/1, /*size=*/10).ValueOrDie(), allocator.Allocate(/*device_ordinal=*/1, /*size=*/10).ValueOrDie(),
/*index=*/{}); /*index=*/{});
@ -119,7 +119,7 @@ TEST(ScopedShapedBufferTest, TestTakeSubTree) {
s = xla::ShapeUtil::MakeTupleShape(std::vector<xla::Shape>(2, s)); s = xla::ShapeUtil::MakeTupleShape(std::vector<xla::Shape>(2, s));
s = xla::ShapeUtil::MakeTupleShape(std::vector<xla::Shape>(3, s)); s = xla::ShapeUtil::MakeTupleShape(std::vector<xla::Shape>(3, s));
ScopedShapedBuffer sb(s, s, &allocator, /*device_ordinal=*/0); ScopedShapedBuffer sb(s, &allocator, /*device_ordinal=*/0);
sb.buffers().ForEachMutableElement( sb.buffers().ForEachMutableElement(
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
@ -156,8 +156,7 @@ TEST(ScopedShapedBufferTest, TestSubShapeTree) {
Shape tuple_shape = Shape tuple_shape =
xla::ShapeUtil::MakeTupleShape({array_shape, array_shape}); xla::ShapeUtil::MakeTupleShape({array_shape, array_shape});
TestAllocator allocator; TestAllocator allocator;
ScopedShapedBuffer sb(tuple_shape, tuple_shape, &allocator, ScopedShapedBuffer sb(tuple_shape, &allocator, /*device_ordinal=*/0);
/*device_ordinal=*/0);
sb.buffers().ForEachMutableElement( sb.buffers().ForEachMutableElement(
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
@ -182,7 +181,7 @@ void BM_TakeSubTree(int iters, int depth, int fan_out) {
std::vector<xla::Shape> shapes(fan_out, shape); std::vector<xla::Shape> shapes(fan_out, shape);
shape = xla::ShapeUtil::MakeTupleShape(shapes); shape = xla::ShapeUtil::MakeTupleShape(shapes);
} }
xla::ScopedShapedBuffer shaped_buffer(shape, shape, /*allocator=*/&allocator, xla::ScopedShapedBuffer shaped_buffer(shape, /*allocator=*/&allocator,
/*device_ordinal=*/0); /*device_ordinal=*/0);
tensorflow::testing::StartTiming(); tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) { for (int i = 0; i < iters; ++i) {

View File

@ -169,8 +169,7 @@ Status TransferManager::TransferArrayToDeviceAsync(
"%d < %d", "%d < %d",
dest.size(), GetByteSizeRequirement(on_device_shape)); dest.size(), GetByteSizeRequirement(on_device_shape));
} }
ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, ShapedBuffer shaped_buffer(on_device_shape, stream->parent()->platform(),
stream->parent()->platform(),
stream->parent()->device_ordinal()); stream->parent()->device_ordinal());
shaped_buffer.set_buffer(dest, /*index=*/{}); shaped_buffer.set_buffer(dest, /*index=*/{});
return TransferLiteralToDevice(stream, literal, shaped_buffer, return TransferLiteralToDevice(stream, literal, shaped_buffer,
@ -194,8 +193,7 @@ void TransferManager::TransferArrayFromDevice(
"%d < %d", "%d < %d",
source.size(), GetByteSizeRequirement(shape))); source.size(), GetByteSizeRequirement(shape)));
} }
ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, ShapedBuffer shaped_buffer(shape, stream->parent()->platform(),
stream->parent()->platform(),
stream->parent()->device_ordinal()); stream->parent()->device_ordinal());
shaped_buffer.set_buffer(source, /*index=*/{}); shaped_buffer.set_buffer(source, /*index=*/{});
return TransferLiteralFromDevice(stream, shaped_buffer, literal, return TransferLiteralFromDevice(stream, shaped_buffer, literal,
@ -406,8 +404,8 @@ StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
Shape on_device_shape = HostShapeToDeviceShape(on_host_shape); Shape on_device_shape = HostShapeToDeviceShape(on_host_shape);
TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape)); TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
ScopedShapedBuffer shaped_buffer(on_host_shape, std::move(on_device_shape), ScopedShapedBuffer shaped_buffer(std::move(on_device_shape), allocator,
allocator, device_ordinal); device_ordinal);
// Allocate an appropriate sized buffer for each element in the shape // Allocate an appropriate sized buffer for each element in the shape
// including the tuple pointer arrays. // including the tuple pointer arrays.

View File

@ -193,6 +193,7 @@ class TransferManager {
// shapes, and returns static shapes with dynamic shapes updated. // shapes, and returns static shapes with dynamic shapes updated.
// The shape of the buffer also have to be compatible with the host shape and // The shape of the buffer also have to be compatible with the host shape and
// device shape. // device shape.
// TODO(b/170310047): remove host_shape.
virtual Status ReadDynamicShapes(se::Stream* stream, virtual Status ReadDynamicShapes(se::Stream* stream,
ShapedBuffer* device_buffer, ShapedBuffer* device_buffer,
Shape* host_shape, Shape* device_shape); Shape* host_shape, Shape* device_shape);

View File

@ -119,8 +119,7 @@ class BufferDonationTest : public HloTestBase {
} }
}); });
args.emplace_back( args.emplace_back(ExecutionInput(std::move(owned_buffers)));
ExecutionInput(std::move(owned_buffers), argument_literal.shape()));
} }
StatusOr<ExecutionOutput> output_status = StatusOr<ExecutionOutput> output_status =