[XLA] Support aliasing in HloRunner
Fixes the module execution with `run_hlo_module` in cases where that module has aliasing. PiperOrigin-RevId: 315799506 Change-Id: I8e8e956f0742e7a72b539147c9c7131ee964626a
This commit is contained in:
parent
686606eb26
commit
1936a8120d
|
@ -124,6 +124,17 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<ExecutionOutput> Executable::ExecuteOnStreamWrapper(
|
||||||
|
const ServiceExecutableRunOptions* run_options,
|
||||||
|
std::vector<ExecutionInput> arguments) {
|
||||||
|
StatusOr<ExecutionOutput> result =
|
||||||
|
ExecuteAsyncOnStreamWrapper(run_options, std::move(arguments));
|
||||||
|
Status block_status = run_options->stream()->BlockHostUntilDone();
|
||||||
|
TF_RETURN_IF_ERROR(result.status());
|
||||||
|
TF_RETURN_IF_ERROR(block_status);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
struct ExecuteAsyncOnStreamWrapperState {
|
struct ExecuteAsyncOnStreamWrapperState {
|
||||||
ExecutionProfile* profile;
|
ExecutionProfile* profile;
|
||||||
std::shared_ptr<se::Timer> timer;
|
std::shared_ptr<se::Timer> timer;
|
||||||
|
|
|
@ -276,6 +276,10 @@ class Executable {
|
||||||
const ServiceExecutableRunOptions* run_options,
|
const ServiceExecutableRunOptions* run_options,
|
||||||
absl::Span<const ShapedBuffer* const> arguments);
|
absl::Span<const ShapedBuffer* const> arguments);
|
||||||
|
|
||||||
|
StatusOr<ExecutionOutput> ExecuteOnStreamWrapper(
|
||||||
|
const ServiceExecutableRunOptions* run_options,
|
||||||
|
std::vector<ExecutionInput> arguments);
|
||||||
|
|
||||||
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamWrapper(
|
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamWrapper(
|
||||||
const ServiceExecutableRunOptions* run_options,
|
const ServiceExecutableRunOptions* run_options,
|
||||||
absl::Span<const ShapedBuffer* const> arguments);
|
absl::Span<const ShapedBuffer* const> arguments);
|
||||||
|
|
|
@ -22,9 +22,11 @@ limitations under the License.
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/compiler/xla/layout_util.h"
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/executable.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module_group.h"
|
#include "tensorflow/compiler/xla/service/hlo_module_group.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
@ -144,13 +146,13 @@ StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
|
||||||
ExecutionProfile* profile) {
|
ExecutionProfile* profile) {
|
||||||
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
|
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
|
||||||
TransferLiteralsToDevice(arguments));
|
TransferLiteralsToDevice(arguments));
|
||||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
|
TF_ASSIGN_OR_RETURN(ExecutionOutput result,
|
||||||
ExecuteWithDeviceBuffers(
|
ExecuteWithDeviceBuffers(
|
||||||
/*module=*/std::move(module),
|
/*module=*/std::move(module),
|
||||||
/*arguments=*/argument_buffers,
|
/*arguments=*/argument_buffers,
|
||||||
/*run_hlo_passes=*/run_hlo_passes,
|
/*run_hlo_passes=*/run_hlo_passes,
|
||||||
/*profile=*/profile));
|
/*profile=*/profile));
|
||||||
return TransferLiteralFromDevice(result);
|
return TransferLiteralFromDevice(result.Result());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
|
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
|
||||||
|
@ -171,72 +173,60 @@ StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable,
|
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable,
|
||||||
absl::Span<const Literal* const> arguments,
|
absl::Span<const Literal> arguments,
|
||||||
ExecutionProfile* profile) {
|
ExecutionProfile* profile) {
|
||||||
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
|
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
|
||||||
TransferLiteralsToDevice(arguments));
|
TransferLiteralsToDevice(arguments));
|
||||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
|
TF_ASSIGN_OR_RETURN(ExecutionOutput result,
|
||||||
ExecuteWithDeviceBuffers(
|
ExecuteWithDeviceBuffers(
|
||||||
/*executable=*/executable.get(),
|
/*executable=*/executable.get(),
|
||||||
/*arguments=*/argument_buffers,
|
/*arguments=*/argument_buffers,
|
||||||
/*profile=*/profile));
|
/*profile=*/profile));
|
||||||
return TransferLiteralFromDevice(result);
|
return TransferLiteralFromDevice(result.Result());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable,
|
// Convert the owning buffer of inputs into a (partially) owning vector of
|
||||||
absl::Span<const Literal> arguments,
|
// ExecutionInputs, and an owning vector of `OwningDeviceMemory`'s.
|
||||||
ExecutionProfile* profile) {
|
static std::vector<ExecutionInput> ExecutionInputsFromScopedShapedBuffers(
|
||||||
// Construct a vector of plain pointers for the arguments.
|
absl::Span<ScopedShapedBuffer const> inputs,
|
||||||
std::vector<const Literal*> argument_pointers;
|
HloInputOutputAliasConfig alias_config, int device_ordinal,
|
||||||
argument_pointers.reserve(arguments.size());
|
se::DeviceMemoryAllocator* allocator) {
|
||||||
for (const auto& argument : arguments) {
|
std::vector<ExecutionInput> execution_inputs;
|
||||||
argument_pointers.push_back(&argument);
|
std::vector<se::OwningDeviceMemory> owned_args;
|
||||||
|
|
||||||
|
for (int param_num = 0; param_num < inputs.size(); param_num++) {
|
||||||
|
const ScopedShapedBuffer& input_buffer = inputs[param_num];
|
||||||
|
ShapeTree<MaybeOwningDeviceMemory> buffer_tree(
|
||||||
|
input_buffer.on_device_shape());
|
||||||
|
|
||||||
|
input_buffer.buffers().ForEachElement(
|
||||||
|
[&](const ShapeIndex& index,
|
||||||
|
const se::DeviceMemoryBase& execution_input_buffer) {
|
||||||
|
if (alias_config.ParameterHasAlias(param_num, index)) {
|
||||||
|
// Store owned.
|
||||||
|
*buffer_tree.mutable_element(index) = se::OwningDeviceMemory{
|
||||||
|
execution_input_buffer, device_ordinal, allocator};
|
||||||
|
} else {
|
||||||
|
// Store unowned.
|
||||||
|
*buffer_tree.mutable_element(index) = execution_input_buffer;
|
||||||
}
|
}
|
||||||
return Execute(
|
});
|
||||||
/*module=*/std::move(executable),
|
execution_inputs.emplace_back(std::move(buffer_tree));
|
||||||
/*arguments=*/argument_pointers,
|
}
|
||||||
/*profile=*/profile);
|
return execution_inputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
StatusOr<ExecutionOutput> HloRunner::ExecuteWithDeviceBuffers(
|
||||||
std::unique_ptr<HloModule> module,
|
std::unique_ptr<HloModule> module,
|
||||||
absl::Span<const ShapedBuffer* const> arguments, bool run_hlo_passes,
|
absl::Span<ScopedShapedBuffer const> arguments, bool run_hlo_passes,
|
||||||
ExecutionProfile* profile) {
|
ExecutionProfile* profile) {
|
||||||
// Get service run options.
|
|
||||||
se::Stream stream(backend().default_stream_executor());
|
|
||||||
stream.Init();
|
|
||||||
ServiceExecutableRunOptions service_run_options =
|
|
||||||
GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
|
|
||||||
nullptr, RunId());
|
|
||||||
service_run_options.mutable_run_options()->set_execution_profile(profile);
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||||
CreateExecutable(std::move(module), run_hlo_passes));
|
CreateExecutable(std::move(module), run_hlo_passes));
|
||||||
TF_ASSIGN_OR_RETURN(
|
return ExecuteWithDeviceBuffers(executable.get(), arguments, profile);
|
||||||
ScopedShapedBuffer retval,
|
|
||||||
executable->ExecuteOnStreamWrapper(&service_run_options, arguments));
|
|
||||||
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
|
|
||||||
return std::move(retval);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
StatusOr<ExecutionOutput> HloRunner::ExecuteWithDeviceBuffers(
|
||||||
std::unique_ptr<HloModule> module,
|
Executable* executable, absl::Span<ScopedShapedBuffer const> arguments,
|
||||||
absl::Span<const ScopedShapedBuffer> arguments, bool run_hlo_passes,
|
|
||||||
ExecutionProfile* profile) {
|
|
||||||
std::vector<const ShapedBuffer*> argument_pointers;
|
|
||||||
argument_pointers.reserve(arguments.size());
|
|
||||||
for (const auto& argument : arguments) {
|
|
||||||
argument_pointers.push_back(&argument);
|
|
||||||
}
|
|
||||||
return ExecuteWithDeviceBuffers(
|
|
||||||
/*module=*/std::move(module),
|
|
||||||
/*arguments=*/argument_pointers,
|
|
||||||
/*run_hlo_passes=*/run_hlo_passes,
|
|
||||||
/*profile=*/profile);
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
|
||||||
Executable* executable, absl::Span<const ShapedBuffer* const> arguments,
|
|
||||||
ExecutionProfile* profile) {
|
ExecutionProfile* profile) {
|
||||||
// Get service run options.
|
// Get service run options.
|
||||||
se::Stream stream(backend().default_stream_executor());
|
se::Stream stream(backend().default_stream_executor());
|
||||||
|
@ -246,27 +236,19 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
||||||
nullptr, RunId());
|
nullptr, RunId());
|
||||||
service_run_options.mutable_run_options()->set_execution_profile(profile);
|
service_run_options.mutable_run_options()->set_execution_profile(profile);
|
||||||
|
|
||||||
|
std::vector<ExecutionInput> execution_arguments =
|
||||||
|
ExecutionInputsFromScopedShapedBuffers(
|
||||||
|
arguments, executable->module().input_output_alias_config(),
|
||||||
|
stream.parent()->device_ordinal(), stream.parent()->GetAllocator());
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
ScopedShapedBuffer retval,
|
ExecutionOutput retval,
|
||||||
executable->ExecuteOnStreamWrapper(&service_run_options, arguments));
|
executable->ExecuteOnStreamWrapper(&service_run_options,
|
||||||
|
std::move(execution_arguments)));
|
||||||
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
|
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
|
||||||
return std::move(retval);
|
return std::move(retval);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
|
|
||||||
Executable* executable, absl::Span<const ScopedShapedBuffer> arguments,
|
|
||||||
ExecutionProfile* profile) {
|
|
||||||
std::vector<const ShapedBuffer*> argument_pointers;
|
|
||||||
argument_pointers.reserve(arguments.size());
|
|
||||||
for (const auto& argument : arguments) {
|
|
||||||
argument_pointers.push_back(&argument);
|
|
||||||
}
|
|
||||||
return ExecuteWithDeviceBuffers(
|
|
||||||
/*executable=*/std::move(executable),
|
|
||||||
/*arguments=*/argument_pointers,
|
|
||||||
/*profile=*/profile);
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
|
StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
|
||||||
std::unique_ptr<HloModule> module, const ReplicatedExecuteOptions& options,
|
std::unique_ptr<HloModule> module, const ReplicatedExecuteOptions& options,
|
||||||
DeviceAssignment* device_assignment) {
|
DeviceAssignment* device_assignment) {
|
||||||
|
|
|
@ -134,35 +134,19 @@ class HloRunner {
|
||||||
bool run_hlo_passes = true,
|
bool run_hlo_passes = true,
|
||||||
ExecutionProfile* profile = nullptr);
|
ExecutionProfile* profile = nullptr);
|
||||||
|
|
||||||
StatusOr<Literal> Execute(std::unique_ptr<Executable> executable,
|
|
||||||
absl::Span<const Literal* const> arguments,
|
|
||||||
ExecutionProfile* profile = nullptr);
|
|
||||||
|
|
||||||
StatusOr<Literal> Execute(std::unique_ptr<Executable> executable,
|
StatusOr<Literal> Execute(std::unique_ptr<Executable> executable,
|
||||||
absl::Span<const Literal> arguments,
|
absl::Span<const Literal> arguments,
|
||||||
ExecutionProfile* profile = nullptr);
|
ExecutionProfile* profile = nullptr);
|
||||||
|
|
||||||
// As Execute(), but accepts and returns device buffers instead of host
|
// As Execute(), but accepts and returns device buffers instead of host
|
||||||
// buffers.
|
// buffers.
|
||||||
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
|
StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers(
|
||||||
std::unique_ptr<HloModule> module,
|
std::unique_ptr<HloModule> module,
|
||||||
absl::Span<const ShapedBuffer* const> arguments,
|
absl::Span<ScopedShapedBuffer const> arguments,
|
||||||
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
|
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
|
||||||
|
|
||||||
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
|
StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers(
|
||||||
std::unique_ptr<HloModule> module,
|
Executable* executable, absl::Span<ScopedShapedBuffer const> arguments,
|
||||||
absl::Span<const ScopedShapedBuffer> arguments,
|
|
||||||
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
|
|
||||||
|
|
||||||
// In the following two calls, "executable" is not a unique_ptr to allow
|
|
||||||
// reuse of the Executable. This call may update the profile information in
|
|
||||||
// *executable.
|
|
||||||
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
|
|
||||||
Executable* executable, absl::Span<const ShapedBuffer* const> arguments,
|
|
||||||
ExecutionProfile* profile = nullptr);
|
|
||||||
|
|
||||||
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
|
|
||||||
Executable* executable, absl::Span<const ScopedShapedBuffer> arguments,
|
|
||||||
ExecutionProfile* profile = nullptr);
|
ExecutionProfile* profile = nullptr);
|
||||||
|
|
||||||
// Creates an executable object given an HLO module. If run_hlo_passes is
|
// Creates an executable object given an HLO module. If run_hlo_passes is
|
||||||
|
|
|
@ -421,9 +421,6 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
|
||||||
std::move(module_or_status.ValueOrDie());
|
std::move(module_or_status.ValueOrDie());
|
||||||
|
|
||||||
fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie();
|
fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie();
|
||||||
absl::c_transform(
|
|
||||||
fake_arguments[i], std::back_inserter(fake_argument_ptrs[i]),
|
|
||||||
[](const Literal& literal) { return const_cast<Literal*>(&literal); });
|
|
||||||
|
|
||||||
if (profiles != nullptr) {
|
if (profiles != nullptr) {
|
||||||
// We have to enable HLO profiling since otherwise currently the
|
// We have to enable HLO profiling since otherwise currently the
|
||||||
|
@ -457,7 +454,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
|
||||||
absl::optional<Literal> canonical_output;
|
absl::optional<Literal> canonical_output;
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
StatusOr<Literal> output =
|
StatusOr<Literal> output =
|
||||||
test_runner_.Execute(std::move(executables[i]), fake_argument_ptrs[i],
|
test_runner_.Execute(std::move(executables[i]), fake_arguments[i],
|
||||||
/*profile=*/&((*profiles)[i]));
|
/*profile=*/&((*profiles)[i]));
|
||||||
if (!output.ok()) {
|
if (!output.ok()) {
|
||||||
return ::testing::AssertionFailure() << output.status().error_message();
|
return ::testing::AssertionFailure() << output.status().error_message();
|
||||||
|
|
Loading…
Reference in New Issue