[XLA] Support aliasing in HloRunner
Fixes the module execution with `run_hlo_module` in cases where that module has aliasing. PiperOrigin-RevId: 315799506 Change-Id: I8e8e956f0742e7a72b539147c9c7131ee964626a
This commit is contained in:
parent
686606eb26
commit
1936a8120d
|
@ -124,6 +124,17 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
|
|||
return result;
|
||||
}
|
||||
|
||||
StatusOr<ExecutionOutput> Executable::ExecuteOnStreamWrapper(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ExecutionInput> arguments) {
|
||||
StatusOr<ExecutionOutput> result =
|
||||
ExecuteAsyncOnStreamWrapper(run_options, std::move(arguments));
|
||||
Status block_status = run_options->stream()->BlockHostUntilDone();
|
||||
TF_RETURN_IF_ERROR(result.status());
|
||||
TF_RETURN_IF_ERROR(block_status);
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ExecuteAsyncOnStreamWrapperState {
|
||||
ExecutionProfile* profile;
|
||||
std::shared_ptr<se::Timer> timer;
|
||||
|
|
|
@ -276,6 +276,10 @@ class Executable {
|
|||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments);
|
||||
|
||||
StatusOr<ExecutionOutput> ExecuteOnStreamWrapper(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ExecutionInput> arguments);
|
||||
|
||||
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamWrapper(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments);
|
||||
|
|
|
@ -22,9 +22,11 @@ limitations under the License.
|
|||
#include "absl/memory/memory.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_group.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
@ -144,13 +146,13 @@ StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
|
|||
ExecutionProfile* profile) {
|
||||
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
|
||||
TransferLiteralsToDevice(arguments));
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
|
||||
TF_ASSIGN_OR_RETURN(ExecutionOutput result,
|
||||
ExecuteWithDeviceBuffers(
|
||||
/*module=*/std::move(module),
|
||||
/*arguments=*/argument_buffers,
|
||||
/*run_hlo_passes=*/run_hlo_passes,
|
||||
/*profile=*/profile));
|
||||
return TransferLiteralFromDevice(result);
|
||||
return TransferLiteralFromDevice(result.Result());
|
||||
}
|
||||
|
||||
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
|
||||
|
@ -171,72 +173,60 @@ StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
|
|||
}
|
||||
|
||||
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable,
|
||||
absl::Span<const Literal* const> arguments,
|
||||
absl::Span<const Literal> arguments,
|
||||
ExecutionProfile* profile) {
|
||||
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
|
||||
TransferLiteralsToDevice(arguments));
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
|
||||
TF_ASSIGN_OR_RETURN(ExecutionOutput result,
|
||||
ExecuteWithDeviceBuffers(
|
||||
/*executable=*/executable.get(),
|
||||
/*arguments=*/argument_buffers,
|
||||
/*profile=*/profile));
|
||||
return TransferLiteralFromDevice(result);
|
||||
return TransferLiteralFromDevice(result.Result());
|
||||
}
|
||||
|
||||
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable,
|
||||
absl::Span<const Literal> arguments,
|
||||
ExecutionProfile* profile) {
|
||||
// Construct a vector of plain pointers for the arguments.
|
||||
std::vector<const Literal*> argument_pointers;
|
||||
argument_pointers.reserve(arguments.size());
|
||||
for (const auto& argument : arguments) {
|
||||
argument_pointers.push_back(&argument);
|
||||
// Convert the owning buffer of inputs into a (partially) owning vector of
|
||||
// ExecutionInputs, and an owning vector of `OwningDeviceMemory`'s.
|
||||
static std::vector<ExecutionInput> ExecutionInputsFromScopedShapedBuffers(
|
||||
absl::Span<ScopedShapedBuffer const> inputs,
|
||||
HloInputOutputAliasConfig alias_config, int device_ordinal,
|
||||
se::DeviceMemoryAllocator* allocator) {
|
||||
std::vector<ExecutionInput> execution_inputs;
|
||||
std::vector<se::OwningDeviceMemory> owned_args;
|
||||
|
||||
for (int param_num = 0; param_num < inputs.size(); param_num++) {
|
||||
const ScopedShapedBuffer& input_buffer = inputs[param_num];
|
||||
ShapeTree<MaybeOwningDeviceMemory> buffer_tree(
|
||||
input_buffer.on_device_shape());
|
||||
|
||||
input_buffer.buffers().ForEachElement(
|
||||
[&](const ShapeIndex& index,
|
||||
const se::DeviceMemoryBase& execution_input_buffer) {
|
||||
if (alias_config.ParameterHasAlias(param_num, index)) {
|
||||
// Store owned.
|
||||
*buffer_tree.mutable_element(index) = se::OwningDeviceMemory{
|
||||
execution_input_buffer, device_ordinal, allocator};
|
||||
} else {
|
||||
// Store unowned.
|
||||
*buffer_tree.mutable_element(index) = execution_input_buffer;
|
||||
}
|
||||
});
|
||||
execution_inputs.emplace_back(std::move(buffer_tree));
|
||||
}
|
||||
return Execute(
|
||||
/*module=*/std::move(executable),
|
||||
/*arguments=*/argument_pointers,
|
||||
/*profile=*/profile);
|
||||
return execution_inputs;
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
||||
StatusOr<ExecutionOutput> HloRunner::ExecuteWithDeviceBuffers(
|
||||
std::unique_ptr<HloModule> module,
|
||||
absl::Span<const ShapedBuffer* const> arguments, bool run_hlo_passes,
|
||||
absl::Span<ScopedShapedBuffer const> arguments, bool run_hlo_passes,
|
||||
ExecutionProfile* profile) {
|
||||
// Get service run options.
|
||||
se::Stream stream(backend().default_stream_executor());
|
||||
stream.Init();
|
||||
ServiceExecutableRunOptions service_run_options =
|
||||
GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
|
||||
nullptr, RunId());
|
||||
service_run_options.mutable_run_options()->set_execution_profile(profile);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||
CreateExecutable(std::move(module), run_hlo_passes));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ScopedShapedBuffer retval,
|
||||
executable->ExecuteOnStreamWrapper(&service_run_options, arguments));
|
||||
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
|
||||
return std::move(retval);
|
||||
return ExecuteWithDeviceBuffers(executable.get(), arguments, profile);
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
||||
std::unique_ptr<HloModule> module,
|
||||
absl::Span<const ScopedShapedBuffer> arguments, bool run_hlo_passes,
|
||||
ExecutionProfile* profile) {
|
||||
std::vector<const ShapedBuffer*> argument_pointers;
|
||||
argument_pointers.reserve(arguments.size());
|
||||
for (const auto& argument : arguments) {
|
||||
argument_pointers.push_back(&argument);
|
||||
}
|
||||
return ExecuteWithDeviceBuffers(
|
||||
/*module=*/std::move(module),
|
||||
/*arguments=*/argument_pointers,
|
||||
/*run_hlo_passes=*/run_hlo_passes,
|
||||
/*profile=*/profile);
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
||||
Executable* executable, absl::Span<const ShapedBuffer* const> arguments,
|
||||
StatusOr<ExecutionOutput> HloRunner::ExecuteWithDeviceBuffers(
|
||||
Executable* executable, absl::Span<ScopedShapedBuffer const> arguments,
|
||||
ExecutionProfile* profile) {
|
||||
// Get service run options.
|
||||
se::Stream stream(backend().default_stream_executor());
|
||||
|
@ -246,27 +236,19 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
|||
nullptr, RunId());
|
||||
service_run_options.mutable_run_options()->set_execution_profile(profile);
|
||||
|
||||
std::vector<ExecutionInput> execution_arguments =
|
||||
ExecutionInputsFromScopedShapedBuffers(
|
||||
arguments, executable->module().input_output_alias_config(),
|
||||
stream.parent()->device_ordinal(), stream.parent()->GetAllocator());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ScopedShapedBuffer retval,
|
||||
executable->ExecuteOnStreamWrapper(&service_run_options, arguments));
|
||||
ExecutionOutput retval,
|
||||
executable->ExecuteOnStreamWrapper(&service_run_options,
|
||||
std::move(execution_arguments)));
|
||||
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
|
||||
return std::move(retval);
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
||||
Executable* executable, absl::Span<const ScopedShapedBuffer> arguments,
|
||||
ExecutionProfile* profile) {
|
||||
std::vector<const ShapedBuffer*> argument_pointers;
|
||||
argument_pointers.reserve(arguments.size());
|
||||
for (const auto& argument : arguments) {
|
||||
argument_pointers.push_back(&argument);
|
||||
}
|
||||
return ExecuteWithDeviceBuffers(
|
||||
/*executable=*/std::move(executable),
|
||||
/*arguments=*/argument_pointers,
|
||||
/*profile=*/profile);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
|
||||
std::unique_ptr<HloModule> module, const ReplicatedExecuteOptions& options,
|
||||
DeviceAssignment* device_assignment) {
|
||||
|
|
|
@ -134,35 +134,19 @@ class HloRunner {
|
|||
bool run_hlo_passes = true,
|
||||
ExecutionProfile* profile = nullptr);
|
||||
|
||||
StatusOr<Literal> Execute(std::unique_ptr<Executable> executable,
|
||||
absl::Span<const Literal* const> arguments,
|
||||
ExecutionProfile* profile = nullptr);
|
||||
|
||||
StatusOr<Literal> Execute(std::unique_ptr<Executable> executable,
|
||||
absl::Span<const Literal> arguments,
|
||||
ExecutionProfile* profile = nullptr);
|
||||
|
||||
// As Execute(), but accepts and returns device buffers instead of host
|
||||
// buffers.
|
||||
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
|
||||
StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers(
|
||||
std::unique_ptr<HloModule> module,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
absl::Span<ScopedShapedBuffer const> arguments,
|
||||
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
|
||||
|
||||
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
|
||||
std::unique_ptr<HloModule> module,
|
||||
absl::Span<const ScopedShapedBuffer> arguments,
|
||||
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
|
||||
|
||||
// In the following two calls, "executable" is not a unique_ptr to allow
|
||||
// reuse of the Executable. This call may update the profile information in
|
||||
// *executable.
|
||||
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
|
||||
Executable* executable, absl::Span<const ShapedBuffer* const> arguments,
|
||||
ExecutionProfile* profile = nullptr);
|
||||
|
||||
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
|
||||
Executable* executable, absl::Span<const ScopedShapedBuffer> arguments,
|
||||
StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers(
|
||||
Executable* executable, absl::Span<ScopedShapedBuffer const> arguments,
|
||||
ExecutionProfile* profile = nullptr);
|
||||
|
||||
// Creates an executable object given an HLO module. If run_hlo_passes is
|
||||
|
|
|
@ -421,9 +421,6 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
|
|||
std::move(module_or_status.ValueOrDie());
|
||||
|
||||
fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie();
|
||||
absl::c_transform(
|
||||
fake_arguments[i], std::back_inserter(fake_argument_ptrs[i]),
|
||||
[](const Literal& literal) { return const_cast<Literal*>(&literal); });
|
||||
|
||||
if (profiles != nullptr) {
|
||||
// We have to enable HLO profiling since otherwise currently the
|
||||
|
@ -457,7 +454,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
|
|||
absl::optional<Literal> canonical_output;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
StatusOr<Literal> output =
|
||||
test_runner_.Execute(std::move(executables[i]), fake_argument_ptrs[i],
|
||||
test_runner_.Execute(std::move(executables[i]), fake_arguments[i],
|
||||
/*profile=*/&((*profiles)[i]));
|
||||
if (!output.ok()) {
|
||||
return ::testing::AssertionFailure() << output.status().error_message();
|
||||
|
|
Loading…
Reference in New Issue