[XLA] Don't pass on_host_shape to ShapedBuffer/ScopedShapedBuffer inside XLA.
PiperOrigin-RevId: 336133292 Change-Id: I47a6fa5a5f2c6a460bdaeb1acc5125ff20710230
This commit is contained in:
parent
17dc4b07a0
commit
e7f6b0c7ee
@ -267,9 +267,9 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
|
||||
}
|
||||
|
||||
static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer(
|
||||
Shape const& on_host_shape, const ShapeTree<MaybeOwningDeviceMemory>& tree,
|
||||
se::Platform* platform, int device_ordinal) {
|
||||
ShapedBuffer result(on_host_shape, tree.shape(), platform, device_ordinal);
|
||||
const ShapeTree<MaybeOwningDeviceMemory>& tree, se::Platform* platform,
|
||||
int device_ordinal) {
|
||||
ShapedBuffer result(tree.shape(), platform, device_ordinal);
|
||||
auto it = tree.begin();
|
||||
auto out_it = result.buffers().begin();
|
||||
for (; it != tree.end(); ++it, ++out_it) {
|
||||
@ -299,8 +299,8 @@ StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
|
||||
shaped_buffer_ptrs.reserve(arguments.size());
|
||||
for (size_t i = 0; i < arguments.size(); ++i) {
|
||||
shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer(
|
||||
*argument_host_shapes[i], arguments[i].Buffers(),
|
||||
backend_->platform(), stream->parent()->device_ordinal()));
|
||||
arguments[i].Buffers(), backend_->platform(),
|
||||
stream->parent()->device_ordinal()));
|
||||
shaped_buffer_ptrs.push_back(&shaped_buffers.back());
|
||||
}
|
||||
|
||||
|
@ -143,13 +143,10 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
|
||||
// We only need to care about replica id 0 here, since the GlobalDataHandle is
|
||||
// the same for all buffers across replicas.
|
||||
const ShapedBuffer* shaped_buffer = replicated_buffers[0];
|
||||
if (!shaped_buffer->on_host_shape().IsTuple()) {
|
||||
if (!shaped_buffer->on_device_shape().IsTuple()) {
|
||||
return InvalidArgument("global data handle %d is not a tuple",
|
||||
data.handle());
|
||||
}
|
||||
// If the on-host representation is a tuple, then the on-device one should be
|
||||
// as well.
|
||||
TF_RET_CHECK(shaped_buffer->on_device_shape().IsTuple());
|
||||
|
||||
if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) {
|
||||
return Unimplemented("Deconstructing nested tuples is not implemented.");
|
||||
@ -160,7 +157,6 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
|
||||
i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape());
|
||||
++i) {
|
||||
auto element_buffer = ShapedBuffer(
|
||||
ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i),
|
||||
ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i),
|
||||
shaped_buffer->platform(), shaped_buffer->device_ordinal());
|
||||
element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}),
|
||||
|
@ -210,8 +210,7 @@ StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
|
||||
absl::Span<MaybeOwningDeviceMemory> buffers,
|
||||
absl::Span<ExecutionInput> arguments) {
|
||||
se::Stream* stream = run_options->stream();
|
||||
ExecutionOutput result(/*on_host_shape=*/result_shape(),
|
||||
/*on_device_shape=*/result_shape(),
|
||||
ExecutionOutput result(/*on_device_shape=*/result_shape(),
|
||||
run_options->allocator(),
|
||||
stream->parent()->device_ordinal());
|
||||
const HloInputOutputAliasConfig& input_output_alias =
|
||||
|
@ -59,11 +59,11 @@ void ExecutionInput::SetUnownedBuffer(const ShapeIndex& index,
|
||||
unowned_indices_.insert(index);
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::ShapedBuffer> ExecutionInput::ToShapedBuffer(
|
||||
StatusOr<ShapedBuffer> ExecutionInput::ToShapedBuffer(
|
||||
se::DeviceMemoryAllocator* allocator, int device_ordinal) const {
|
||||
const Shape& input_shape = shape();
|
||||
xla::ShapedBuffer shaped_buffer(input_shape, input_shape,
|
||||
allocator->platform(), device_ordinal);
|
||||
ShapedBuffer shaped_buffer(input_shape, allocator->platform(),
|
||||
device_ordinal);
|
||||
for (const auto& index_buffer : Buffers()) {
|
||||
const tensorflow::se::OwningDeviceMemory* mem =
|
||||
index_buffer.second.AsOwningDeviceMemory();
|
||||
@ -93,8 +93,7 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
|
||||
|
||||
static ExecutionInput MakeMaybeOwningDeviceMemoryTree(
|
||||
const ShapedBuffer& shaped_buffer) {
|
||||
ExecutionInput result(shaped_buffer.on_device_shape(),
|
||||
shaped_buffer.on_host_shape());
|
||||
ExecutionInput result(shaped_buffer.on_device_shape());
|
||||
shaped_buffer.buffers().ForEachElement(
|
||||
[&](const ShapeIndex& index, const se::DeviceMemoryBase& mem) {
|
||||
result.SetBuffer(index, MaybeOwningDeviceMemory(mem));
|
||||
|
@ -153,13 +153,13 @@ class ExecutionOutput {
|
||||
std::vector<se::OwningDeviceMemory> to_be_released)
|
||||
: result_(std::move(result)),
|
||||
to_be_released_(std::move(to_be_released)) {}
|
||||
// TODO(b/170310047): remove this overload.
|
||||
ExecutionOutput(Shape on_host_shape, Shape on_device_shape,
|
||||
se::DeviceMemoryAllocator* allocator, int device_ordinal)
|
||||
: result_(std::move(on_device_shape), allocator, device_ordinal) {}
|
||||
ExecutionOutput(Shape on_device_shape, se::DeviceMemoryAllocator* allocator,
|
||||
int device_ordinal)
|
||||
: result_(std::move(on_device_shape), allocator, device_ordinal) {}
|
||||
ExecutionOutput(Shape on_host_shape, Shape on_device_shape,
|
||||
se::DeviceMemoryAllocator* allocator, int device_ordinal)
|
||||
: result_(std::move(on_host_shape), std::move(on_device_shape), allocator,
|
||||
device_ordinal) {}
|
||||
ExecutionOutput(ExecutionOutput&&) = default;
|
||||
ExecutionOutput& operator=(ExecutionOutput&&) = default;
|
||||
|
||||
|
@ -69,13 +69,8 @@ void GenericTransferManager::TransferLiteralFromDevice(
|
||||
TF_RET_CHECK(stream->parent()->device_ordinal() ==
|
||||
device_buffer.device_ordinal());
|
||||
|
||||
// The on-host and on-device shape should always be the same for the generic
|
||||
// transfer manager.
|
||||
TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
|
||||
device_buffer.on_host_shape()));
|
||||
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
|
||||
device_buffer.on_host_shape(),
|
||||
device_buffer.on_device_shape(),
|
||||
[&](const Shape& subshape, const ShapeIndex& index) -> Status {
|
||||
if (subshape.IsArray()) {
|
||||
stream->ThenMemcpy(
|
||||
@ -103,21 +98,15 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
|
||||
<< ShapeUtil::HumanString(shape)
|
||||
<< "; device buffer: " << device_buffer;
|
||||
|
||||
// The on-host and on-device shape should always be the same for the generic
|
||||
// transfer manager.
|
||||
TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
|
||||
device_buffer.on_host_shape()))
|
||||
<< device_buffer.ToString();
|
||||
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape()));
|
||||
ShapeUtil::Compatible(literal.shape(), device_buffer.on_device_shape()));
|
||||
TF_RET_CHECK(stream->parent()->device_ordinal() ==
|
||||
device_buffer.device_ordinal());
|
||||
|
||||
TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
|
||||
|
||||
return ShapeUtil::ForEachSubshapeWithStatus(
|
||||
device_buffer.on_host_shape(),
|
||||
device_buffer.on_device_shape(),
|
||||
[&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
|
||||
se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
|
||||
if (device_subshape.IsArray()) {
|
||||
|
@ -450,8 +450,7 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
|
||||
HloInstruction* root = hlo_module_->entry_computation()->root_instruction();
|
||||
const Shape& root_shape = root->shape();
|
||||
auto device_ordinal = executor->device_ordinal();
|
||||
ExecutionOutput result(/*on_host_shape=*/root->shape(),
|
||||
/*on_device_shape=*/root->shape(), memory_allocator,
|
||||
ExecutionOutput result(/*on_device_shape=*/root->shape(), memory_allocator,
|
||||
device_ordinal);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocations buffer_allocations,
|
||||
|
@ -211,8 +211,7 @@ static std::vector<ExecutionInput> ExecutionInputsFromScopedShapedBuffers(
|
||||
*buffer_tree.mutable_element(index) = execution_input_buffer;
|
||||
}
|
||||
});
|
||||
execution_inputs.emplace_back(std::move(buffer_tree),
|
||||
input_buffer.on_host_shape());
|
||||
execution_inputs.emplace_back(std::move(buffer_tree));
|
||||
}
|
||||
return execution_inputs;
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
|
||||
}
|
||||
for (auto& argument : arguments) {
|
||||
const ShapeTree<MaybeOwningDeviceMemory>& buffers = argument.Buffers();
|
||||
argument_buffers.push_back(ShapedBuffer(buffers.shape(), buffers.shape(),
|
||||
argument_buffers.push_back(ShapedBuffer(buffers.shape(),
|
||||
/*platform=*/nullptr,
|
||||
/*device_ordinal=*/device_ordinal));
|
||||
auto in_it = buffers.begin();
|
||||
|
@ -31,10 +31,6 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
const se::Platform* platform, int device_ordinal)
|
||||
: ShapedBuffer(on_device_shape, platform, device_ordinal) {}
|
||||
|
||||
ShapedBuffer::ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
|
||||
int device_ordinal)
|
||||
: on_device_shape_(std::move(on_device_shape)),
|
||||
@ -44,6 +40,10 @@ ShapedBuffer::ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
|
||||
on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape_);
|
||||
}
|
||||
|
||||
ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
const se::Platform* platform, int device_ordinal)
|
||||
: ShapedBuffer(on_device_shape, platform, device_ordinal) {}
|
||||
|
||||
ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
|
||||
: on_host_shape_(std::move(s.on_host_shape_)),
|
||||
on_device_shape_(std::move(s.on_device_shape_)),
|
||||
@ -90,12 +90,11 @@ void ShapedBuffer::clear() {
|
||||
}
|
||||
|
||||
string ShapedBuffer::ToString() const {
|
||||
string s = absl::StrCat(
|
||||
"ShapedBuffer(", platform_->Name(), ":", device_ordinal(),
|
||||
"), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()),
|
||||
", on-device shape=" +
|
||||
ShapeUtil::HumanStringWithLayout(on_device_shape()),
|
||||
":\n");
|
||||
string s =
|
||||
absl::StrCat("ShapedBuffer(", platform_->Name(), ":", device_ordinal(),
|
||||
"), on-device shape=" +
|
||||
ShapeUtil::HumanStringWithLayout(on_device_shape()),
|
||||
":\n");
|
||||
ShapeUtil::ForEachSubshape(
|
||||
on_device_shape(),
|
||||
[this, &s](const Shape& subshape, const ShapeIndex& index) {
|
||||
@ -118,14 +117,6 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) {
|
||||
return out;
|
||||
}
|
||||
|
||||
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,
|
||||
Shape on_device_shape,
|
||||
se::DeviceMemoryAllocator* allocator,
|
||||
int device_ordinal)
|
||||
: ShapedBuffer(std::move(on_device_shape), allocator->platform(),
|
||||
device_ordinal),
|
||||
allocator_(allocator) {}
|
||||
|
||||
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape,
|
||||
se::DeviceMemoryAllocator* allocator,
|
||||
int device_ordinal)
|
||||
@ -133,6 +124,13 @@ ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape,
|
||||
device_ordinal),
|
||||
allocator_(allocator) {}
|
||||
|
||||
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,
|
||||
Shape on_device_shape,
|
||||
se::DeviceMemoryAllocator* allocator,
|
||||
int device_ordinal)
|
||||
: ScopedShapedBuffer(std::move(on_device_shape), allocator,
|
||||
device_ordinal) {}
|
||||
|
||||
ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,
|
||||
se::DeviceMemoryAllocator* allocator)
|
||||
: ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {}
|
||||
|
@ -45,6 +45,7 @@ class ShapedBuffer {
|
||||
// ShapedBuffer.
|
||||
ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
|
||||
int device_ordinal);
|
||||
|
||||
// TODO(b/170310047): remove this overload.
|
||||
ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
const se::Platform* platform, int device_ordinal);
|
||||
@ -100,7 +101,7 @@ class ShapedBuffer {
|
||||
// Reset the shape of this shaped buffer and underlying buffer structure.
|
||||
//
|
||||
// Precondition: EqualStructure(this->on_device_shape_, on_device_shape).
|
||||
void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) {
|
||||
void set_shapes(const Shape& on_device_shape) {
|
||||
CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_))
|
||||
<< "Structures are not the same. new: " << on_device_shape
|
||||
<< ", old: " << on_device_shape_;
|
||||
@ -108,6 +109,10 @@ class ShapedBuffer {
|
||||
on_device_shape_ = on_device_shape;
|
||||
buffers_.replace_shape_ptr(&on_device_shape_);
|
||||
}
|
||||
// TODO(b/170310047): remove this overload.
|
||||
void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) {
|
||||
set_shapes(on_device_shape);
|
||||
}
|
||||
|
||||
// Returns the underlying ShapeTree containing all the device addresses in the
|
||||
// ShapedBuffer.
|
||||
|
@ -97,12 +97,12 @@ class TestAllocator : public se::DeviceMemoryAllocator {
|
||||
TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) {
|
||||
Shape s = ShapeUtil::MakeShape(F32, {1});
|
||||
TestAllocator allocator;
|
||||
ScopedShapedBuffer sb1(s, s, &allocator, /*device_ordinal=*/0);
|
||||
ScopedShapedBuffer sb1(s, &allocator, /*device_ordinal=*/0);
|
||||
sb1.set_buffer(
|
||||
allocator.Allocate(/*device_ordinal=*/0, /*size=*/42).ValueOrDie(),
|
||||
/*index=*/{});
|
||||
|
||||
ScopedShapedBuffer sb2(s, s, &allocator, /*device_ordinal=*/1);
|
||||
ScopedShapedBuffer sb2(s, &allocator, /*device_ordinal=*/1);
|
||||
sb2.set_buffer(
|
||||
allocator.Allocate(/*device_ordinal=*/1, /*size=*/10).ValueOrDie(),
|
||||
/*index=*/{});
|
||||
@ -119,7 +119,7 @@ TEST(ScopedShapedBufferTest, TestTakeSubTree) {
|
||||
s = xla::ShapeUtil::MakeTupleShape(std::vector<xla::Shape>(2, s));
|
||||
s = xla::ShapeUtil::MakeTupleShape(std::vector<xla::Shape>(3, s));
|
||||
|
||||
ScopedShapedBuffer sb(s, s, &allocator, /*device_ordinal=*/0);
|
||||
ScopedShapedBuffer sb(s, &allocator, /*device_ordinal=*/0);
|
||||
sb.buffers().ForEachMutableElement(
|
||||
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
@ -156,8 +156,7 @@ TEST(ScopedShapedBufferTest, TestSubShapeTree) {
|
||||
Shape tuple_shape =
|
||||
xla::ShapeUtil::MakeTupleShape({array_shape, array_shape});
|
||||
TestAllocator allocator;
|
||||
ScopedShapedBuffer sb(tuple_shape, tuple_shape, &allocator,
|
||||
/*device_ordinal=*/0);
|
||||
ScopedShapedBuffer sb(tuple_shape, &allocator, /*device_ordinal=*/0);
|
||||
sb.buffers().ForEachMutableElement(
|
||||
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
@ -182,7 +181,7 @@ void BM_TakeSubTree(int iters, int depth, int fan_out) {
|
||||
std::vector<xla::Shape> shapes(fan_out, shape);
|
||||
shape = xla::ShapeUtil::MakeTupleShape(shapes);
|
||||
}
|
||||
xla::ScopedShapedBuffer shaped_buffer(shape, shape, /*allocator=*/&allocator,
|
||||
xla::ScopedShapedBuffer shaped_buffer(shape, /*allocator=*/&allocator,
|
||||
/*device_ordinal=*/0);
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
|
@ -169,8 +169,7 @@ Status TransferManager::TransferArrayToDeviceAsync(
|
||||
"%d < %d",
|
||||
dest.size(), GetByteSizeRequirement(on_device_shape));
|
||||
}
|
||||
ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape,
|
||||
stream->parent()->platform(),
|
||||
ShapedBuffer shaped_buffer(on_device_shape, stream->parent()->platform(),
|
||||
stream->parent()->device_ordinal());
|
||||
shaped_buffer.set_buffer(dest, /*index=*/{});
|
||||
return TransferLiteralToDevice(stream, literal, shaped_buffer,
|
||||
@ -194,8 +193,7 @@ void TransferManager::TransferArrayFromDevice(
|
||||
"%d < %d",
|
||||
source.size(), GetByteSizeRequirement(shape)));
|
||||
}
|
||||
ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape,
|
||||
stream->parent()->platform(),
|
||||
ShapedBuffer shaped_buffer(shape, stream->parent()->platform(),
|
||||
stream->parent()->device_ordinal());
|
||||
shaped_buffer.set_buffer(source, /*index=*/{});
|
||||
return TransferLiteralFromDevice(stream, shaped_buffer, literal,
|
||||
@ -406,8 +404,8 @@ StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
|
||||
Shape on_device_shape = HostShapeToDeviceShape(on_host_shape);
|
||||
TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
|
||||
|
||||
ScopedShapedBuffer shaped_buffer(on_host_shape, std::move(on_device_shape),
|
||||
allocator, device_ordinal);
|
||||
ScopedShapedBuffer shaped_buffer(std::move(on_device_shape), allocator,
|
||||
device_ordinal);
|
||||
|
||||
// Allocate an appropriate sized buffer for each element in the shape
|
||||
// including the tuple pointer arrays.
|
||||
|
@ -193,6 +193,7 @@ class TransferManager {
|
||||
// shapes, and returns static shapes with dynamic shapes updated.
|
||||
// The shape of the buffer also have to be compatible with the host shape and
|
||||
// device shape.
|
||||
// TODO(b/170310047): remove host_shape.
|
||||
virtual Status ReadDynamicShapes(se::Stream* stream,
|
||||
ShapedBuffer* device_buffer,
|
||||
Shape* host_shape, Shape* device_shape);
|
||||
|
@ -119,8 +119,7 @@ class BufferDonationTest : public HloTestBase {
|
||||
}
|
||||
});
|
||||
|
||||
args.emplace_back(
|
||||
ExecutionInput(std::move(owned_buffers), argument_literal.shape()));
|
||||
args.emplace_back(ExecutionInput(std::move(owned_buffers)));
|
||||
}
|
||||
|
||||
StatusOr<ExecutionOutput> output_status =
|
||||
|
Loading…
Reference in New Issue
Block a user