[XLA] Store host shape in ExecutionInput
Simplify the APIs explicitly passing the host shape PiperOrigin-RevId: 321083080 Change-Id: I9e124dd4465ee4037f2d0cdbd33f04a43f35abc2
This commit is contained in:
parent
07daafc869
commit
1a9b57d729
tensorflow/compiler
@ -187,7 +187,7 @@ StatusOr<ExecutionOutput> LocalExecutable::Run(
|
||||
std::vector<const Shape*> argument_shapes;
|
||||
argument_shapes.reserve(arguments.size());
|
||||
for (const ExecutionInput& arg : arguments) {
|
||||
argument_shapes.push_back(&arg.shape());
|
||||
argument_shapes.push_back(&arg.host_shape());
|
||||
}
|
||||
return AsyncCallAndBlockHostUntilDone<ExecutionOutput>(
|
||||
argument_shapes, run_options, [&](const ExecutableRunOptions& options) {
|
||||
@ -325,7 +325,7 @@ StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
|
||||
std::vector<const Shape*> argument_shapes;
|
||||
argument_shapes.reserve(arguments.size());
|
||||
for (const ExecutionInput& arg : arguments) {
|
||||
argument_shapes.push_back(&arg.shape());
|
||||
argument_shapes.push_back(&arg.host_shape());
|
||||
}
|
||||
return RunAsync(argument_shapes, std::move(arguments), run_options);
|
||||
}
|
||||
|
@ -64,10 +64,6 @@ class LocalExecutable {
|
||||
|
||||
// Similar to RunAsync(), but allows for donating argument buffers to the
|
||||
// executable.
|
||||
StatusOr<ExecutionOutput> RunAsync(
|
||||
absl::Span<Shape const* const> argument_host_shapes,
|
||||
std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options);
|
||||
|
||||
StatusOr<ExecutionOutput> RunAsync(std::vector<ExecutionInput> arguments,
|
||||
ExecutableRunOptions run_options);
|
||||
|
||||
@ -78,6 +74,10 @@ class LocalExecutable {
|
||||
Executable* executable() const { return executable_.get(); }
|
||||
|
||||
private:
|
||||
StatusOr<ExecutionOutput> RunAsync(
|
||||
absl::Span<Shape const* const> argument_host_shapes,
|
||||
std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options);
|
||||
|
||||
// Validates that the given arguments and options satisfy various constraints
|
||||
// of the computation.
|
||||
//
|
||||
|
@ -1383,7 +1383,7 @@ StatusOr<TupleHandle> MakeTupleHelper(
|
||||
local_device->compute_stream()->parent(), root_table_memory.cref()));
|
||||
}
|
||||
|
||||
ExecutionInput execution_input(on_device_shape);
|
||||
ExecutionInput execution_input(on_device_shape, on_host_shape);
|
||||
ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
|
||||
execution_input.MutableBuffers()->begin();
|
||||
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
|
||||
@ -1521,7 +1521,6 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
|
||||
<< " mapped to device ordinal for execution: " << device_ordinal;
|
||||
|
||||
absl::flat_hash_set<BufferSequencingEvent*> events;
|
||||
std::vector<const Shape*> argument_host_shapes;
|
||||
std::vector<ExecutionInput> execution_inputs;
|
||||
device_buffers->reserve(argument_handles.size());
|
||||
const absl::flat_hash_set<int>& parameters_that_must_be_donated =
|
||||
@ -1570,24 +1569,22 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
|
||||
}
|
||||
|
||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
||||
TupleHandle tuple_handle;
|
||||
absl::optional<TupleHandle> tuple_handle;
|
||||
if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) {
|
||||
TF_ASSIGN_OR_RETURN(tuple_handle,
|
||||
MakeTupleHelper(client_, device_state, argument_handles,
|
||||
*device_buffers, device_ordinal));
|
||||
events.insert(tuple_handle.event.get());
|
||||
execution_inputs.emplace_back(std::move(tuple_handle.execution_input));
|
||||
argument_host_shapes.push_back(&tuple_handle.on_host_shape);
|
||||
events.insert(tuple_handle->event.get());
|
||||
execution_inputs.emplace_back(std::move(tuple_handle->execution_input));
|
||||
} else {
|
||||
argument_host_shapes.reserve(argument_handles.size());
|
||||
execution_inputs.reserve(argument_handles.size());
|
||||
for (int i = 0; i < argument_handles.size(); ++i) {
|
||||
PjRtBuffer* handle = argument_handles[i];
|
||||
argument_host_shapes.push_back(&handle->on_host_shape());
|
||||
|
||||
const PjRtBuffer::ScopedHold& device_buffer = (*device_buffers)[i];
|
||||
// Make an ExecutionInput from the device buffer.
|
||||
execution_inputs.emplace_back(handle->on_device_shape());
|
||||
execution_inputs.emplace_back(handle->on_device_shape(),
|
||||
handle->on_host_shape());
|
||||
ExecutionInput& execution_input = execution_inputs.back();
|
||||
ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
|
||||
execution_input.MutableBuffers()->begin();
|
||||
@ -1623,8 +1620,8 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
|
||||
device_state->compute_semaphore().ScopedAcquire(1));
|
||||
|
||||
StatusOr<ExecutionOutput> result_buffer_or_status =
|
||||
executables_[executable_idx]->RunAsync(
|
||||
argument_host_shapes, std::move(execution_inputs), run_options);
|
||||
executables_[executable_idx]->RunAsync(std::move(execution_inputs),
|
||||
run_options);
|
||||
|
||||
VLOG(1) << "Replica " << replica << " partition " << partition
|
||||
<< " completed; ok=" << result_buffer_or_status.ok();
|
||||
|
@ -93,7 +93,8 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
|
||||
|
||||
static ExecutionInput MakeMaybeOwningDeviceMemoryTree(
|
||||
const ShapedBuffer& shaped_buffer) {
|
||||
ExecutionInput result(shaped_buffer.on_device_shape());
|
||||
ExecutionInput result(shaped_buffer.on_device_shape(),
|
||||
shaped_buffer.on_host_shape());
|
||||
shaped_buffer.buffers().ForEachElement(
|
||||
[&](const ShapeIndex& index, const se::DeviceMemoryBase& mem) {
|
||||
result.SetBuffer(index, MaybeOwningDeviceMemory(mem));
|
||||
@ -105,10 +106,10 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
std::vector<ExecutionInput> args(arguments.size());
|
||||
auto out_it = args.begin();
|
||||
std::vector<ExecutionInput> args;
|
||||
args.reserve(arguments.size());
|
||||
for (const ShapedBuffer* arg : arguments) {
|
||||
*out_it++ = MakeMaybeOwningDeviceMemoryTree(*arg);
|
||||
args.emplace_back(MakeMaybeOwningDeviceMemoryTree(*arg));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(ExecutionOutput out,
|
||||
ExecuteAsyncOnStream(run_options, std::move(args),
|
||||
|
@ -60,10 +60,17 @@ namespace xla {
|
||||
// with their indices absent from unowned_indices_.
|
||||
class ExecutionInput {
|
||||
public:
|
||||
ExecutionInput() = default;
|
||||
explicit ExecutionInput(xla::Shape shape) : buffers_(std::move(shape)) {}
|
||||
explicit ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers)
|
||||
: buffers_(std::move(buffers)) {}
|
||||
explicit ExecutionInput(xla::Shape shape, xla::Shape host_shape)
|
||||
: buffers_(std::move(shape)) {
|
||||
SetHostShape(std::move(host_shape));
|
||||
}
|
||||
|
||||
explicit ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers,
|
||||
xla::Shape host_shape)
|
||||
: buffers_(std::move(buffers)) {
|
||||
SetHostShape(std::move(host_shape));
|
||||
}
|
||||
|
||||
ExecutionInput(ExecutionInput&&) = default;
|
||||
|
||||
~ExecutionInput();
|
||||
@ -74,6 +81,10 @@ class ExecutionInput {
|
||||
return dynamic_shape_ != nullptr ? *dynamic_shape_ : buffers_.shape();
|
||||
}
|
||||
|
||||
const Shape& host_shape() const {
|
||||
return host_shape_ != nullptr ? *host_shape_ : shape();
|
||||
}
|
||||
|
||||
Status SetDynamicShape(Shape dynamic_shape);
|
||||
|
||||
xla::StatusOr<xla::ShapedBuffer> ToShapedBuffer(
|
||||
@ -107,11 +118,18 @@ class ExecutionInput {
|
||||
}
|
||||
|
||||
private:
|
||||
void SetHostShape(xla::Shape host_shape) {
|
||||
if (shape() != host_shape) {
|
||||
host_shape_ = absl::make_unique<Shape>(std::move(host_shape));
|
||||
}
|
||||
}
|
||||
|
||||
ShapeTree<MaybeOwningDeviceMemory> buffers_;
|
||||
// Set of indices of buffers that should be returned to the caller if an error
|
||||
// occurs when enqueuing the computation.
|
||||
std::set<ShapeIndex> unowned_indices_;
|
||||
std::unique_ptr<Shape> dynamic_shape_;
|
||||
std::unique_ptr<Shape> host_shape_;
|
||||
};
|
||||
|
||||
// ExecutionOutput encapsulates the output buffers of a execution and the
|
||||
|
@ -211,7 +211,8 @@ static std::vector<ExecutionInput> ExecutionInputsFromScopedShapedBuffers(
|
||||
*buffer_tree.mutable_element(index) = execution_input_buffer;
|
||||
}
|
||||
});
|
||||
execution_inputs.emplace_back(std::move(buffer_tree));
|
||||
execution_inputs.emplace_back(std::move(buffer_tree),
|
||||
input_buffer.on_host_shape());
|
||||
}
|
||||
return execution_inputs;
|
||||
}
|
||||
|
@ -119,7 +119,8 @@ class BufferDonationTest : public HloTestBase {
|
||||
}
|
||||
});
|
||||
|
||||
args.emplace_back(ExecutionInput(std::move(owned_buffers)));
|
||||
args.emplace_back(
|
||||
ExecutionInput(std::move(owned_buffers), argument_literal.shape()));
|
||||
}
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
|
@ -650,7 +650,7 @@ Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source,
|
||||
xla::StatusOr<xla::ExecutionInput> XRTTupleAllocation::ToExecutionInput(
|
||||
const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>&
|
||||
alias_checker) {
|
||||
xla::ExecutionInput result(on_device_shape());
|
||||
xla::ExecutionInput result(on_device_shape(), on_host_shape());
|
||||
for (const auto& index_buffer : buffers_) {
|
||||
if (index_buffer.second == nullptr ||
|
||||
(index_buffer.second->allocation().is_null() &&
|
||||
|
Loading…
Reference in New Issue
Block a user