Add SPMD alternative implementations.

PiperOrigin-RevId: 354649029
Change-Id: I60395c2f5aad44d5ebb88f45e102b95029deff37
This commit is contained in:
A. Unique TensorFlower 2021-01-29 18:25:19 -08:00 committed by TensorFlower Gardener
parent adc7aeab09
commit 48cd090636
2 changed files with 12 additions and 1 deletions

View File

@ -97,7 +97,7 @@ class HloRunner : public HloRunnerInterface {
// Creates an executable object given an HLO module. If run_hlo_passes is
// true, the HLO passes will be run as part of compilation.
StatusOr<std::unique_ptr<Executable>> CreateExecutable(
std::unique_ptr<HloModule> module, bool run_hlo_passes);
std::unique_ptr<HloModule> module, bool run_hlo_passes) override;
// Executes a given HLO module into a set of replicas, and returns a map
// with the replica number as key, and the corresponding returned literal as

View File

@ -108,6 +108,11 @@ class HloRunnerInterface {
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile(
const std::string& filename, const DebugOptions& debug_options);
// Creates an executable object given an HLO module. If run_hlo_passes is
// true, the HLO passes will be run as part of compilation.
virtual StatusOr<std::unique_ptr<Executable>> CreateExecutable(
std::unique_ptr<HloModule> module, bool run_hlo_passes) = 0;
// Executes the given module with given literals as input and returns the
// result as a Literal.
//
@ -158,6 +163,12 @@ class HloRunnerInterface {
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment) = 0;
virtual StatusOr<std::vector<Literal>> ExecuteReplicated(
std::function<Executable*(int64)> executable_provider,
std::function<int64(int64)> argument_count_provider,
std::function<const Literal*(int64, int64)> argument_provider,
const ReplicatedExecuteOptions& options) = 0;
};
} // namespace xla