Make Hlo runner interface so we can add alternative implementations.
PiperOrigin-RevId: 341756589 Change-Id: I267aaa795f38d9ac1f4b4ebd9bba5efc20d2b7ef
This commit is contained in:
parent
696102807d
commit
bc819d9cf3
@ -4459,6 +4459,30 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_runner_interface",
|
||||
srcs = ["hlo_runner_interface.cc"],
|
||||
hdrs = ["hlo_runner_interface.h"],
|
||||
deps = [
|
||||
":compiler",
|
||||
":computation_placer",
|
||||
":hlo",
|
||||
":hlo_module_group",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_runner",
|
||||
srcs = ["hlo_runner.cc"],
|
||||
@ -4471,6 +4495,7 @@ cc_library(
|
||||
":hlo",
|
||||
":hlo_module_group",
|
||||
":hlo_parser",
|
||||
":hlo_runner_interface",
|
||||
":transfer_manager",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
|
@ -34,58 +34,6 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<HloModule>>
|
||||
HloRunner::CreateModuleFromString(const absl::string_view hlo_string,
|
||||
const DebugOptions& debug_options) {
|
||||
HloModuleConfig config;
|
||||
config.set_debug_options(debug_options);
|
||||
return ParseAndReturnUnverifiedModule(hlo_string, config);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Creates an HloModule from the given proto.
|
||||
StatusOr<std::unique_ptr<HloModule>> HloProtoToModule(
|
||||
const HloProto& proto, const DebugOptions& debug_options) {
|
||||
TF_ASSIGN_OR_RETURN(HloModuleConfig config,
|
||||
HloModule::CreateModuleConfigFromProto(proto.hlo_module(),
|
||||
debug_options));
|
||||
TF_ASSIGN_OR_RETURN(auto module,
|
||||
HloModule::CreateFromProto(proto.hlo_module(), config));
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<HloModule>>
|
||||
HloRunner::ReadModuleFromBinaryProtoFile(const std::string& filename,
|
||||
const DebugOptions& debug_options) {
|
||||
HloProto proto;
|
||||
TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
|
||||
filename, &proto));
|
||||
return HloProtoToModule(proto, debug_options);
|
||||
}
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<HloModule>>
|
||||
HloRunner::ReadModuleFromTextProtoFile(const std::string& filename,
|
||||
const DebugOptions& debug_options) {
|
||||
HloProto proto;
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto));
|
||||
return HloProtoToModule(proto, debug_options);
|
||||
}
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<HloModule>>
|
||||
HloRunner::ReadModuleFromHloTextFile(const std::string& filename,
|
||||
const DebugOptions& debug_options) {
|
||||
string hlo_string;
|
||||
TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
|
||||
filename, &hlo_string));
|
||||
HloModuleConfig config;
|
||||
config.set_debug_options(debug_options);
|
||||
return ParseAndReturnUnverifiedModule(hlo_string, config);
|
||||
}
|
||||
|
||||
HloRunner::HloRunner(se::Platform* platform, int intra_op_parallelism_threads) {
|
||||
BackendOptions backend_options;
|
||||
backend_options.set_platform(platform);
|
||||
@ -155,25 +103,8 @@ StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
|
||||
return TransferLiteralFromDevice(result.Result());
|
||||
}
|
||||
|
||||
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
|
||||
absl::Span<const Literal> arguments,
|
||||
bool run_hlo_passes,
|
||||
ExecutionProfile* profile) {
|
||||
// Construct a vector of plain pointers for the arguments.
|
||||
std::vector<const Literal*> argument_pointers;
|
||||
argument_pointers.reserve(arguments.size());
|
||||
for (const auto& argument : arguments) {
|
||||
argument_pointers.push_back(&argument);
|
||||
}
|
||||
return Execute(
|
||||
/*module=*/std::move(module),
|
||||
/*arguments=*/argument_pointers,
|
||||
/*run_hlo_passes=*/run_hlo_passes,
|
||||
/*profile=*/profile);
|
||||
}
|
||||
|
||||
StatusOr<Literal> HloRunner::Execute(std::unique_ptr<Executable> executable,
|
||||
absl::Span<const Literal> arguments,
|
||||
StatusOr<Literal> HloRunner::ExecuteWithExecutable(
|
||||
std::unique_ptr<Executable> executable, absl::Span<const Literal> arguments,
|
||||
ExecutionProfile* profile) {
|
||||
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
|
||||
TransferLiteralsToDevice(arguments));
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_runner_interface.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -42,48 +43,8 @@ namespace xla {
|
||||
// certain backend directly without using the client interface. HloModule can be
|
||||
// explicitly built, or loaded from a serialization file (e.g., hlo proto
|
||||
// file), or parsed from a hlo textual IR string.
|
||||
class HloRunner {
|
||||
class HloRunner : public HloRunnerInterface {
|
||||
public:
|
||||
// The options used to configure a ExecuteReplicated() call.
|
||||
struct ReplicatedExecuteOptions {
|
||||
// The number of devices the HLO module should be replicated onto.
|
||||
int64 num_replicas = 1;
|
||||
|
||||
// The arguments to be fed to each replica. Since this is used for a
|
||||
// replicated execution, all the arguments are the same for all replicas.
|
||||
std::vector<const Literal*> arguments;
|
||||
|
||||
// If the HLO module being run has an infeed instruction, this will be the
|
||||
// data which will be fed to it, for as many as infeed_steps steps.
|
||||
const Literal* infeed = nullptr;
|
||||
|
||||
// The number of times the infeed literal should be fed to the HLO module.
|
||||
// For a clean exit, this should match the iterations-per-loop parameter
|
||||
// used when generating the HLO module proto (that is usually the main
|
||||
// while boundary counter). A value higher then iterations-per-loop would
|
||||
// lead to infeed threads feeding to a gone computation, while a lower
|
||||
// value would trigger a stuck ExecuteReplicated() call (the computation
|
||||
// will be trying to infeed data which will never come).
|
||||
int64 infeed_steps = -1;
|
||||
|
||||
// The shape of the outfeed operation. If empty, the HLO module does not
|
||||
// generate any outfeed.
|
||||
Shape outfeed_shape;
|
||||
|
||||
// A pointer to a vector where the outfeed values will be stored. If
|
||||
// nullptr, the values will be read and discarded.
|
||||
std::vector<Literal>* outfeed_values = nullptr;
|
||||
|
||||
// Whether the HLO passes should be run on the input module. Usually
|
||||
// saved modules are coming from after the HLO pass pipeline, so triggering
|
||||
// another run will likely cause errors.
|
||||
bool run_hlo_passes = false;
|
||||
|
||||
// If true, executes on multiple threads using se::Stream::ExecuteOnStream.
|
||||
// Otherwise, executes using xla::Executable::ExecuteOnStreams.
|
||||
bool use_threads = false;
|
||||
};
|
||||
|
||||
// intra_op_parallelism_threads: For the CPU backend only. It is the thread
|
||||
// pool size for parallel execution of an individual operator. The default
|
||||
// value of -1 will result in initializing the thread pool with the number of
|
||||
@ -92,24 +53,7 @@ class HloRunner {
|
||||
explicit HloRunner(se::Platform* platform,
|
||||
int intra_op_parallelism_threads = -1);
|
||||
|
||||
~HloRunner();
|
||||
|
||||
// Converts an HloModule from the given hlo textual IR string (in
|
||||
// HloModule::ToString format).
|
||||
static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
|
||||
const absl::string_view hlo_string, const DebugOptions& debug_options);
|
||||
|
||||
// Reads the proto file in xla.HloProto format, creates and returns the
|
||||
// HloModule.
|
||||
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromBinaryProtoFile(
|
||||
const std::string& filename, const DebugOptions& debug_options);
|
||||
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromTextProtoFile(
|
||||
const std::string& filename, const DebugOptions& debug_options);
|
||||
|
||||
// Reads the hlo text dump file in HloModule::ToString format, creates and
|
||||
// returns the HloModule.
|
||||
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile(
|
||||
const std::string& filename, const DebugOptions& debug_options);
|
||||
~HloRunner() override;
|
||||
|
||||
// Transfers data between the host and device.
|
||||
StatusOr<ScopedShapedBuffer> TransferLiteralToDevice(const Literal& literal);
|
||||
@ -124,19 +68,17 @@ class HloRunner {
|
||||
//
|
||||
// If run_hlo_passes is false, the module will be executed without Hlo
|
||||
// optimization.
|
||||
|
||||
using HloRunnerInterface::Execute;
|
||||
|
||||
StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
|
||||
absl::Span<const Literal* const> arguments,
|
||||
bool run_hlo_passes = true,
|
||||
ExecutionProfile* profile = nullptr);
|
||||
bool run_hlo_passes,
|
||||
ExecutionProfile* profile) override;
|
||||
|
||||
StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
|
||||
absl::Span<const Literal> arguments,
|
||||
bool run_hlo_passes = true,
|
||||
ExecutionProfile* profile = nullptr);
|
||||
|
||||
StatusOr<Literal> Execute(std::unique_ptr<Executable> executable,
|
||||
absl::Span<const Literal> arguments,
|
||||
ExecutionProfile* profile = nullptr);
|
||||
StatusOr<Literal> ExecuteWithExecutable(
|
||||
std::unique_ptr<Executable> executable,
|
||||
absl::Span<const Literal> arguments, ExecutionProfile* profile = nullptr);
|
||||
|
||||
// As Execute(), but accepts and returns device buffers instead of host
|
||||
// buffers.
|
||||
@ -159,13 +101,13 @@ class HloRunner {
|
||||
// value.
|
||||
StatusOr<std::vector<Literal>> ExecuteReplicated(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const ReplicatedExecuteOptions& options);
|
||||
const ReplicatedExecuteOptions& options) override;
|
||||
|
||||
// Same as above, but with specified device assignment.
|
||||
StatusOr<std::vector<Literal>> ExecuteReplicated(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const ReplicatedExecuteOptions& options,
|
||||
DeviceAssignment* device_assignment);
|
||||
DeviceAssignment* device_assignment) override;
|
||||
|
||||
// Same as above, but with a reusable Executable. This may update the profile
|
||||
// information in *executable.
|
||||
|
90
tensorflow/compiler/xla/service/hlo_runner_interface.cc
Normal file
90
tensorflow/compiler/xla/service/hlo_runner_interface.cc
Normal file
@ -0,0 +1,90 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_runner_interface.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<HloModule>>
|
||||
HloRunnerInterface::CreateModuleFromString(const absl::string_view hlo_string,
|
||||
const DebugOptions& debug_options) {
|
||||
HloModuleConfig config;
|
||||
config.set_debug_options(debug_options);
|
||||
return ParseAndReturnUnverifiedModule(hlo_string, config);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Creates an HloModule from the given proto.
|
||||
StatusOr<std::unique_ptr<HloModule>> HloProtoToModule(
|
||||
const HloProto& proto, const DebugOptions& debug_options) {
|
||||
TF_ASSIGN_OR_RETURN(HloModuleConfig config,
|
||||
HloModule::CreateModuleConfigFromProto(proto.hlo_module(),
|
||||
debug_options));
|
||||
TF_ASSIGN_OR_RETURN(auto module,
|
||||
HloModule::CreateFromProto(proto.hlo_module(), config));
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<HloModule>>
|
||||
HloRunnerInterface::ReadModuleFromBinaryProtoFile(
|
||||
const std::string& filename, const DebugOptions& debug_options) {
|
||||
HloProto proto;
|
||||
TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
|
||||
filename, &proto));
|
||||
return HloProtoToModule(proto, debug_options);
|
||||
}
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<HloModule>>
|
||||
HloRunnerInterface::ReadModuleFromTextProtoFile(
|
||||
const std::string& filename, const DebugOptions& debug_options) {
|
||||
HloProto proto;
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto));
|
||||
return HloProtoToModule(proto, debug_options);
|
||||
}
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<HloModule>>
|
||||
HloRunnerInterface::ReadModuleFromHloTextFile(
|
||||
const std::string& filename, const DebugOptions& debug_options) {
|
||||
string hlo_string;
|
||||
TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
|
||||
filename, &hlo_string));
|
||||
HloModuleConfig config;
|
||||
config.set_debug_options(debug_options);
|
||||
return ParseAndReturnUnverifiedModule(hlo_string, config);
|
||||
}
|
||||
|
||||
StatusOr<Literal> HloRunnerInterface::Execute(
|
||||
std::unique_ptr<HloModule> module, absl::Span<const Literal> arguments,
|
||||
bool run_hlo_passes, ExecutionProfile* profile) {
|
||||
// Construct a vector of plain pointers for the arguments.
|
||||
std::vector<const Literal*> argument_pointers;
|
||||
argument_pointers.reserve(arguments.size());
|
||||
for (const auto& argument : arguments) {
|
||||
argument_pointers.push_back(&argument);
|
||||
}
|
||||
return Execute(
|
||||
/*module=*/std::move(module),
|
||||
/*arguments=*/argument_pointers,
|
||||
/*run_hlo_passes=*/run_hlo_passes,
|
||||
/*profile=*/profile);
|
||||
}
|
||||
|
||||
} // namespace xla
|
142
tensorflow/compiler/xla/service/hlo_runner_interface.h
Normal file
142
tensorflow/compiler/xla/service/hlo_runner_interface.h
Normal file
@ -0,0 +1,142 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// A base class for running an HloModule. This executes the given HloModule on a
|
||||
// certain backend directly without using the client interface. HloModule can be
|
||||
// explicitly built, or loaded from a serialization file (e.g., hlo proto
|
||||
// file), or parsed from a hlo textual IR string.
|
||||
class HloRunnerInterface {
|
||||
public:
|
||||
// The options used to configure an ExecuteReplicated() call.
|
||||
struct ReplicatedExecuteOptions {
|
||||
// The number of devices the HLO module should be replicated onto.
|
||||
int64 num_replicas = 1;
|
||||
|
||||
// The arguments to be fed to each replica. Since this is used for a
|
||||
// replicated execution, all the arguments are the same for all replicas.
|
||||
std::vector<const Literal*> arguments;
|
||||
|
||||
// If the HLO module being run has an infeed instruction, this will be the
|
||||
// data which will be fed to it, for as many as infeed_steps steps.
|
||||
const Literal* infeed = nullptr;
|
||||
|
||||
// The number of times the infeed literal should be fed to the HLO module.
|
||||
// For a clean exit, this should match the iterations-per-loop parameter
|
||||
// used when generating the HLO module proto (that is usually the main
|
||||
// while boundary counter). A value higher then iterations-per-loop would
|
||||
// lead to infeed threads feeding to a gone computation, while a lower
|
||||
// value would trigger a stuck ExecuteReplicated() call (the computation
|
||||
// will be trying to infeed data which will never come).
|
||||
int64 infeed_steps = -1;
|
||||
|
||||
// The shape of the outfeed operation. If empty, the HLO module does not
|
||||
// generate any outfeed.
|
||||
Shape outfeed_shape;
|
||||
|
||||
// A pointer to a vector where the outfeed values will be stored. If
|
||||
// nullptr, the values will be read and discarded.
|
||||
std::vector<Literal>* outfeed_values = nullptr;
|
||||
|
||||
// Whether the HLO passes should be run on the input module. Usually
|
||||
// saved modules are coming from after the HLO pass pipeline, so triggering
|
||||
// another run will likely cause errors.
|
||||
bool run_hlo_passes = false;
|
||||
|
||||
// If true, executes on multiple threads using se::Stream::ExecuteOnStream.
|
||||
// Otherwise, executes using xla::Executable::ExecuteOnStreams.
|
||||
bool use_threads = false;
|
||||
};
|
||||
|
||||
HloRunnerInterface() = default;
|
||||
|
||||
virtual ~HloRunnerInterface() = default;
|
||||
|
||||
// Converts an HloModule from the given hlo textual IR string (in
|
||||
// HloModule::ToString format).
|
||||
static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
|
||||
const absl::string_view hlo_string, const DebugOptions& debug_options);
|
||||
|
||||
// Reads the proto file in xla.HloProto format, creates and returns the
|
||||
// HloModule.
|
||||
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromBinaryProtoFile(
|
||||
const std::string& filename, const DebugOptions& debug_options);
|
||||
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromTextProtoFile(
|
||||
const std::string& filename, const DebugOptions& debug_options);
|
||||
|
||||
// Reads the hlo text dump file in HloModule::ToString format, creates and
|
||||
// returns the HloModule.
|
||||
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile(
|
||||
const std::string& filename, const DebugOptions& debug_options);
|
||||
|
||||
// Executes the given module with given literals as input and returns the
|
||||
// result as a Literal.
|
||||
//
|
||||
// If run_hlo_passes is false, the module will be executed without Hlo
|
||||
// optimization
|
||||
StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
|
||||
absl::Span<const Literal* const> arguments,
|
||||
bool run_hlo_passes = true) {
|
||||
return Execute(std::move(module), arguments, run_hlo_passes, nullptr);
|
||||
}
|
||||
|
||||
StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
|
||||
absl::Span<const Literal> arguments,
|
||||
bool run_hlo_passes = true,
|
||||
ExecutionProfile* profile = nullptr);
|
||||
|
||||
virtual StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
|
||||
absl::Span<const Literal* const> arguments,
|
||||
bool run_hlo_passes,
|
||||
ExecutionProfile* profile) = 0;
|
||||
|
||||
// Executes a given HLO module into a set of replicas, and returns a map
|
||||
// with the replica number as key, and the corresponding returned literal as
|
||||
// value.
|
||||
// TODO(b/172931928): change to non-virtual function.
|
||||
virtual StatusOr<std::vector<Literal>> ExecuteReplicated(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const ReplicatedExecuteOptions& options) = 0;
|
||||
|
||||
// Same as above, but with specified device assignment.
|
||||
virtual StatusOr<std::vector<Literal>> ExecuteReplicated(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const ReplicatedExecuteOptions& options,
|
||||
DeviceAssignment* device_assignment) = 0;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_INTERFACE_H_
|
@ -507,8 +507,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
|
||||
|
||||
absl::optional<Literal> canonical_output;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
StatusOr<Literal> output =
|
||||
test_runner_.Execute(std::move(executables[i]), fake_arguments[i],
|
||||
StatusOr<Literal> output = test_runner_.ExecuteWithExecutable(
|
||||
std::move(executables[i]), fake_arguments[i],
|
||||
/*profile=*/&((*profiles)[i]));
|
||||
if (!output.ok()) {
|
||||
return ::testing::AssertionFailure() << output.status().error_message();
|
||||
|
Loading…
x
Reference in New Issue
Block a user