Make Hlo runner interface so we can add alternative implementations.

PiperOrigin-RevId: 341756589
Change-Id: I267aaa795f38d9ac1f4b4ebd9bba5efc20d2b7ef
This commit is contained in:
A. Unique TensorFlower 2020-11-10 21:07:36 -08:00 committed by TensorFlower Gardener
parent 696102807d
commit bc819d9cf3
6 changed files with 276 additions and 146 deletions

View File

@ -4459,6 +4459,30 @@ tf_cc_test(
],
)
cc_library(
name = "hlo_runner_interface",
srcs = ["hlo_runner_interface.cc"],
hdrs = ["hlo_runner_interface.h"],
deps = [
":compiler",
":computation_placer",
":hlo",
":hlo_module_group",
":hlo_parser",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "hlo_runner",
srcs = ["hlo_runner.cc"],
@ -4471,6 +4495,7 @@ cc_library(
":hlo",
":hlo_module_group",
":hlo_parser",
":hlo_runner_interface",
":transfer_manager",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",

View File

@ -34,58 +34,6 @@ limitations under the License.
namespace xla {
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunner::CreateModuleFromString(const absl::string_view hlo_string,
const DebugOptions& debug_options) {
HloModuleConfig config;
config.set_debug_options(debug_options);
return ParseAndReturnUnverifiedModule(hlo_string, config);
}
namespace {
// Creates an HloModule from the given proto.
StatusOr<std::unique_ptr<HloModule>> HloProtoToModule(
const HloProto& proto, const DebugOptions& debug_options) {
TF_ASSIGN_OR_RETURN(HloModuleConfig config,
HloModule::CreateModuleConfigFromProto(proto.hlo_module(),
debug_options));
TF_ASSIGN_OR_RETURN(auto module,
HloModule::CreateFromProto(proto.hlo_module(), config));
return std::move(module);
}
} // namespace
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunner::ReadModuleFromBinaryProtoFile(const std::string& filename,
const DebugOptions& debug_options) {
HloProto proto;
TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
filename, &proto));
return HloProtoToModule(proto, debug_options);
}
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunner::ReadModuleFromTextProtoFile(const std::string& filename,
const DebugOptions& debug_options) {
HloProto proto;
TF_RETURN_IF_ERROR(
tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto));
return HloProtoToModule(proto, debug_options);
}
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunner::ReadModuleFromHloTextFile(const std::string& filename,
const DebugOptions& debug_options) {
string hlo_string;
TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
filename, &hlo_string));
HloModuleConfig config;
config.set_debug_options(debug_options);
return ParseAndReturnUnverifiedModule(hlo_string, config);
}
HloRunner::HloRunner(se::Platform* platform, int intra_op_parallelism_threads) {
BackendOptions backend_options;
backend_options.set_platform(platform);
@ -155,25 +103,8 @@ StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
return TransferLiteralFromDevice(result.Result());
}
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
absl::Span<const Literal> arguments,
bool run_hlo_passes,
ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
std::vector<const Literal*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
argument_pointers.push_back(&argument);
}
return Execute(
/*module=*/std::move(module),
/*arguments=*/argument_pointers,
/*run_hlo_passes=*/run_hlo_passes,
/*profile=*/profile);
}
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable,
absl::Span<const Literal> arguments,
StatusOr<Literal> HloRunner::ExecuteWithExecutable(
std::unique_ptr<Executable> executable, absl::Span<const Literal> arguments,
ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
TransferLiteralsToDevice(arguments));

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_runner_interface.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@ -42,48 +43,8 @@ namespace xla {
// certain backend directly without using the client interface. HloModule can be
// explicitly built, or loaded from a serialization file (e.g., hlo proto
// file), or parsed from a hlo textual IR string.
class HloRunner {
class HloRunner : public HloRunnerInterface {
public:
// The options used to configure a ExecuteReplicated() call.
struct ReplicatedExecuteOptions {
// The number of devices the HLO module should be replicated onto.
int64 num_replicas = 1;
// The arguments to be fed to each replica. Since this is used for a
// replicated execution, all the arguments are the same for all replicas.
std::vector<const Literal*> arguments;
// If the HLO module being run has an infeed instruction, this will be the
// data which will be fed to it, for as many as infeed_steps steps.
const Literal* infeed = nullptr;
// The number of times the infeed literal should be fed to the HLO module.
// For a clean exit, this should match the iterations-per-loop parameter
// used when generating the HLO module proto (that is usually the main
// while boundary counter). A value higher then iterations-per-loop would
// lead to infeed threads feeding to a gone computation, while a lower
// value would trigger a stuck ExecuteReplicated() call (the computation
// will be trying to infeed data which will never come).
int64 infeed_steps = -1;
// The shape of the outfeed operation. If empty, the HLO module does not
// generate any outfeed.
Shape outfeed_shape;
// A pointer to a vector where the outfeed values will be stored. If
// nullptr, the values will be read and discarded.
std::vector<Literal>* outfeed_values = nullptr;
// Whether the HLO passes should be run on the input module. Usually
// saved modules are coming from after the HLO pass pipeline, so triggering
// another run will likely cause errors.
bool run_hlo_passes = false;
// If true, executes on multiple threads using se::Stream::ExecuteOnStream.
// Otherwise, executes using xla::Executable::ExecuteOnStreams.
bool use_threads = false;
};
// intra_op_parallelism_threads: For the CPU backend only. It is the thread
// pool size for parallel execution of an individual operator. The default
// value of -1 will result in initializing the thread pool with the number of
@ -92,24 +53,7 @@ class HloRunner {
explicit HloRunner(se::Platform* platform,
int intra_op_parallelism_threads = -1);
~HloRunner();
// Converts an HloModule from the given hlo textual IR string (in
// HloModule::ToString format).
static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
const absl::string_view hlo_string, const DebugOptions& debug_options);
// Reads the proto file in xla.HloProto format, creates and returns the
// HloModule.
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromBinaryProtoFile(
const std::string& filename, const DebugOptions& debug_options);
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromTextProtoFile(
const std::string& filename, const DebugOptions& debug_options);
// Reads the hlo text dump file in HloModule::ToString format, creates and
// returns the HloModule.
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile(
const std::string& filename, const DebugOptions& debug_options);
~HloRunner() override;
// Transfers data between the host and device.
StatusOr<ScopedShapedBuffer> TransferLiteralToDevice(const Literal& literal);
@ -124,19 +68,17 @@ class HloRunner {
//
// If run_hlo_passes is false, the module will be executed without Hlo
// optimization.
using HloRunnerInterface::Execute;
StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
absl::Span<const Literal* const> arguments,
bool run_hlo_passes = true,
ExecutionProfile* profile = nullptr);
bool run_hlo_passes,
ExecutionProfile* profile) override;
StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
absl::Span<const Literal> arguments,
bool run_hlo_passes = true,
ExecutionProfile* profile = nullptr);
StatusOr<Literal> Execute(std::unique_ptr<Executable> executable,
absl::Span<const Literal> arguments,
ExecutionProfile* profile = nullptr);
StatusOr<Literal> ExecuteWithExecutable(
std::unique_ptr<Executable> executable,
absl::Span<const Literal> arguments, ExecutionProfile* profile = nullptr);
// As Execute(), but accepts and returns device buffers instead of host
// buffers.
@ -159,13 +101,13 @@ class HloRunner {
// value.
StatusOr<std::vector<Literal>> ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options);
const ReplicatedExecuteOptions& options) override;
// Same as above, but with specified device assignment.
StatusOr<std::vector<Literal>> ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment);
DeviceAssignment* device_assignment) override;
// Same as above, but with a reusable Executable. This may update the profile
// information in *executable.

View File

@ -0,0 +1,90 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_runner_interface.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
namespace xla {
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunnerInterface::CreateModuleFromString(const absl::string_view hlo_string,
const DebugOptions& debug_options) {
HloModuleConfig config;
config.set_debug_options(debug_options);
return ParseAndReturnUnverifiedModule(hlo_string, config);
}
namespace {
// Creates an HloModule from the given proto.
StatusOr<std::unique_ptr<HloModule>> HloProtoToModule(
const HloProto& proto, const DebugOptions& debug_options) {
TF_ASSIGN_OR_RETURN(HloModuleConfig config,
HloModule::CreateModuleConfigFromProto(proto.hlo_module(),
debug_options));
TF_ASSIGN_OR_RETURN(auto module,
HloModule::CreateFromProto(proto.hlo_module(), config));
return std::move(module);
}
} // namespace
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunnerInterface::ReadModuleFromBinaryProtoFile(
const std::string& filename, const DebugOptions& debug_options) {
HloProto proto;
TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
filename, &proto));
return HloProtoToModule(proto, debug_options);
}
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunnerInterface::ReadModuleFromTextProtoFile(
const std::string& filename, const DebugOptions& debug_options) {
HloProto proto;
TF_RETURN_IF_ERROR(
tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto));
return HloProtoToModule(proto, debug_options);
}
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunnerInterface::ReadModuleFromHloTextFile(
const std::string& filename, const DebugOptions& debug_options) {
string hlo_string;
TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
filename, &hlo_string));
HloModuleConfig config;
config.set_debug_options(debug_options);
return ParseAndReturnUnverifiedModule(hlo_string, config);
}
StatusOr<Literal> HloRunnerInterface::Execute(
std::unique_ptr<HloModule> module, absl::Span<const Literal> arguments,
bool run_hlo_passes, ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
std::vector<const Literal*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
argument_pointers.push_back(&argument);
}
return Execute(
/*module=*/std::move(module),
/*arguments=*/argument_pointers,
/*run_hlo_passes=*/run_hlo_passes,
/*profile=*/profile);
}
} // namespace xla

View File

@ -0,0 +1,142 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
// A base class for running an HloModule. This executes the given HloModule on a
// certain backend directly without using the client interface. HloModule can be
// explicitly built, or loaded from a serialization file (e.g., hlo proto
// file), or parsed from a hlo textual IR string.
class HloRunnerInterface {
public:
// The options used to configure an ExecuteReplicated() call.
struct ReplicatedExecuteOptions {
// The number of devices the HLO module should be replicated onto.
int64 num_replicas = 1;
// The arguments to be fed to each replica. Since this is used for a
// replicated execution, all the arguments are the same for all replicas.
std::vector<const Literal*> arguments;
// If the HLO module being run has an infeed instruction, this will be the
// data which will be fed to it, for as many as infeed_steps steps.
const Literal* infeed = nullptr;
// The number of times the infeed literal should be fed to the HLO module.
// For a clean exit, this should match the iterations-per-loop parameter
// used when generating the HLO module proto (that is usually the main
// while boundary counter). A value higher then iterations-per-loop would
// lead to infeed threads feeding to a gone computation, while a lower
// value would trigger a stuck ExecuteReplicated() call (the computation
// will be trying to infeed data which will never come).
int64 infeed_steps = -1;
// The shape of the outfeed operation. If empty, the HLO module does not
// generate any outfeed.
Shape outfeed_shape;
// A pointer to a vector where the outfeed values will be stored. If
// nullptr, the values will be read and discarded.
std::vector<Literal>* outfeed_values = nullptr;
// Whether the HLO passes should be run on the input module. Usually
// saved modules are coming from after the HLO pass pipeline, so triggering
// another run will likely cause errors.
bool run_hlo_passes = false;
// If true, executes on multiple threads using se::Stream::ExecuteOnStream.
// Otherwise, executes using xla::Executable::ExecuteOnStreams.
bool use_threads = false;
};
HloRunnerInterface() = default;
virtual ~HloRunnerInterface() = default;
// Converts an HloModule from the given hlo textual IR string (in
// HloModule::ToString format).
static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
const absl::string_view hlo_string, const DebugOptions& debug_options);
// Reads the proto file in xla.HloProto format, creates and returns the
// HloModule.
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromBinaryProtoFile(
const std::string& filename, const DebugOptions& debug_options);
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromTextProtoFile(
const std::string& filename, const DebugOptions& debug_options);
// Reads the hlo text dump file in HloModule::ToString format, creates and
// returns the HloModule.
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile(
const std::string& filename, const DebugOptions& debug_options);
// Executes the given module with given literals as input and returns the
// result as a Literal.
//
// If run_hlo_passes is false, the module will be executed without Hlo
// optimization
StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
absl::Span<const Literal* const> arguments,
bool run_hlo_passes = true) {
return Execute(std::move(module), arguments, run_hlo_passes, nullptr);
}
StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
absl::Span<const Literal> arguments,
bool run_hlo_passes = true,
ExecutionProfile* profile = nullptr);
virtual StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
absl::Span<const Literal* const> arguments,
bool run_hlo_passes,
ExecutionProfile* profile) = 0;
// Executes a given HLO module into a set of replicas, and returns a map
// with the replica number as key, and the corresponding returned literal as
// value.
// TODO(b/172931928): change to non-virtual function.
virtual StatusOr<std::vector<Literal>> ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options) = 0;
// Same as above, but with specified device assignment.
virtual StatusOr<std::vector<Literal>> ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment) = 0;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_

View File

@ -507,8 +507,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
absl::optional<Literal> canonical_output;
for (int i = 0; i < n; ++i) {
StatusOr<Literal> output =
test_runner_.Execute(std::move(executables[i]), fake_arguments[i],
StatusOr<Literal> output = test_runner_.ExecuteWithExecutable(
std::move(executables[i]), fake_arguments[i],
/*profile=*/&((*profiles)[i]));
if (!output.ok()) {
return ::testing::AssertionFailure() << output.status().error_message();