[XLA] Support aliasing in HloRunner

Fixes the module execution with `run_hlo_module` in cases where that module has
aliasing.

PiperOrigin-RevId: 315799506
Change-Id: I8e8e956f0742e7a72b539147c9c7131ee964626a
This commit is contained in:
George Karpenkov 2020-06-10 17:13:29 -07:00 committed by TensorFlower Gardener
parent 686606eb26
commit 1936a8120d
5 changed files with 68 additions and 90 deletions

View File

@ -124,6 +124,17 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
return result; return result;
} }
StatusOr<ExecutionOutput> Executable::ExecuteOnStreamWrapper(
const ServiceExecutableRunOptions* run_options,
std::vector<ExecutionInput> arguments) {
StatusOr<ExecutionOutput> result =
ExecuteAsyncOnStreamWrapper(run_options, std::move(arguments));
Status block_status = run_options->stream()->BlockHostUntilDone();
TF_RETURN_IF_ERROR(result.status());
TF_RETURN_IF_ERROR(block_status);
return result;
}
struct ExecuteAsyncOnStreamWrapperState { struct ExecuteAsyncOnStreamWrapperState {
ExecutionProfile* profile; ExecutionProfile* profile;
std::shared_ptr<se::Timer> timer; std::shared_ptr<se::Timer> timer;

View File

@ -276,6 +276,10 @@ class Executable {
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
absl::Span<const ShapedBuffer* const> arguments); absl::Span<const ShapedBuffer* const> arguments);
StatusOr<ExecutionOutput> ExecuteOnStreamWrapper(
const ServiceExecutableRunOptions* run_options,
std::vector<ExecutionInput> arguments);
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamWrapper( StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamWrapper(
const ServiceExecutableRunOptions* run_options, const ServiceExecutableRunOptions* run_options,
absl::Span<const ShapedBuffer* const> arguments); absl::Span<const ShapedBuffer* const> arguments);

View File

@ -22,9 +22,11 @@ limitations under the License.
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/service/hlo_module_group.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -144,13 +146,13 @@ StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
ExecutionProfile* profile) { ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers, TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
TransferLiteralsToDevice(arguments)); TransferLiteralsToDevice(arguments));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, TF_ASSIGN_OR_RETURN(ExecutionOutput result,
ExecuteWithDeviceBuffers( ExecuteWithDeviceBuffers(
/*module=*/std::move(module), /*module=*/std::move(module),
/*arguments=*/argument_buffers, /*arguments=*/argument_buffers,
/*run_hlo_passes=*/run_hlo_passes, /*run_hlo_passes=*/run_hlo_passes,
/*profile=*/profile)); /*profile=*/profile));
return TransferLiteralFromDevice(result); return TransferLiteralFromDevice(result.Result());
} }
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module, StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
@ -171,72 +173,60 @@ StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
} }
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable, StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable,
absl::Span<const Literal* const> arguments, absl::Span<const Literal> arguments,
ExecutionProfile* profile) { ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers, TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
TransferLiteralsToDevice(arguments)); TransferLiteralsToDevice(arguments));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, TF_ASSIGN_OR_RETURN(ExecutionOutput result,
ExecuteWithDeviceBuffers( ExecuteWithDeviceBuffers(
/*executable=*/executable.get(), /*executable=*/executable.get(),
/*arguments=*/argument_buffers, /*arguments=*/argument_buffers,
/*profile=*/profile)); /*profile=*/profile));
return TransferLiteralFromDevice(result); return TransferLiteralFromDevice(result.Result());
} }
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable, // Convert the owning buffer of inputs into a (partially) owning vector of
absl::Span<const Literal> arguments, // ExecutionInputs, and an owning vector of `OwningDeviceMemory`'s.
ExecutionProfile* profile) { static std::vector<ExecutionInput> ExecutionInputsFromScopedShapedBuffers(
// Construct a vector of plain pointers for the arguments. absl::Span<ScopedShapedBuffer const> inputs,
std::vector<const Literal*> argument_pointers; HloInputOutputAliasConfig alias_config, int device_ordinal,
argument_pointers.reserve(arguments.size()); se::DeviceMemoryAllocator* allocator) {
for (const auto& argument : arguments) { std::vector<ExecutionInput> execution_inputs;
argument_pointers.push_back(&argument); std::vector<se::OwningDeviceMemory> owned_args;
for (int param_num = 0; param_num < inputs.size(); param_num++) {
const ScopedShapedBuffer& input_buffer = inputs[param_num];
ShapeTree<MaybeOwningDeviceMemory> buffer_tree(
input_buffer.on_device_shape());
input_buffer.buffers().ForEachElement(
[&](const ShapeIndex& index,
const se::DeviceMemoryBase& execution_input_buffer) {
if (alias_config.ParameterHasAlias(param_num, index)) {
// Store owned.
*buffer_tree.mutable_element(index) = se::OwningDeviceMemory{
execution_input_buffer, device_ordinal, allocator};
} else {
// Store unowned.
*buffer_tree.mutable_element(index) = execution_input_buffer;
}
});
execution_inputs.emplace_back(std::move(buffer_tree));
} }
return Execute( return execution_inputs;
/*module=*/std::move(executable),
/*arguments=*/argument_pointers,
/*profile=*/profile);
} }
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers( StatusOr<ExecutionOutput> HloRunner::ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module, std::unique_ptr<HloModule> module,
absl::Span<const ShapedBuffer* const> arguments, bool run_hlo_passes, absl::Span<ScopedShapedBuffer const> arguments, bool run_hlo_passes,
ExecutionProfile* profile) { ExecutionProfile* profile) {
// Get service run options.
se::Stream stream(backend().default_stream_executor());
stream.Init();
ServiceExecutableRunOptions service_run_options =
GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
nullptr, RunId());
service_run_options.mutable_run_options()->set_execution_profile(profile);
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable, TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
CreateExecutable(std::move(module), run_hlo_passes)); CreateExecutable(std::move(module), run_hlo_passes));
TF_ASSIGN_OR_RETURN( return ExecuteWithDeviceBuffers(executable.get(), arguments, profile);
ScopedShapedBuffer retval,
executable->ExecuteOnStreamWrapper(&service_run_options, arguments));
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
return std::move(retval);
} }
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers( StatusOr<ExecutionOutput> HloRunner::ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module, Executable* executable, absl::Span<ScopedShapedBuffer const> arguments,
absl::Span<const ScopedShapedBuffer> arguments, bool run_hlo_passes,
ExecutionProfile* profile) {
std::vector<const ShapedBuffer*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
argument_pointers.push_back(&argument);
}
return ExecuteWithDeviceBuffers(
/*module=*/std::move(module),
/*arguments=*/argument_pointers,
/*run_hlo_passes=*/run_hlo_passes,
/*profile=*/profile);
}
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
Executable* executable, absl::Span<const ShapedBuffer* const> arguments,
ExecutionProfile* profile) { ExecutionProfile* profile) {
// Get service run options. // Get service run options.
se::Stream stream(backend().default_stream_executor()); se::Stream stream(backend().default_stream_executor());
@ -246,27 +236,19 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
nullptr, RunId()); nullptr, RunId());
service_run_options.mutable_run_options()->set_execution_profile(profile); service_run_options.mutable_run_options()->set_execution_profile(profile);
std::vector<ExecutionInput> execution_arguments =
ExecutionInputsFromScopedShapedBuffers(
arguments, executable->module().input_output_alias_config(),
stream.parent()->device_ordinal(), stream.parent()->GetAllocator());
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer retval, ExecutionOutput retval,
executable->ExecuteOnStreamWrapper(&service_run_options, arguments)); executable->ExecuteOnStreamWrapper(&service_run_options,
std::move(execution_arguments)));
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
return std::move(retval); return std::move(retval);
} }
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
Executable* executable, absl::Span<const ScopedShapedBuffer> arguments,
ExecutionProfile* profile) {
std::vector<const ShapedBuffer*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
argument_pointers.push_back(&argument);
}
return ExecuteWithDeviceBuffers(
/*executable=*/std::move(executable),
/*arguments=*/argument_pointers,
/*profile=*/profile);
}
StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated( StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
std::unique_ptr<HloModule> module, const ReplicatedExecuteOptions& options, std::unique_ptr<HloModule> module, const ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment) { DeviceAssignment* device_assignment) {

View File

@ -134,35 +134,19 @@ class HloRunner {
bool run_hlo_passes = true, bool run_hlo_passes = true,
ExecutionProfile* profile = nullptr); ExecutionProfile* profile = nullptr);
StatusOr<Literal> Execute(std::unique_ptr<Executable> executable,
absl::Span<const Literal* const> arguments,
ExecutionProfile* profile = nullptr);
StatusOr<Literal> Execute(std::unique_ptr<Executable> executable, StatusOr<Literal> Execute(std::unique_ptr<Executable> executable,
absl::Span<const Literal> arguments, absl::Span<const Literal> arguments,
ExecutionProfile* profile = nullptr); ExecutionProfile* profile = nullptr);
// As Execute(), but accepts and returns device buffers instead of host // As Execute(), but accepts and returns device buffers instead of host
// buffers. // buffers.
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers( StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module, std::unique_ptr<HloModule> module,
absl::Span<const ShapedBuffer* const> arguments, absl::Span<ScopedShapedBuffer const> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers( StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module, Executable* executable, absl::Span<ScopedShapedBuffer const> arguments,
absl::Span<const ScopedShapedBuffer> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
// In the following two calls, "executable" is not a unique_ptr to allow
// reuse of the Executable. This call may update the profile information in
// *executable.
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
Executable* executable, absl::Span<const ShapedBuffer* const> arguments,
ExecutionProfile* profile = nullptr);
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
Executable* executable, absl::Span<const ScopedShapedBuffer> arguments,
ExecutionProfile* profile = nullptr); ExecutionProfile* profile = nullptr);
// Creates an executable object given an HLO module. If run_hlo_passes is // Creates an executable object given an HLO module. If run_hlo_passes is

View File

@ -421,9 +421,6 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
std::move(module_or_status.ValueOrDie()); std::move(module_or_status.ValueOrDie());
fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie(); fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie();
absl::c_transform(
fake_arguments[i], std::back_inserter(fake_argument_ptrs[i]),
[](const Literal& literal) { return const_cast<Literal*>(&literal); });
if (profiles != nullptr) { if (profiles != nullptr) {
// We have to enable HLO profiling since otherwise currently the // We have to enable HLO profiling since otherwise currently the
@ -457,7 +454,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
absl::optional<Literal> canonical_output; absl::optional<Literal> canonical_output;
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
StatusOr<Literal> output = StatusOr<Literal> output =
test_runner_.Execute(std::move(executables[i]), fake_argument_ptrs[i], test_runner_.Execute(std::move(executables[i]), fake_arguments[i],
/*profile=*/&((*profiles)[i])); /*profile=*/&((*profiles)[i]));
if (!output.ok()) { if (!output.ok()) {
return ::testing::AssertionFailure() << output.status().error_message(); return ::testing::AssertionFailure() << output.status().error_message();