From 48cd0906362e6f7d2723622edf735531692aa59b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 29 Jan 2021 18:25:19 -0800 Subject: [PATCH] Add SPMD alternative implementations. PiperOrigin-RevId: 354649029 Change-Id: I60395c2f5aad44d5ebb88f45e102b95029deff37 --- tensorflow/compiler/xla/service/hlo_runner.h | 2 +- .../compiler/xla/service/hlo_runner_interface.h | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 8c0593218c1..b3ae46ae822 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -97,7 +97,7 @@ class HloRunner : public HloRunnerInterface { // Creates an executable object given an HLO module. If run_hlo_passes is // true, the HLO passes will be run as part of compilation. StatusOr> CreateExecutable( - std::unique_ptr module, bool run_hlo_passes); + std::unique_ptr module, bool run_hlo_passes) override; // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as diff --git a/tensorflow/compiler/xla/service/hlo_runner_interface.h b/tensorflow/compiler/xla/service/hlo_runner_interface.h index 886a99e0a0b..aeddd31192c 100644 --- a/tensorflow/compiler/xla/service/hlo_runner_interface.h +++ b/tensorflow/compiler/xla/service/hlo_runner_interface.h @@ -108,6 +108,11 @@ class HloRunnerInterface { static StatusOr> ReadModuleFromHloTextFile( const std::string& filename, const DebugOptions& debug_options); + // Creates an executable object given an HLO module. If run_hlo_passes is + // true, the HLO passes will be run as part of compilation. + virtual StatusOr> CreateExecutable( + std::unique_ptr module, bool run_hlo_passes) = 0; + // Executes the given module with given literals as input and returns the // result as a Literal. // @@ -158,6 +163,12 @@ class HloRunnerInterface { std::unique_ptr module, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment) = 0; + + virtual StatusOr> ExecuteReplicated( + std::function executable_provider, + std::function argument_count_provider, + std::function argument_provider, + const ReplicatedExecuteOptions& options) = 0; }; } // namespace xla