[XLA] Support aliasing in HloRunner

Fixes the module execution with `run_hlo_module` in cases where that module has
aliasing.

PiperOrigin-RevId: 315799506
Change-Id: I8e8e956f0742e7a72b539147c9c7131ee964626a
This commit is contained in:
George Karpenkov 2020-06-10 17:13:29 -07:00 committed by TensorFlower Gardener
parent 686606eb26
commit 1936a8120d
5 changed files with 68 additions and 90 deletions

View File

@ -124,6 +124,17 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
return result;
}
StatusOr<ExecutionOutput> Executable::ExecuteOnStreamWrapper(
const ServiceExecutableRunOptions* run_options,
std::vector<ExecutionInput> arguments) {
StatusOr<ExecutionOutput> result =
ExecuteAsyncOnStreamWrapper(run_options, std::move(arguments));
Status block_status = run_options->stream()->BlockHostUntilDone();
TF_RETURN_IF_ERROR(result.status());
TF_RETURN_IF_ERROR(block_status);
return result;
}
struct ExecuteAsyncOnStreamWrapperState {
ExecutionProfile* profile;
std::shared_ptr<se::Timer> timer;

View File

@ -276,6 +276,10 @@ class Executable {
const ServiceExecutableRunOptions* run_options,
absl::Span<const ShapedBuffer* const> arguments);
StatusOr<ExecutionOutput> ExecuteOnStreamWrapper(
const ServiceExecutableRunOptions* run_options,
std::vector<ExecutionInput> arguments);
StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamWrapper(
const ServiceExecutableRunOptions* run_options,
absl::Span<const ShapedBuffer* const> arguments);

View File

@ -22,9 +22,11 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module_group.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/platform/logging.h"
@ -144,13 +146,13 @@ StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
TransferLiteralsToDevice(arguments));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
TF_ASSIGN_OR_RETURN(ExecutionOutput result,
ExecuteWithDeviceBuffers(
/*module=*/std::move(module),
/*arguments=*/argument_buffers,
/*run_hlo_passes=*/run_hlo_passes,
/*profile=*/profile));
return TransferLiteralFromDevice(result);
return TransferLiteralFromDevice(result.Result());
}
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
@ -171,72 +173,60 @@ StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
}
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable,
absl::Span<const Literal* const> arguments,
absl::Span<const Literal> arguments,
ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
TransferLiteralsToDevice(arguments));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
TF_ASSIGN_OR_RETURN(ExecutionOutput result,
ExecuteWithDeviceBuffers(
/*executable=*/executable.get(),
/*arguments=*/argument_buffers,
/*profile=*/profile));
return TransferLiteralFromDevice(result);
return TransferLiteralFromDevice(result.Result());
}
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable,
absl::Span<const Literal> arguments,
ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
std::vector<const Literal*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
argument_pointers.push_back(&argument);
// Convert the owning buffer of inputs into a (partially) owning vector of
// ExecutionInputs, and an owning vector of `OwningDeviceMemory`'s.
static std::vector<ExecutionInput> ExecutionInputsFromScopedShapedBuffers(
absl::Span<ScopedShapedBuffer const> inputs,
HloInputOutputAliasConfig alias_config, int device_ordinal,
se::DeviceMemoryAllocator* allocator) {
std::vector<ExecutionInput> execution_inputs;
std::vector<se::OwningDeviceMemory> owned_args;
for (int param_num = 0; param_num < inputs.size(); param_num++) {
const ScopedShapedBuffer& input_buffer = inputs[param_num];
ShapeTree<MaybeOwningDeviceMemory> buffer_tree(
input_buffer.on_device_shape());
input_buffer.buffers().ForEachElement(
[&](const ShapeIndex& index,
const se::DeviceMemoryBase& execution_input_buffer) {
if (alias_config.ParameterHasAlias(param_num, index)) {
// Store owned.
*buffer_tree.mutable_element(index) = se::OwningDeviceMemory{
execution_input_buffer, device_ordinal, allocator};
} else {
// Store unowned.
*buffer_tree.mutable_element(index) = execution_input_buffer;
}
});
execution_inputs.emplace_back(std::move(buffer_tree));
}
return Execute(
/*module=*/std::move(executable),
/*arguments=*/argument_pointers,
/*profile=*/profile);
return execution_inputs;
}
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
StatusOr<ExecutionOutput> HloRunner::ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
absl::Span<const ShapedBuffer* const> arguments, bool run_hlo_passes,
absl::Span<ScopedShapedBuffer const> arguments, bool run_hlo_passes,
ExecutionProfile* profile) {
// Get service run options.
se::Stream stream(backend().default_stream_executor());
stream.Init();
ServiceExecutableRunOptions service_run_options =
GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream,
nullptr, RunId());
service_run_options.mutable_run_options()->set_execution_profile(profile);
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
CreateExecutable(std::move(module), run_hlo_passes));
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer retval,
executable->ExecuteOnStreamWrapper(&service_run_options, arguments));
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
return std::move(retval);
return ExecuteWithDeviceBuffers(executable.get(), arguments, profile);
}
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
absl::Span<const ScopedShapedBuffer> arguments, bool run_hlo_passes,
ExecutionProfile* profile) {
std::vector<const ShapedBuffer*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
argument_pointers.push_back(&argument);
}
return ExecuteWithDeviceBuffers(
/*module=*/std::move(module),
/*arguments=*/argument_pointers,
/*run_hlo_passes=*/run_hlo_passes,
/*profile=*/profile);
}
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
Executable* executable, absl::Span<const ShapedBuffer* const> arguments,
StatusOr<ExecutionOutput> HloRunner::ExecuteWithDeviceBuffers(
Executable* executable, absl::Span<ScopedShapedBuffer const> arguments,
ExecutionProfile* profile) {
// Get service run options.
se::Stream stream(backend().default_stream_executor());
@ -246,27 +236,19 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
nullptr, RunId());
service_run_options.mutable_run_options()->set_execution_profile(profile);
std::vector<ExecutionInput> execution_arguments =
ExecutionInputsFromScopedShapedBuffers(
arguments, executable->module().input_output_alias_config(),
stream.parent()->device_ordinal(), stream.parent()->GetAllocator());
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer retval,
executable->ExecuteOnStreamWrapper(&service_run_options, arguments));
ExecutionOutput retval,
executable->ExecuteOnStreamWrapper(&service_run_options,
std::move(execution_arguments)));
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
return std::move(retval);
}
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
Executable* executable, absl::Span<const ScopedShapedBuffer> arguments,
ExecutionProfile* profile) {
std::vector<const ShapedBuffer*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
argument_pointers.push_back(&argument);
}
return ExecuteWithDeviceBuffers(
/*executable=*/std::move(executable),
/*arguments=*/argument_pointers,
/*profile=*/profile);
}
StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
std::unique_ptr<HloModule> module, const ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment) {

View File

@ -134,35 +134,19 @@ class HloRunner {
bool run_hlo_passes = true,
ExecutionProfile* profile = nullptr);
StatusOr<Literal> Execute(std::unique_ptr<Executable> executable,
absl::Span<const Literal* const> arguments,
ExecutionProfile* profile = nullptr);
StatusOr<Literal> Execute(std::unique_ptr<Executable> executable,
absl::Span<const Literal> arguments,
ExecutionProfile* profile = nullptr);
// As Execute(), but accepts and returns device buffers instead of host
// buffers.
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
absl::Span<const ShapedBuffer* const> arguments,
absl::Span<ScopedShapedBuffer const> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
std::unique_ptr<HloModule> module,
absl::Span<const ScopedShapedBuffer> arguments,
bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
// In the following two calls, "executable" is not a unique_ptr to allow
// reuse of the Executable. This call may update the profile information in
// *executable.
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
Executable* executable, absl::Span<const ShapedBuffer* const> arguments,
ExecutionProfile* profile = nullptr);
StatusOr<ScopedShapedBuffer> ExecuteWithDeviceBuffers(
Executable* executable, absl::Span<const ScopedShapedBuffer> arguments,
StatusOr<ExecutionOutput> ExecuteWithDeviceBuffers(
Executable* executable, absl::Span<ScopedShapedBuffer const> arguments,
ExecutionProfile* profile = nullptr);
// Creates an executable object given an HLO module. If run_hlo_passes is

View File

@ -421,9 +421,6 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
std::move(module_or_status.ValueOrDie());
fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie();
absl::c_transform(
fake_arguments[i], std::back_inserter(fake_argument_ptrs[i]),
[](const Literal& literal) { return const_cast<Literal*>(&literal); });
if (profiles != nullptr) {
// We have to enable HLO profiling since otherwise currently the
@ -457,7 +454,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
absl::optional<Literal> canonical_output;
for (int i = 0; i < n; ++i) {
StatusOr<Literal> output =
test_runner_.Execute(std::move(executables[i]), fake_argument_ptrs[i],
test_runner_.Execute(std::move(executables[i]), fake_arguments[i],
/*profile=*/&((*profiles)[i]));
if (!output.ok()) {
return ::testing::AssertionFailure() << output.status().error_message();