Internal build rule change.
PiperOrigin-RevId: 345134366 Change-Id: Idfa2b6983b3a9aaeaaa2db4a8e62b73c2533bf0c
This commit is contained in:
parent
5e0f094382
commit
c0141706a4
@ -4481,21 +4481,16 @@ cc_library(
|
||||
srcs = ["hlo_runner_interface.cc"],
|
||||
hdrs = ["hlo_runner_interface.h"],
|
||||
deps = [
|
||||
":compiler",
|
||||
":computation_placer",
|
||||
":executable",
|
||||
":hlo",
|
||||
":hlo_module_group",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -104,8 +104,8 @@ StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
|
||||
}
|
||||
|
||||
StatusOr<Literal> HloRunner::ExecuteWithExecutable(
|
||||
std::unique_ptr<Executable> executable, absl::Span<const Literal> arguments,
|
||||
ExecutionProfile* profile) {
|
||||
std::unique_ptr<Executable> executable,
|
||||
absl::Span<const Literal* const> arguments, ExecutionProfile* profile) {
|
||||
TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> argument_buffers,
|
||||
TransferLiteralsToDevice(arguments));
|
||||
TF_ASSIGN_OR_RETURN(ExecutionOutput result,
|
||||
|
@ -76,9 +76,12 @@ class HloRunner : public HloRunnerInterface {
|
||||
bool run_hlo_passes,
|
||||
ExecutionProfile* profile) override;
|
||||
|
||||
using HloRunnerInterface::ExecuteWithExecutable;
|
||||
|
||||
StatusOr<Literal> ExecuteWithExecutable(
|
||||
std::unique_ptr<Executable> executable,
|
||||
absl::Span<const Literal> arguments, ExecutionProfile* profile = nullptr);
|
||||
absl::Span<const Literal* const> arguments,
|
||||
ExecutionProfile* profile) override;
|
||||
|
||||
// As Execute(), but accepts and returns device buffers instead of host
|
||||
// buffers.
|
||||
|
@ -87,4 +87,17 @@ StatusOr<Literal> HloRunnerInterface::Execute(
|
||||
/*profile=*/profile);
|
||||
}
|
||||
|
||||
StatusOr<Literal> HloRunnerInterface::ExecuteWithExecutable(
|
||||
std::unique_ptr<Executable> executable, absl::Span<const Literal> arguments,
|
||||
ExecutionProfile* profile) {
|
||||
// Construct a vector of plain pointers for the arguments.
|
||||
std::vector<const Literal*> argument_pointers;
|
||||
argument_pointers.reserve(arguments.size());
|
||||
for (const auto& argument : arguments) {
|
||||
argument_pointers.push_back(&argument);
|
||||
}
|
||||
return ExecuteWithExecutable(std::move(executable), argument_pointers,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -122,6 +123,22 @@ class HloRunnerInterface {
|
||||
bool run_hlo_passes,
|
||||
ExecutionProfile* profile) = 0;
|
||||
|
||||
// Same as above, but with Executable as input.
|
||||
StatusOr<Literal> ExecuteWithExecutable(
|
||||
std::unique_ptr<Executable> executable,
|
||||
absl::Span<const Literal> arguments, ExecutionProfile* profile = nullptr);
|
||||
|
||||
StatusOr<Literal> ExecuteWithExecutable(
|
||||
std::unique_ptr<Executable> executable,
|
||||
absl::Span<const Literal* const> arguments) {
|
||||
return ExecuteWithExecutable(std::move(executable), arguments, nullptr);
|
||||
}
|
||||
|
||||
virtual StatusOr<Literal> ExecuteWithExecutable(
|
||||
std::unique_ptr<Executable> executable,
|
||||
absl::Span<const Literal* const> arguments,
|
||||
ExecutionProfile* profile) = 0;
|
||||
|
||||
// Executes a given HLO module into a set of replicas, and returns a map
|
||||
// with the replica number as key, and the corresponding returned literal as
|
||||
// value.
|
||||
|
Loading…
x
Reference in New Issue
Block a user