[XLA] Consolidate Executable::ExecuteOnStream and ExecuteAsyncOnStream.
Remove ExecuteOnStream virtual method, make ExecuteOnStream a non-virtual wrapper around ExecuteAsyncOnStream. This means that backend authors have one method to implement (ExecuteAsyncOnStream) rather than two, and reduces the number of code paths to running an executable. Comment that ExecuteAsyncOnStream may in fact not be async. While undesirable, this is a quality of implementation issue not a bug. Future changes can make implementations of ExecuteAsyncOnStream truly async. PiperOrigin-RevId: 261922907
This commit is contained in:
parent
605ff0fecf
commit
1762bef938
@ -213,7 +213,8 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ScopedShapedBuffer outputs,
|
||||
executable_->ExecuteAsyncOnStream(&options_and_stream.first, arguments));
|
||||
executable_->ExecuteAsyncOnStream(&options_and_stream.first, arguments,
|
||||
/*hlo_execution_profile=*/nullptr));
|
||||
|
||||
// Transfer the outputs and save the snapshot to disk.
|
||||
if (snapshot) {
|
||||
|
@ -268,29 +268,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
|
||||
return std::move(result_buffer);
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto result,
|
||||
ExecuteAsyncOnStreamImpl(run_options, arguments, hlo_execution_profile));
|
||||
TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments) {
|
||||
if (hlo_profiling_enabled()) {
|
||||
return Unimplemented(
|
||||
"Asynchronous execution on stream with hlo profiling is not yet "
|
||||
"supported on CPU.");
|
||||
}
|
||||
return ExecuteAsyncOnStreamImpl(run_options, arguments, nullptr);
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
|
@ -55,15 +55,11 @@ class CpuExecutable : public Executable {
|
||||
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
|
||||
~CpuExecutable() override {}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> ExecuteOnStream(
|
||||
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) override;
|
||||
|
||||
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments) override;
|
||||
|
||||
// This should be called after set_ir_module_string.
|
||||
const string& ir_module_string() const { return ir_module_string_; }
|
||||
|
||||
@ -86,16 +82,6 @@ class CpuExecutable : public Executable {
|
||||
const BufferAssignment& buffer_assignment() const { return *assignment_; }
|
||||
|
||||
private:
|
||||
// This is for sharing the code between ExecuteOnStream and
|
||||
// ExecuteAsyncOnStream.
|
||||
//
|
||||
// Notice that it's tricky to use correctly, as the profile object (when it
|
||||
// exists) must out-live the task.
|
||||
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamImpl(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile);
|
||||
|
||||
// Creates an array suitable for passing as the "buffer_table" argument to the
|
||||
// JIT compiled function pointer.
|
||||
//
|
||||
|
@ -29,6 +29,38 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
StatusOr<ScopedShapedBuffer> result =
|
||||
ExecuteAsyncOnStream(run_options, arguments, hlo_execution_profile);
|
||||
Status blocking_status = run_options->stream()->BlockHostUntilDone();
|
||||
TF_RETURN_IF_ERROR(result.status());
|
||||
TF_RETURN_IF_ERROR(blocking_status);
|
||||
return result;
|
||||
}
|
||||
|
||||
StatusOr<ExecutionOutput> Executable::ExecuteOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
StatusOr<ExecutionOutput> result = ExecuteAsyncOnStream(
|
||||
run_options, std::move(arguments), hlo_execution_profile);
|
||||
Status blocking_status = run_options->stream()->BlockHostUntilDone();
|
||||
TF_RETURN_IF_ERROR(result.status());
|
||||
TF_RETURN_IF_ERROR(blocking_status);
|
||||
return result;
|
||||
}
|
||||
|
||||
StatusOr<ExecutionOutput> Executable::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* /*run_options*/,
|
||||
std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> /*arguments*/,
|
||||
HloExecutionProfile* /*hlo_execution_profile*/) {
|
||||
return Unimplemented(
|
||||
"MaybeOwningDeviceMemory version of overload is not implemented ");
|
||||
}
|
||||
|
||||
StatusOr<std::vector<ScopedShapedBuffer>> Executable::ExecuteOnStreams(
|
||||
absl::Span<const ServiceExecutableRunOptions> run_options,
|
||||
absl::Span<const absl::Span<const ShapedBuffer* const>> arguments) {
|
||||
@ -49,8 +81,9 @@ StatusOr<std::vector<ScopedShapedBuffer>> Executable::ExecuteOnStreams(
|
||||
// We cannot BlockHostUntilDone() on the already-launched executions in case
|
||||
// of error, since if the executions communicate, the initially launched
|
||||
// executions may never complete if not all executions are running.
|
||||
TF_ASSIGN_OR_RETURN(auto rv,
|
||||
ExecuteAsyncOnStream(&run_options[i], arguments[i]));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto rv, ExecuteAsyncOnStream(&run_options[i], arguments[i],
|
||||
/*hlo_execution_profile=*/nullptr));
|
||||
return_values.push_back(std::move(rv));
|
||||
}
|
||||
for (const auto& options : run_options) {
|
||||
|
@ -123,16 +123,10 @@ class Executable {
|
||||
// enabled.
|
||||
//
|
||||
// Returns a shaped buffer containing the result of the computation.
|
||||
virtual StatusOr<ScopedShapedBuffer> ExecuteOnStream(
|
||||
StatusOr<ScopedShapedBuffer> ExecuteOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) = 0;
|
||||
|
||||
// Same as ExecuteOnStream(), but this call is non-blocking and returns as
|
||||
// soon as all of the operations are enqueued for launch on the stream.
|
||||
virtual StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments) = 0;
|
||||
HloExecutionProfile* hlo_execution_profile);
|
||||
|
||||
// Starts the given program executing on the given stream/executor.
|
||||
//
|
||||
@ -143,20 +137,31 @@ class Executable {
|
||||
//
|
||||
// If an input is donated to XLA but is not reused as output, it is returned
|
||||
// as an leftover buffer for the caller to release.
|
||||
virtual StatusOr<ExecutionOutput> ExecuteOnStream(
|
||||
//
|
||||
// This call should be non-blocking and may return as soon as all of the
|
||||
// operations are enqueued for launch on the stream. Note that some
|
||||
// implementations may in fact block or may block in some circumstances (e.g.,
|
||||
// when profiling); i.e., asynchronous is a "may" not a "must".
|
||||
//
|
||||
// If the hlo_execution_profile is provided as non-nullptr, profiling will be
|
||||
// enabled. Note that profiling is tricky to use correctly, as the profiling
|
||||
// objects (when they exist) must out-live the task.
|
||||
virtual StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) = 0;
|
||||
|
||||
// Same as ExecuteAsyncOnStream(), but blocks waiting for the computation to
|
||||
// complete.
|
||||
StatusOr<ExecutionOutput> ExecuteOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
return Unimplemented(
|
||||
"MaybeOwningDeviceMemory version of overload is not implemented ");
|
||||
}
|
||||
HloExecutionProfile* hlo_execution_profile);
|
||||
|
||||
virtual StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments) {
|
||||
return Unimplemented(
|
||||
"MaybeOwningDeviceMemory version of overload is not implemented ");
|
||||
}
|
||||
std::vector<ShapeTree<xla::MaybeOwningDeviceMemory>> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile);
|
||||
|
||||
// Same as ExecuteOnStream(), but runs this executable on multiple
|
||||
// streams. arguments[i] contains the arguments to the execution on
|
||||
|
@ -405,25 +405,16 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::Execute(
|
||||
return std::move(shaped_buffer);
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
|
||||
StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
// TODO(b/134086343): ExecuteOnStream should not be async according to the
|
||||
// documentation, instead ExecuteAsyncOnStream should be used.
|
||||
return Execute(run_options, arguments, hlo_execution_profile,
|
||||
/*block_host_until_done=*/
|
||||
!run_options->allocator()->AllowsAsynchronousDeallocation());
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments) {
|
||||
se::DeviceMemoryAllocator* memory_allocator = run_options->allocator();
|
||||
// Force synchronous execution if the allocator requires it.
|
||||
bool block_host_until_done =
|
||||
!memory_allocator->AllowsAsynchronousDeallocation();
|
||||
return Execute(run_options, arguments, nullptr, block_host_until_done);
|
||||
return Execute(run_options, arguments, hlo_execution_profile,
|
||||
block_host_until_done);
|
||||
}
|
||||
|
||||
const InstructionValueSet& GpuExecutable::GetRootValueSet() const {
|
||||
|
@ -80,17 +80,13 @@ class GpuExecutable : public Executable {
|
||||
// compilation is left up to the GPU driver.
|
||||
const std::vector<uint8>& binary() const { return binary_; }
|
||||
|
||||
// ExecuteOnStream will fail if the compute capability of the stream doesn't
|
||||
// match the compute capability passed to this object's constructor.
|
||||
StatusOr<ScopedShapedBuffer> ExecuteOnStream(
|
||||
// ExecuteAsyncOnStream will fail if the compute capability of the stream
|
||||
// doesn't match the compute capability passed to this object's constructor.
|
||||
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) override;
|
||||
|
||||
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments) override;
|
||||
|
||||
std::shared_ptr<const BufferAssignment> GetBufferAssignment() const {
|
||||
return assignment_;
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ InterpreterExecutable::InterpreterExecutable(
|
||||
|
||||
InterpreterExecutable::~InterpreterExecutable() {}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
|
||||
StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
@ -122,13 +122,6 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"ExecuteAsyncOnStream is not yet supported on Interpreter.");
|
||||
}
|
||||
|
||||
/*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) {
|
||||
if (shape.IsOpaque()) {
|
||||
return sizeof(void*);
|
||||
|
@ -46,16 +46,12 @@ class InterpreterExecutable : public Executable {
|
||||
std::unique_ptr<HloEvaluator> evaluator);
|
||||
~InterpreterExecutable() override;
|
||||
|
||||
StatusOr<ScopedShapedBuffer> ExecuteOnStream(
|
||||
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) override
|
||||
LOCKS_EXCLUDED(evaluator_lock_);
|
||||
|
||||
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments) override;
|
||||
|
||||
static int64 ShapeSizeBytes(const Shape& shape);
|
||||
|
||||
protected:
|
||||
|
@ -462,7 +462,8 @@ Service::ExecuteParallelAndRegisterResult(
|
||||
// Asynchronously launch the computation.
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
|
||||
executables[i]->ExecuteAsyncOnStream(
|
||||
&run_options, arguments[i][replica]));
|
||||
&run_options, arguments[i][replica],
|
||||
/*hlo_execution_profile=*/nullptr));
|
||||
|
||||
if (replica == 0 && profile != nullptr) {
|
||||
streams.back()->ThenStopTimer(timers.back().get());
|
||||
|
Loading…
Reference in New Issue
Block a user