[XLA] Don't pass on_host_shape to ShapedBuffer/ScopedShapedBuffer inside XLA.

PiperOrigin-RevId: 336133292
Change-Id: I47a6fa5a5f2c6a460bdaeb1acc5125ff20710230
This commit is contained in:
Peter Hawkins 2020-10-08 11:54:04 -07:00 committed by TensorFlower Gardener
parent 17dc4b07a0
commit e7f6b0c7ee
15 changed files with 54 additions and 73 deletions

View File

@ -267,9 +267,9 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
}
static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer(
Shape const& on_host_shape, const ShapeTree<MaybeOwningDeviceMemory>& tree,
se::Platform* platform, int device_ordinal) {
ShapedBuffer result(on_host_shape, tree.shape(), platform, device_ordinal);
const ShapeTree<MaybeOwningDeviceMemory>& tree, se::Platform* platform,
int device_ordinal) {
ShapedBuffer result(tree.shape(), platform, device_ordinal);
auto it = tree.begin();
auto out_it = result.buffers().begin();
for (; it != tree.end(); ++it, ++out_it) {
@ -299,8 +299,8 @@ StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
shaped_buffer_ptrs.reserve(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer(
*argument_host_shapes[i], arguments[i].Buffers(),
backend_->platform(), stream->parent()->device_ordinal()));
arguments[i].Buffers(), backend_->platform(),
stream->parent()->device_ordinal()));
shaped_buffer_ptrs.push_back(&shaped_buffers.back());
}

View File

@ -143,13 +143,10 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
// We only need to care about replica id 0 here, since the GlobalDataHandle is
// the same for all buffers across replicas.
const ShapedBuffer* shaped_buffer = replicated_buffers[0];
if (!shaped_buffer->on_host_shape().IsTuple()) {
if (!shaped_buffer->on_device_shape().IsTuple()) {
return InvalidArgument("global data handle %d is not a tuple",
data.handle());
}
// If the on-host representation is a tuple, then the on-device one should be
// as well.
TF_RET_CHECK(shaped_buffer->on_device_shape().IsTuple());
if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) {
return Unimplemented("Deconstructing nested tuples is not implemented.");
@ -160,7 +157,6 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape());
++i) {
auto element_buffer = ShapedBuffer(
ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i),
ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i),
shaped_buffer->platform(), shaped_buffer->device_ordinal());
element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}),

View File

@ -210,8 +210,7 @@ StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
absl::Span<MaybeOwningDeviceMemory> buffers,
absl::Span<ExecutionInput> arguments) {
se::Stream* stream = run_options->stream();
ExecutionOutput result(/*on_host_shape=*/result_shape(),
/*on_device_shape=*/result_shape(),
ExecutionOutput result(/*on_device_shape=*/result_shape(),
run_options->allocator(),
stream->parent()->device_ordinal());
const HloInputOutputAliasConfig& input_output_alias =

View File

@ -59,11 +59,11 @@ void ExecutionInput::SetUnownedBuffer(const ShapeIndex& index,
unowned_indices_.insert(index);
}
xla::StatusOr<xla::ShapedBuffer> ExecutionInput::ToShapedBuffer(
StatusOr<ShapedBuffer> ExecutionInput::ToShapedBuffer(
se::DeviceMemoryAllocator* allocator, int device_ordinal) const {
const Shape& input_shape = shape();
xla::ShapedBuffer shaped_buffer(input_shape, input_shape,
allocator->platform(), device_ordinal);
ShapedBuffer shaped_buffer(input_shape, allocator->platform(),
device_ordinal);
for (const auto& index_buffer : Buffers()) {
const tensorflow::se::OwningDeviceMemory* mem =
index_buffer.second.AsOwningDeviceMemory();
@ -93,8 +93,7 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
static ExecutionInput MakeMaybeOwningDeviceMemoryTree(
const ShapedBuffer& shaped_buffer) {
ExecutionInput result(shaped_buffer.on_device_shape(),
shaped_buffer.on_host_shape());
ExecutionInput result(shaped_buffer.on_device_shape());
shaped_buffer.buffers().ForEachElement(
[&](const ShapeIndex& index, const se::DeviceMemoryBase& mem) {
result.SetBuffer(index, MaybeOwningDeviceMemory(mem));

View File

@ -153,13 +153,13 @@ class ExecutionOutput {
std::vector<se::OwningDeviceMemory> to_be_released)
: result_(std::move(result)),
to_be_released_(std::move(to_be_released)) {}
// TODO(b/170310047): remove this overload.
ExecutionOutput(Shape on_host_shape, Shape on_device_shape,
se::DeviceMemoryAllocator* allocator, int device_ordinal)
: result_(std::move(on_device_shape), allocator, device_ordinal) {}
ExecutionOutput(Shape on_device_shape, se::DeviceMemoryAllocator* allocator,
int device_ordinal)
: result_(std::move(on_device_shape), allocator, device_ordinal) {}
ExecutionOutput(Shape on_host_shape, Shape on_device_shape,
se::DeviceMemoryAllocator* allocator, int device_ordinal)
: result_(std::move(on_host_shape), std::move(on_device_shape), allocator,
device_ordinal) {}
ExecutionOutput(ExecutionOutput&&) = default;
ExecutionOutput& operator=(ExecutionOutput&&) = default;

View File

@ -69,13 +69,8 @@ void GenericTransferManager::TransferLiteralFromDevice(
TF_RET_CHECK(stream->parent()->device_ordinal() ==
device_buffer.device_ordinal());
// The on-host and on-device shape should always be the same for the generic
// transfer manager.
TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
device_buffer.on_host_shape()));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
device_buffer.on_host_shape(),
device_buffer.on_device_shape(),
[&](const Shape& subshape, const ShapeIndex& index) -> Status {
if (subshape.IsArray()) {
stream->ThenMemcpy(
@ -103,21 +98,15 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
<< ShapeUtil::HumanString(shape)
<< "; device buffer: " << device_buffer;
// The on-host and on-device shape should always be the same for the generic
// transfer manager.
TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
device_buffer.on_host_shape()))
<< device_buffer.ToString();
TF_RET_CHECK(
ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape()));
ShapeUtil::Compatible(literal.shape(), device_buffer.on_device_shape()));
TF_RET_CHECK(stream->parent()->device_ordinal() ==
device_buffer.device_ordinal());
TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
return ShapeUtil::ForEachSubshapeWithStatus(
device_buffer.on_host_shape(),
device_buffer.on_device_shape(),
[&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
if (device_subshape.IsArray()) {

View File

@ -450,8 +450,7 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
HloInstruction* root = hlo_module_->entry_computation()->root_instruction();
const Shape& root_shape = root->shape();
auto device_ordinal = executor->device_ordinal();
ExecutionOutput result(/*on_host_shape=*/root->shape(),
/*on_device_shape=*/root->shape(), memory_allocator,
ExecutionOutput result(/*on_device_shape=*/root->shape(), memory_allocator,
device_ordinal);
TF_ASSIGN_OR_RETURN(BufferAllocations buffer_allocations,

View File

@ -211,8 +211,7 @@ static std::vector<ExecutionInput> ExecutionInputsFromScopedShapedBuffers(
*buffer_tree.mutable_element(index) = execution_input_buffer;
}
});
execution_inputs.emplace_back(std::move(buffer_tree),
input_buffer.on_host_shape());
execution_inputs.emplace_back(std::move(buffer_tree));
}
return execution_inputs;
}

View File

@ -56,7 +56,7 @@ StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
}
for (auto& argument : arguments) {
const ShapeTree<MaybeOwningDeviceMemory>& buffers = argument.Buffers();
argument_buffers.push_back(ShapedBuffer(buffers.shape(), buffers.shape(),
argument_buffers.push_back(ShapedBuffer(buffers.shape(),
/*platform=*/nullptr,
/*device_ordinal=*/device_ordinal));
auto in_it = buffers.begin();

View File

@ -31,10 +31,6 @@ limitations under the License.
namespace xla {
ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
const se::Platform* platform, int device_ordinal)
: ShapedBuffer(on_device_shape, platform, device_ordinal) {}
ShapedBuffer::ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
int device_ordinal)
: on_device_shape_(std::move(on_device_shape)),
@ -44,6 +40,10 @@ ShapedBuffer::ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape_);
}
ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
const se::Platform* platform, int device_ordinal)
: ShapedBuffer(on_device_shape, platform, device_ordinal) {}
ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
: on_host_shape_(std::move(s.on_host_shape_)),
on_device_shape_(std::move(s.on_device_shape_)),
@ -90,12 +90,11 @@ void ShapedBuffer::clear() {
}
string ShapedBuffer::ToString() const {
string s = absl::StrCat(
"ShapedBuffer(", platform_->Name(), ":", device_ordinal(),
"), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()),
", on-device shape=" +
ShapeUtil::HumanStringWithLayout(on_device_shape()),
":\n");
string s =
absl::StrCat("ShapedBuffer(", platform_->Name(), ":", device_ordinal(),
"), on-device shape=" +
ShapeUtil::HumanStringWithLayout(on_device_shape()),
":\n");
ShapeUtil::ForEachSubshape(
on_device_shape(),
[this, &s](const Shape& subshape, const ShapeIndex& index) {
@ -118,14 +117,6 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) {
return out;
}
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,
Shape on_device_shape,
se::DeviceMemoryAllocator* allocator,
int device_ordinal)
: ShapedBuffer(std::move(on_device_shape), allocator->platform(),
device_ordinal),
allocator_(allocator) {}
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape,
se::DeviceMemoryAllocator* allocator,
int device_ordinal)
@ -133,6 +124,13 @@ ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape,
device_ordinal),
allocator_(allocator) {}
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,
Shape on_device_shape,
se::DeviceMemoryAllocator* allocator,
int device_ordinal)
: ScopedShapedBuffer(std::move(on_device_shape), allocator,
device_ordinal) {}
ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,
se::DeviceMemoryAllocator* allocator)
: ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {}

View File

@ -45,6 +45,7 @@ class ShapedBuffer {
// ShapedBuffer.
ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
int device_ordinal);
// TODO(b/170310047): remove this overload.
ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
const se::Platform* platform, int device_ordinal);
@ -100,7 +101,7 @@ class ShapedBuffer {
// Reset the shape of this shaped buffer and underlying buffer structure.
//
// Precondition: EqualStructure(this->on_device_shape_, on_device_shape).
void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) {
void set_shapes(const Shape& on_device_shape) {
CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_))
<< "Structures are not the same. new: " << on_device_shape
<< ", old: " << on_device_shape_;
@ -108,6 +109,10 @@ class ShapedBuffer {
on_device_shape_ = on_device_shape;
buffers_.replace_shape_ptr(&on_device_shape_);
}
// TODO(b/170310047): remove this overload.
void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) {
set_shapes(on_device_shape);
}
// Returns the underlying ShapeTree containing all the device addresses in the
// ShapedBuffer.

View File

@ -97,12 +97,12 @@ class TestAllocator : public se::DeviceMemoryAllocator {
TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) {
Shape s = ShapeUtil::MakeShape(F32, {1});
TestAllocator allocator;
ScopedShapedBuffer sb1(s, s, &allocator, /*device_ordinal=*/0);
ScopedShapedBuffer sb1(s, &allocator, /*device_ordinal=*/0);
sb1.set_buffer(
allocator.Allocate(/*device_ordinal=*/0, /*size=*/42).ValueOrDie(),
/*index=*/{});
ScopedShapedBuffer sb2(s, s, &allocator, /*device_ordinal=*/1);
ScopedShapedBuffer sb2(s, &allocator, /*device_ordinal=*/1);
sb2.set_buffer(
allocator.Allocate(/*device_ordinal=*/1, /*size=*/10).ValueOrDie(),
/*index=*/{});
@ -119,7 +119,7 @@ TEST(ScopedShapedBufferTest, TestTakeSubTree) {
s = xla::ShapeUtil::MakeTupleShape(std::vector<xla::Shape>(2, s));
s = xla::ShapeUtil::MakeTupleShape(std::vector<xla::Shape>(3, s));
ScopedShapedBuffer sb(s, s, &allocator, /*device_ordinal=*/0);
ScopedShapedBuffer sb(s, &allocator, /*device_ordinal=*/0);
sb.buffers().ForEachMutableElement(
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
TF_ASSERT_OK_AND_ASSIGN(
@ -156,8 +156,7 @@ TEST(ScopedShapedBufferTest, TestSubShapeTree) {
Shape tuple_shape =
xla::ShapeUtil::MakeTupleShape({array_shape, array_shape});
TestAllocator allocator;
ScopedShapedBuffer sb(tuple_shape, tuple_shape, &allocator,
/*device_ordinal=*/0);
ScopedShapedBuffer sb(tuple_shape, &allocator, /*device_ordinal=*/0);
sb.buffers().ForEachMutableElement(
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
TF_ASSERT_OK_AND_ASSIGN(
@ -182,7 +181,7 @@ void BM_TakeSubTree(int iters, int depth, int fan_out) {
std::vector<xla::Shape> shapes(fan_out, shape);
shape = xla::ShapeUtil::MakeTupleShape(shapes);
}
xla::ScopedShapedBuffer shaped_buffer(shape, shape, /*allocator=*/&allocator,
xla::ScopedShapedBuffer shaped_buffer(shape, /*allocator=*/&allocator,
/*device_ordinal=*/0);
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {

View File

@ -169,8 +169,7 @@ Status TransferManager::TransferArrayToDeviceAsync(
"%d < %d",
dest.size(), GetByteSizeRequirement(on_device_shape));
}
ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape,
stream->parent()->platform(),
ShapedBuffer shaped_buffer(on_device_shape, stream->parent()->platform(),
stream->parent()->device_ordinal());
shaped_buffer.set_buffer(dest, /*index=*/{});
return TransferLiteralToDevice(stream, literal, shaped_buffer,
@ -194,8 +193,7 @@ void TransferManager::TransferArrayFromDevice(
"%d < %d",
source.size(), GetByteSizeRequirement(shape)));
}
ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape,
stream->parent()->platform(),
ShapedBuffer shaped_buffer(shape, stream->parent()->platform(),
stream->parent()->device_ordinal());
shaped_buffer.set_buffer(source, /*index=*/{});
return TransferLiteralFromDevice(stream, shaped_buffer, literal,
@ -406,8 +404,8 @@ StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
Shape on_device_shape = HostShapeToDeviceShape(on_host_shape);
TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
ScopedShapedBuffer shaped_buffer(on_host_shape, std::move(on_device_shape),
allocator, device_ordinal);
ScopedShapedBuffer shaped_buffer(std::move(on_device_shape), allocator,
device_ordinal);
// Allocate an appropriate sized buffer for each element in the shape
// including the tuple pointer arrays.

View File

@ -193,6 +193,7 @@ class TransferManager {
// shapes, and returns static shapes with dynamic shapes updated.
// The shape of the buffer also have to be compatible with the host shape and
// device shape.
// TODO(b/170310047): remove host_shape.
virtual Status ReadDynamicShapes(se::Stream* stream,
ShapedBuffer* device_buffer,
Shape* host_shape, Shape* device_shape);

View File

@ -119,8 +119,7 @@ class BufferDonationTest : public HloTestBase {
}
});
args.emplace_back(
ExecutionInput(std::move(owned_buffers), argument_literal.shape()));
args.emplace_back(ExecutionInput(std::move(owned_buffers)));
}
StatusOr<ExecutionOutput> output_status =