[XLA] Store host shape in ExecutionInput

Simplify the APIs explicitly passing the host shape

PiperOrigin-RevId: 321083080
Change-Id: I9e124dd4465ee4037f2d0cdbd33f04a43f35abc2
This commit is contained in:
George Karpenkov 2020-07-13 19:54:15 -07:00 committed by TensorFlower Gardener
parent 07daafc869
commit 1a9b57d729
8 changed files with 46 additions and 28 deletions

View File

@ -187,7 +187,7 @@ StatusOr<ExecutionOutput> LocalExecutable::Run(
std::vector<const Shape*> argument_shapes;
argument_shapes.reserve(arguments.size());
for (const ExecutionInput& arg : arguments) {
argument_shapes.push_back(&arg.shape());
argument_shapes.push_back(&arg.host_shape());
}
return AsyncCallAndBlockHostUntilDone<ExecutionOutput>(
argument_shapes, run_options, [&](const ExecutableRunOptions& options) {
@ -325,7 +325,7 @@ StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
std::vector<const Shape*> argument_shapes;
argument_shapes.reserve(arguments.size());
for (const ExecutionInput& arg : arguments) {
argument_shapes.push_back(&arg.shape());
argument_shapes.push_back(&arg.host_shape());
}
return RunAsync(argument_shapes, std::move(arguments), run_options);
}

View File

@ -64,10 +64,6 @@ class LocalExecutable {
// Similar to RunAsync(), but allows for donating argument buffers to the
// executable.
StatusOr<ExecutionOutput> RunAsync(
absl::Span<Shape const* const> argument_host_shapes,
std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options);
StatusOr<ExecutionOutput> RunAsync(std::vector<ExecutionInput> arguments,
ExecutableRunOptions run_options);
@ -78,6 +74,10 @@ class LocalExecutable {
Executable* executable() const { return executable_.get(); }
private:
StatusOr<ExecutionOutput> RunAsync(
absl::Span<Shape const* const> argument_host_shapes,
std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options);
// Validates that the given arguments and options satisfy various constraints
// of the computation.
//

View File

@ -1383,7 +1383,7 @@ StatusOr<TupleHandle> MakeTupleHelper(
local_device->compute_stream()->parent(), root_table_memory.cref()));
}
ExecutionInput execution_input(on_device_shape);
ExecutionInput execution_input(on_device_shape, on_host_shape);
ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
execution_input.MutableBuffers()->begin();
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
@ -1521,7 +1521,6 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
<< " mapped to device ordinal for execution: " << device_ordinal;
absl::flat_hash_set<BufferSequencingEvent*> events;
std::vector<const Shape*> argument_host_shapes;
std::vector<ExecutionInput> execution_inputs;
device_buffers->reserve(argument_handles.size());
const absl::flat_hash_set<int>& parameters_that_must_be_donated =
@ -1570,24 +1569,22 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
}
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
TupleHandle tuple_handle;
absl::optional<TupleHandle> tuple_handle;
if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) {
TF_ASSIGN_OR_RETURN(tuple_handle,
MakeTupleHelper(client_, device_state, argument_handles,
*device_buffers, device_ordinal));
events.insert(tuple_handle.event.get());
execution_inputs.emplace_back(std::move(tuple_handle.execution_input));
argument_host_shapes.push_back(&tuple_handle.on_host_shape);
events.insert(tuple_handle->event.get());
execution_inputs.emplace_back(std::move(tuple_handle->execution_input));
} else {
argument_host_shapes.reserve(argument_handles.size());
execution_inputs.reserve(argument_handles.size());
for (int i = 0; i < argument_handles.size(); ++i) {
PjRtBuffer* handle = argument_handles[i];
argument_host_shapes.push_back(&handle->on_host_shape());
const PjRtBuffer::ScopedHold& device_buffer = (*device_buffers)[i];
// Make an ExecutionInput from the device buffer.
execution_inputs.emplace_back(handle->on_device_shape());
execution_inputs.emplace_back(handle->on_device_shape(),
handle->on_host_shape());
ExecutionInput& execution_input = execution_inputs.back();
ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
execution_input.MutableBuffers()->begin();
@ -1623,8 +1620,8 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
device_state->compute_semaphore().ScopedAcquire(1));
StatusOr<ExecutionOutput> result_buffer_or_status =
executables_[executable_idx]->RunAsync(
argument_host_shapes, std::move(execution_inputs), run_options);
executables_[executable_idx]->RunAsync(std::move(execution_inputs),
run_options);
VLOG(1) << "Replica " << replica << " partition " << partition
<< " completed; ok=" << result_buffer_or_status.ok();

View File

@ -93,7 +93,8 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
static ExecutionInput MakeMaybeOwningDeviceMemoryTree(
const ShapedBuffer& shaped_buffer) {
ExecutionInput result(shaped_buffer.on_device_shape());
ExecutionInput result(shaped_buffer.on_device_shape(),
shaped_buffer.on_host_shape());
shaped_buffer.buffers().ForEachElement(
[&](const ShapeIndex& index, const se::DeviceMemoryBase& mem) {
result.SetBuffer(index, MaybeOwningDeviceMemory(mem));
@ -105,10 +106,10 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) {
std::vector<ExecutionInput> args(arguments.size());
auto out_it = args.begin();
std::vector<ExecutionInput> args;
args.reserve(arguments.size());
for (const ShapedBuffer* arg : arguments) {
*out_it++ = MakeMaybeOwningDeviceMemoryTree(*arg);
args.emplace_back(MakeMaybeOwningDeviceMemoryTree(*arg));
}
TF_ASSIGN_OR_RETURN(ExecutionOutput out,
ExecuteAsyncOnStream(run_options, std::move(args),

View File

@ -60,10 +60,17 @@ namespace xla {
// with their indices absent from unowned_indices_.
class ExecutionInput {
public:
ExecutionInput() = default;
explicit ExecutionInput(xla::Shape shape) : buffers_(std::move(shape)) {}
explicit ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers)
: buffers_(std::move(buffers)) {}
explicit ExecutionInput(xla::Shape shape, xla::Shape host_shape)
: buffers_(std::move(shape)) {
SetHostShape(std::move(host_shape));
}
explicit ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers,
xla::Shape host_shape)
: buffers_(std::move(buffers)) {
SetHostShape(std::move(host_shape));
}
ExecutionInput(ExecutionInput&&) = default;
~ExecutionInput();
@ -74,6 +81,10 @@ class ExecutionInput {
return dynamic_shape_ != nullptr ? *dynamic_shape_ : buffers_.shape();
}
const Shape& host_shape() const {
return host_shape_ != nullptr ? *host_shape_ : shape();
}
Status SetDynamicShape(Shape dynamic_shape);
xla::StatusOr<xla::ShapedBuffer> ToShapedBuffer(
@ -107,11 +118,18 @@ class ExecutionInput {
}
private:
void SetHostShape(xla::Shape host_shape) {
if (shape() != host_shape) {
host_shape_ = absl::make_unique<Shape>(std::move(host_shape));
}
}
ShapeTree<MaybeOwningDeviceMemory> buffers_;
// Set of indices of buffers that should be returned to the caller if an error
// occurs when enqueuing the computation.
std::set<ShapeIndex> unowned_indices_;
std::unique_ptr<Shape> dynamic_shape_;
std::unique_ptr<Shape> host_shape_;
};
// ExecutionOutput encapsulates the output buffers of a execution and the

View File

@ -211,7 +211,8 @@ static std::vector<ExecutionInput> ExecutionInputsFromScopedShapedBuffers(
*buffer_tree.mutable_element(index) = execution_input_buffer;
}
});
execution_inputs.emplace_back(std::move(buffer_tree));
execution_inputs.emplace_back(std::move(buffer_tree),
input_buffer.on_host_shape());
}
return execution_inputs;
}

View File

@ -119,7 +119,8 @@ class BufferDonationTest : public HloTestBase {
}
});
args.emplace_back(ExecutionInput(std::move(owned_buffers)));
args.emplace_back(
ExecutionInput(std::move(owned_buffers), argument_literal.shape()));
}
TF_ASSERT_OK_AND_ASSIGN(

View File

@ -650,7 +650,7 @@ Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source,
xla::StatusOr<xla::ExecutionInput> XRTTupleAllocation::ToExecutionInput(
const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>&
alias_checker) {
xla::ExecutionInput result(on_device_shape());
xla::ExecutionInput result(on_device_shape(), on_host_shape());
for (const auto& index_buffer : buffers_) {
if (index_buffer.second == nullptr ||
(index_buffer.second->allocation().is_null() &&