[XLA] Don't pass on_host_shape to ShapedBuffer/ScopedShapedBuffer inside XLA.
PiperOrigin-RevId: 336133292 Change-Id: I47a6fa5a5f2c6a460bdaeb1acc5125ff20710230
This commit is contained in:
parent
17dc4b07a0
commit
e7f6b0c7ee
@ -267,9 +267,9 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
|
|||||||
}
|
}
|
||||||
|
|
||||||
static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer(
|
static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer(
|
||||||
Shape const& on_host_shape, const ShapeTree<MaybeOwningDeviceMemory>& tree,
|
const ShapeTree<MaybeOwningDeviceMemory>& tree, se::Platform* platform,
|
||||||
se::Platform* platform, int device_ordinal) {
|
int device_ordinal) {
|
||||||
ShapedBuffer result(on_host_shape, tree.shape(), platform, device_ordinal);
|
ShapedBuffer result(tree.shape(), platform, device_ordinal);
|
||||||
auto it = tree.begin();
|
auto it = tree.begin();
|
||||||
auto out_it = result.buffers().begin();
|
auto out_it = result.buffers().begin();
|
||||||
for (; it != tree.end(); ++it, ++out_it) {
|
for (; it != tree.end(); ++it, ++out_it) {
|
||||||
@ -299,8 +299,8 @@ StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
|
|||||||
shaped_buffer_ptrs.reserve(arguments.size());
|
shaped_buffer_ptrs.reserve(arguments.size());
|
||||||
for (size_t i = 0; i < arguments.size(); ++i) {
|
for (size_t i = 0; i < arguments.size(); ++i) {
|
||||||
shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer(
|
shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer(
|
||||||
*argument_host_shapes[i], arguments[i].Buffers(),
|
arguments[i].Buffers(), backend_->platform(),
|
||||||
backend_->platform(), stream->parent()->device_ordinal()));
|
stream->parent()->device_ordinal()));
|
||||||
shaped_buffer_ptrs.push_back(&shaped_buffers.back());
|
shaped_buffer_ptrs.push_back(&shaped_buffers.back());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,13 +143,10 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
|
|||||||
// We only need to care about replica id 0 here, since the GlobalDataHandle is
|
// We only need to care about replica id 0 here, since the GlobalDataHandle is
|
||||||
// the same for all buffers across replicas.
|
// the same for all buffers across replicas.
|
||||||
const ShapedBuffer* shaped_buffer = replicated_buffers[0];
|
const ShapedBuffer* shaped_buffer = replicated_buffers[0];
|
||||||
if (!shaped_buffer->on_host_shape().IsTuple()) {
|
if (!shaped_buffer->on_device_shape().IsTuple()) {
|
||||||
return InvalidArgument("global data handle %d is not a tuple",
|
return InvalidArgument("global data handle %d is not a tuple",
|
||||||
data.handle());
|
data.handle());
|
||||||
}
|
}
|
||||||
// If the on-host representation is a tuple, then the on-device one should be
|
|
||||||
// as well.
|
|
||||||
TF_RET_CHECK(shaped_buffer->on_device_shape().IsTuple());
|
|
||||||
|
|
||||||
if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) {
|
if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) {
|
||||||
return Unimplemented("Deconstructing nested tuples is not implemented.");
|
return Unimplemented("Deconstructing nested tuples is not implemented.");
|
||||||
@ -160,7 +157,6 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
|
|||||||
i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape());
|
i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape());
|
||||||
++i) {
|
++i) {
|
||||||
auto element_buffer = ShapedBuffer(
|
auto element_buffer = ShapedBuffer(
|
||||||
ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i),
|
|
||||||
ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i),
|
ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i),
|
||||||
shaped_buffer->platform(), shaped_buffer->device_ordinal());
|
shaped_buffer->platform(), shaped_buffer->device_ordinal());
|
||||||
element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}),
|
element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}),
|
||||||
|
@ -210,8 +210,7 @@ StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
|
|||||||
absl::Span<MaybeOwningDeviceMemory> buffers,
|
absl::Span<MaybeOwningDeviceMemory> buffers,
|
||||||
absl::Span<ExecutionInput> arguments) {
|
absl::Span<ExecutionInput> arguments) {
|
||||||
se::Stream* stream = run_options->stream();
|
se::Stream* stream = run_options->stream();
|
||||||
ExecutionOutput result(/*on_host_shape=*/result_shape(),
|
ExecutionOutput result(/*on_device_shape=*/result_shape(),
|
||||||
/*on_device_shape=*/result_shape(),
|
|
||||||
run_options->allocator(),
|
run_options->allocator(),
|
||||||
stream->parent()->device_ordinal());
|
stream->parent()->device_ordinal());
|
||||||
const HloInputOutputAliasConfig& input_output_alias =
|
const HloInputOutputAliasConfig& input_output_alias =
|
||||||
|
@ -59,11 +59,11 @@ void ExecutionInput::SetUnownedBuffer(const ShapeIndex& index,
|
|||||||
unowned_indices_.insert(index);
|
unowned_indices_.insert(index);
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::StatusOr<xla::ShapedBuffer> ExecutionInput::ToShapedBuffer(
|
StatusOr<ShapedBuffer> ExecutionInput::ToShapedBuffer(
|
||||||
se::DeviceMemoryAllocator* allocator, int device_ordinal) const {
|
se::DeviceMemoryAllocator* allocator, int device_ordinal) const {
|
||||||
const Shape& input_shape = shape();
|
const Shape& input_shape = shape();
|
||||||
xla::ShapedBuffer shaped_buffer(input_shape, input_shape,
|
ShapedBuffer shaped_buffer(input_shape, allocator->platform(),
|
||||||
allocator->platform(), device_ordinal);
|
device_ordinal);
|
||||||
for (const auto& index_buffer : Buffers()) {
|
for (const auto& index_buffer : Buffers()) {
|
||||||
const tensorflow::se::OwningDeviceMemory* mem =
|
const tensorflow::se::OwningDeviceMemory* mem =
|
||||||
index_buffer.second.AsOwningDeviceMemory();
|
index_buffer.second.AsOwningDeviceMemory();
|
||||||
@ -93,8 +93,7 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
|
|||||||
|
|
||||||
static ExecutionInput MakeMaybeOwningDeviceMemoryTree(
|
static ExecutionInput MakeMaybeOwningDeviceMemoryTree(
|
||||||
const ShapedBuffer& shaped_buffer) {
|
const ShapedBuffer& shaped_buffer) {
|
||||||
ExecutionInput result(shaped_buffer.on_device_shape(),
|
ExecutionInput result(shaped_buffer.on_device_shape());
|
||||||
shaped_buffer.on_host_shape());
|
|
||||||
shaped_buffer.buffers().ForEachElement(
|
shaped_buffer.buffers().ForEachElement(
|
||||||
[&](const ShapeIndex& index, const se::DeviceMemoryBase& mem) {
|
[&](const ShapeIndex& index, const se::DeviceMemoryBase& mem) {
|
||||||
result.SetBuffer(index, MaybeOwningDeviceMemory(mem));
|
result.SetBuffer(index, MaybeOwningDeviceMemory(mem));
|
||||||
|
@ -153,13 +153,13 @@ class ExecutionOutput {
|
|||||||
std::vector<se::OwningDeviceMemory> to_be_released)
|
std::vector<se::OwningDeviceMemory> to_be_released)
|
||||||
: result_(std::move(result)),
|
: result_(std::move(result)),
|
||||||
to_be_released_(std::move(to_be_released)) {}
|
to_be_released_(std::move(to_be_released)) {}
|
||||||
|
// TODO(b/170310047): remove this overload.
|
||||||
|
ExecutionOutput(Shape on_host_shape, Shape on_device_shape,
|
||||||
|
se::DeviceMemoryAllocator* allocator, int device_ordinal)
|
||||||
|
: result_(std::move(on_device_shape), allocator, device_ordinal) {}
|
||||||
ExecutionOutput(Shape on_device_shape, se::DeviceMemoryAllocator* allocator,
|
ExecutionOutput(Shape on_device_shape, se::DeviceMemoryAllocator* allocator,
|
||||||
int device_ordinal)
|
int device_ordinal)
|
||||||
: result_(std::move(on_device_shape), allocator, device_ordinal) {}
|
: result_(std::move(on_device_shape), allocator, device_ordinal) {}
|
||||||
ExecutionOutput(Shape on_host_shape, Shape on_device_shape,
|
|
||||||
se::DeviceMemoryAllocator* allocator, int device_ordinal)
|
|
||||||
: result_(std::move(on_host_shape), std::move(on_device_shape), allocator,
|
|
||||||
device_ordinal) {}
|
|
||||||
ExecutionOutput(ExecutionOutput&&) = default;
|
ExecutionOutput(ExecutionOutput&&) = default;
|
||||||
ExecutionOutput& operator=(ExecutionOutput&&) = default;
|
ExecutionOutput& operator=(ExecutionOutput&&) = default;
|
||||||
|
|
||||||
|
@ -69,13 +69,8 @@ void GenericTransferManager::TransferLiteralFromDevice(
|
|||||||
TF_RET_CHECK(stream->parent()->device_ordinal() ==
|
TF_RET_CHECK(stream->parent()->device_ordinal() ==
|
||||||
device_buffer.device_ordinal());
|
device_buffer.device_ordinal());
|
||||||
|
|
||||||
// The on-host and on-device shape should always be the same for the generic
|
|
||||||
// transfer manager.
|
|
||||||
TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
|
|
||||||
device_buffer.on_host_shape()));
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
|
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
|
||||||
device_buffer.on_host_shape(),
|
device_buffer.on_device_shape(),
|
||||||
[&](const Shape& subshape, const ShapeIndex& index) -> Status {
|
[&](const Shape& subshape, const ShapeIndex& index) -> Status {
|
||||||
if (subshape.IsArray()) {
|
if (subshape.IsArray()) {
|
||||||
stream->ThenMemcpy(
|
stream->ThenMemcpy(
|
||||||
@ -103,21 +98,15 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
|
|||||||
<< ShapeUtil::HumanString(shape)
|
<< ShapeUtil::HumanString(shape)
|
||||||
<< "; device buffer: " << device_buffer;
|
<< "; device buffer: " << device_buffer;
|
||||||
|
|
||||||
// The on-host and on-device shape should always be the same for the generic
|
|
||||||
// transfer manager.
|
|
||||||
TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
|
|
||||||
device_buffer.on_host_shape()))
|
|
||||||
<< device_buffer.ToString();
|
|
||||||
|
|
||||||
TF_RET_CHECK(
|
TF_RET_CHECK(
|
||||||
ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape()));
|
ShapeUtil::Compatible(literal.shape(), device_buffer.on_device_shape()));
|
||||||
TF_RET_CHECK(stream->parent()->device_ordinal() ==
|
TF_RET_CHECK(stream->parent()->device_ordinal() ==
|
||||||
device_buffer.device_ordinal());
|
device_buffer.device_ordinal());
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
|
TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
|
||||||
|
|
||||||
return ShapeUtil::ForEachSubshapeWithStatus(
|
return ShapeUtil::ForEachSubshapeWithStatus(
|
||||||
device_buffer.on_host_shape(),
|
device_buffer.on_device_shape(),
|
||||||
[&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
|
[&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
|
||||||
se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
|
se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
|
||||||
if (device_subshape.IsArray()) {
|
if (device_subshape.IsArray()) {
|
||||||
|
@ -450,8 +450,7 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
|
|||||||
HloInstruction* root = hlo_module_->entry_computation()->root_instruction();
|
HloInstruction* root = hlo_module_->entry_computation()->root_instruction();
|
||||||
const Shape& root_shape = root->shape();
|
const Shape& root_shape = root->shape();
|
||||||
auto device_ordinal = executor->device_ordinal();
|
auto device_ordinal = executor->device_ordinal();
|
||||||
ExecutionOutput result(/*on_host_shape=*/root->shape(),
|
ExecutionOutput result(/*on_device_shape=*/root->shape(), memory_allocator,
|
||||||
/*on_device_shape=*/root->shape(), memory_allocator,
|
|
||||||
device_ordinal);
|
device_ordinal);
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(BufferAllocations buffer_allocations,
|
TF_ASSIGN_OR_RETURN(BufferAllocations buffer_allocations,
|
||||||
|
@ -211,8 +211,7 @@ static std::vector<ExecutionInput> ExecutionInputsFromScopedShapedBuffers(
|
|||||||
*buffer_tree.mutable_element(index) = execution_input_buffer;
|
*buffer_tree.mutable_element(index) = execution_input_buffer;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
execution_inputs.emplace_back(std::move(buffer_tree),
|
execution_inputs.emplace_back(std::move(buffer_tree));
|
||||||
input_buffer.on_host_shape());
|
|
||||||
}
|
}
|
||||||
return execution_inputs;
|
return execution_inputs;
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,7 @@ StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
|
|||||||
}
|
}
|
||||||
for (auto& argument : arguments) {
|
for (auto& argument : arguments) {
|
||||||
const ShapeTree<MaybeOwningDeviceMemory>& buffers = argument.Buffers();
|
const ShapeTree<MaybeOwningDeviceMemory>& buffers = argument.Buffers();
|
||||||
argument_buffers.push_back(ShapedBuffer(buffers.shape(), buffers.shape(),
|
argument_buffers.push_back(ShapedBuffer(buffers.shape(),
|
||||||
/*platform=*/nullptr,
|
/*platform=*/nullptr,
|
||||||
/*device_ordinal=*/device_ordinal));
|
/*device_ordinal=*/device_ordinal));
|
||||||
auto in_it = buffers.begin();
|
auto in_it = buffers.begin();
|
||||||
|
@ -31,10 +31,6 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
|
|
||||||
const se::Platform* platform, int device_ordinal)
|
|
||||||
: ShapedBuffer(on_device_shape, platform, device_ordinal) {}
|
|
||||||
|
|
||||||
ShapedBuffer::ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
|
ShapedBuffer::ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
|
||||||
int device_ordinal)
|
int device_ordinal)
|
||||||
: on_device_shape_(std::move(on_device_shape)),
|
: on_device_shape_(std::move(on_device_shape)),
|
||||||
@ -44,6 +40,10 @@ ShapedBuffer::ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
|
|||||||
on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape_);
|
on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||||
|
const se::Platform* platform, int device_ordinal)
|
||||||
|
: ShapedBuffer(on_device_shape, platform, device_ordinal) {}
|
||||||
|
|
||||||
ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
|
ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
|
||||||
: on_host_shape_(std::move(s.on_host_shape_)),
|
: on_host_shape_(std::move(s.on_host_shape_)),
|
||||||
on_device_shape_(std::move(s.on_device_shape_)),
|
on_device_shape_(std::move(s.on_device_shape_)),
|
||||||
@ -90,10 +90,9 @@ void ShapedBuffer::clear() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
string ShapedBuffer::ToString() const {
|
string ShapedBuffer::ToString() const {
|
||||||
string s = absl::StrCat(
|
string s =
|
||||||
"ShapedBuffer(", platform_->Name(), ":", device_ordinal(),
|
absl::StrCat("ShapedBuffer(", platform_->Name(), ":", device_ordinal(),
|
||||||
"), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()),
|
"), on-device shape=" +
|
||||||
", on-device shape=" +
|
|
||||||
ShapeUtil::HumanStringWithLayout(on_device_shape()),
|
ShapeUtil::HumanStringWithLayout(on_device_shape()),
|
||||||
":\n");
|
":\n");
|
||||||
ShapeUtil::ForEachSubshape(
|
ShapeUtil::ForEachSubshape(
|
||||||
@ -118,14 +117,6 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,
|
|
||||||
Shape on_device_shape,
|
|
||||||
se::DeviceMemoryAllocator* allocator,
|
|
||||||
int device_ordinal)
|
|
||||||
: ShapedBuffer(std::move(on_device_shape), allocator->platform(),
|
|
||||||
device_ordinal),
|
|
||||||
allocator_(allocator) {}
|
|
||||||
|
|
||||||
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape,
|
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape,
|
||||||
se::DeviceMemoryAllocator* allocator,
|
se::DeviceMemoryAllocator* allocator,
|
||||||
int device_ordinal)
|
int device_ordinal)
|
||||||
@ -133,6 +124,13 @@ ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape,
|
|||||||
device_ordinal),
|
device_ordinal),
|
||||||
allocator_(allocator) {}
|
allocator_(allocator) {}
|
||||||
|
|
||||||
|
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,
|
||||||
|
Shape on_device_shape,
|
||||||
|
se::DeviceMemoryAllocator* allocator,
|
||||||
|
int device_ordinal)
|
||||||
|
: ScopedShapedBuffer(std::move(on_device_shape), allocator,
|
||||||
|
device_ordinal) {}
|
||||||
|
|
||||||
ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,
|
ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,
|
||||||
se::DeviceMemoryAllocator* allocator)
|
se::DeviceMemoryAllocator* allocator)
|
||||||
: ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {}
|
: ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {}
|
||||||
|
@ -45,6 +45,7 @@ class ShapedBuffer {
|
|||||||
// ShapedBuffer.
|
// ShapedBuffer.
|
||||||
ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
|
ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
|
||||||
int device_ordinal);
|
int device_ordinal);
|
||||||
|
|
||||||
// TODO(b/170310047): remove this overload.
|
// TODO(b/170310047): remove this overload.
|
||||||
ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
|
ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||||
const se::Platform* platform, int device_ordinal);
|
const se::Platform* platform, int device_ordinal);
|
||||||
@ -100,7 +101,7 @@ class ShapedBuffer {
|
|||||||
// Reset the shape of this shaped buffer and underlying buffer structure.
|
// Reset the shape of this shaped buffer and underlying buffer structure.
|
||||||
//
|
//
|
||||||
// Precondition: EqualStructure(this->on_device_shape_, on_device_shape).
|
// Precondition: EqualStructure(this->on_device_shape_, on_device_shape).
|
||||||
void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) {
|
void set_shapes(const Shape& on_device_shape) {
|
||||||
CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_))
|
CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_))
|
||||||
<< "Structures are not the same. new: " << on_device_shape
|
<< "Structures are not the same. new: " << on_device_shape
|
||||||
<< ", old: " << on_device_shape_;
|
<< ", old: " << on_device_shape_;
|
||||||
@ -108,6 +109,10 @@ class ShapedBuffer {
|
|||||||
on_device_shape_ = on_device_shape;
|
on_device_shape_ = on_device_shape;
|
||||||
buffers_.replace_shape_ptr(&on_device_shape_);
|
buffers_.replace_shape_ptr(&on_device_shape_);
|
||||||
}
|
}
|
||||||
|
// TODO(b/170310047): remove this overload.
|
||||||
|
void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) {
|
||||||
|
set_shapes(on_device_shape);
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the underlying ShapeTree containing all the device addresses in the
|
// Returns the underlying ShapeTree containing all the device addresses in the
|
||||||
// ShapedBuffer.
|
// ShapedBuffer.
|
||||||
|
@ -97,12 +97,12 @@ class TestAllocator : public se::DeviceMemoryAllocator {
|
|||||||
TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) {
|
TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) {
|
||||||
Shape s = ShapeUtil::MakeShape(F32, {1});
|
Shape s = ShapeUtil::MakeShape(F32, {1});
|
||||||
TestAllocator allocator;
|
TestAllocator allocator;
|
||||||
ScopedShapedBuffer sb1(s, s, &allocator, /*device_ordinal=*/0);
|
ScopedShapedBuffer sb1(s, &allocator, /*device_ordinal=*/0);
|
||||||
sb1.set_buffer(
|
sb1.set_buffer(
|
||||||
allocator.Allocate(/*device_ordinal=*/0, /*size=*/42).ValueOrDie(),
|
allocator.Allocate(/*device_ordinal=*/0, /*size=*/42).ValueOrDie(),
|
||||||
/*index=*/{});
|
/*index=*/{});
|
||||||
|
|
||||||
ScopedShapedBuffer sb2(s, s, &allocator, /*device_ordinal=*/1);
|
ScopedShapedBuffer sb2(s, &allocator, /*device_ordinal=*/1);
|
||||||
sb2.set_buffer(
|
sb2.set_buffer(
|
||||||
allocator.Allocate(/*device_ordinal=*/1, /*size=*/10).ValueOrDie(),
|
allocator.Allocate(/*device_ordinal=*/1, /*size=*/10).ValueOrDie(),
|
||||||
/*index=*/{});
|
/*index=*/{});
|
||||||
@ -119,7 +119,7 @@ TEST(ScopedShapedBufferTest, TestTakeSubTree) {
|
|||||||
s = xla::ShapeUtil::MakeTupleShape(std::vector<xla::Shape>(2, s));
|
s = xla::ShapeUtil::MakeTupleShape(std::vector<xla::Shape>(2, s));
|
||||||
s = xla::ShapeUtil::MakeTupleShape(std::vector<xla::Shape>(3, s));
|
s = xla::ShapeUtil::MakeTupleShape(std::vector<xla::Shape>(3, s));
|
||||||
|
|
||||||
ScopedShapedBuffer sb(s, s, &allocator, /*device_ordinal=*/0);
|
ScopedShapedBuffer sb(s, &allocator, /*device_ordinal=*/0);
|
||||||
sb.buffers().ForEachMutableElement(
|
sb.buffers().ForEachMutableElement(
|
||||||
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
@ -156,8 +156,7 @@ TEST(ScopedShapedBufferTest, TestSubShapeTree) {
|
|||||||
Shape tuple_shape =
|
Shape tuple_shape =
|
||||||
xla::ShapeUtil::MakeTupleShape({array_shape, array_shape});
|
xla::ShapeUtil::MakeTupleShape({array_shape, array_shape});
|
||||||
TestAllocator allocator;
|
TestAllocator allocator;
|
||||||
ScopedShapedBuffer sb(tuple_shape, tuple_shape, &allocator,
|
ScopedShapedBuffer sb(tuple_shape, &allocator, /*device_ordinal=*/0);
|
||||||
/*device_ordinal=*/0);
|
|
||||||
sb.buffers().ForEachMutableElement(
|
sb.buffers().ForEachMutableElement(
|
||||||
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
@ -182,7 +181,7 @@ void BM_TakeSubTree(int iters, int depth, int fan_out) {
|
|||||||
std::vector<xla::Shape> shapes(fan_out, shape);
|
std::vector<xla::Shape> shapes(fan_out, shape);
|
||||||
shape = xla::ShapeUtil::MakeTupleShape(shapes);
|
shape = xla::ShapeUtil::MakeTupleShape(shapes);
|
||||||
}
|
}
|
||||||
xla::ScopedShapedBuffer shaped_buffer(shape, shape, /*allocator=*/&allocator,
|
xla::ScopedShapedBuffer shaped_buffer(shape, /*allocator=*/&allocator,
|
||||||
/*device_ordinal=*/0);
|
/*device_ordinal=*/0);
|
||||||
tensorflow::testing::StartTiming();
|
tensorflow::testing::StartTiming();
|
||||||
for (int i = 0; i < iters; ++i) {
|
for (int i = 0; i < iters; ++i) {
|
||||||
|
@ -169,8 +169,7 @@ Status TransferManager::TransferArrayToDeviceAsync(
|
|||||||
"%d < %d",
|
"%d < %d",
|
||||||
dest.size(), GetByteSizeRequirement(on_device_shape));
|
dest.size(), GetByteSizeRequirement(on_device_shape));
|
||||||
}
|
}
|
||||||
ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape,
|
ShapedBuffer shaped_buffer(on_device_shape, stream->parent()->platform(),
|
||||||
stream->parent()->platform(),
|
|
||||||
stream->parent()->device_ordinal());
|
stream->parent()->device_ordinal());
|
||||||
shaped_buffer.set_buffer(dest, /*index=*/{});
|
shaped_buffer.set_buffer(dest, /*index=*/{});
|
||||||
return TransferLiteralToDevice(stream, literal, shaped_buffer,
|
return TransferLiteralToDevice(stream, literal, shaped_buffer,
|
||||||
@ -194,8 +193,7 @@ void TransferManager::TransferArrayFromDevice(
|
|||||||
"%d < %d",
|
"%d < %d",
|
||||||
source.size(), GetByteSizeRequirement(shape)));
|
source.size(), GetByteSizeRequirement(shape)));
|
||||||
}
|
}
|
||||||
ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape,
|
ShapedBuffer shaped_buffer(shape, stream->parent()->platform(),
|
||||||
stream->parent()->platform(),
|
|
||||||
stream->parent()->device_ordinal());
|
stream->parent()->device_ordinal());
|
||||||
shaped_buffer.set_buffer(source, /*index=*/{});
|
shaped_buffer.set_buffer(source, /*index=*/{});
|
||||||
return TransferLiteralFromDevice(stream, shaped_buffer, literal,
|
return TransferLiteralFromDevice(stream, shaped_buffer, literal,
|
||||||
@ -406,8 +404,8 @@ StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
|
|||||||
Shape on_device_shape = HostShapeToDeviceShape(on_host_shape);
|
Shape on_device_shape = HostShapeToDeviceShape(on_host_shape);
|
||||||
TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
|
TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
|
||||||
|
|
||||||
ScopedShapedBuffer shaped_buffer(on_host_shape, std::move(on_device_shape),
|
ScopedShapedBuffer shaped_buffer(std::move(on_device_shape), allocator,
|
||||||
allocator, device_ordinal);
|
device_ordinal);
|
||||||
|
|
||||||
// Allocate an appropriate sized buffer for each element in the shape
|
// Allocate an appropriate sized buffer for each element in the shape
|
||||||
// including the tuple pointer arrays.
|
// including the tuple pointer arrays.
|
||||||
|
@ -193,6 +193,7 @@ class TransferManager {
|
|||||||
// shapes, and returns static shapes with dynamic shapes updated.
|
// shapes, and returns static shapes with dynamic shapes updated.
|
||||||
// The shape of the buffer also have to be compatible with the host shape and
|
// The shape of the buffer also have to be compatible with the host shape and
|
||||||
// device shape.
|
// device shape.
|
||||||
|
// TODO(b/170310047): remove host_shape.
|
||||||
virtual Status ReadDynamicShapes(se::Stream* stream,
|
virtual Status ReadDynamicShapes(se::Stream* stream,
|
||||||
ShapedBuffer* device_buffer,
|
ShapedBuffer* device_buffer,
|
||||||
Shape* host_shape, Shape* device_shape);
|
Shape* host_shape, Shape* device_shape);
|
||||||
|
@ -119,8 +119,7 @@ class BufferDonationTest : public HloTestBase {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
args.emplace_back(
|
args.emplace_back(ExecutionInput(std::move(owned_buffers)));
|
||||||
ExecutionInput(std::move(owned_buffers), argument_literal.shape()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<ExecutionOutput> output_status =
|
StatusOr<ExecutionOutput> output_status =
|
||||||
|
Loading…
Reference in New Issue
Block a user