[XLA] [client] Implement a RunAsync overload which does not need a vector of shapes

PiperOrigin-RevId: 317406952
Change-Id: I69d8cc8a68ffdfbf70e2969f5df5e6adba7d2e1d
This commit is contained in:
George Karpenkov 2020-06-19 17:06:02 -07:00 committed by TensorFlower Gardener
parent f51b649394
commit 6116b7f911
2 changed files with 13 additions and 0 deletions

View File

@ -320,6 +320,16 @@ StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
return std::move(outputs); return std::move(outputs);
} }
StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options) {
std::vector<const Shape*> argument_shapes;
argument_shapes.reserve(arguments.size());
for (const ExecutionInput& arg : arguments) {
argument_shapes.push_back(&arg.shape());
}
return RunAsync(argument_shapes, std::move(arguments), run_options);
}
se::Platform* LocalClient::platform() const { se::Platform* LocalClient::platform() const {
return local_service_->backend().platform(); return local_service_->backend().platform();
} }

View File

@ -68,6 +68,9 @@ class LocalExecutable {
absl::Span<Shape const* const> argument_host_shapes, absl::Span<Shape const* const> argument_host_shapes,
std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options); std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options);
StatusOr<ExecutionOutput> RunAsync(std::vector<ExecutionInput> arguments,
ExecutableRunOptions run_options);
// Return the options used to build the executable. // Return the options used to build the executable.
const ExecutableBuildOptions& build_options() const { return build_options_; } const ExecutableBuildOptions& build_options() const { return build_options_; }