diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index aa252067e19..5fc9909fa2a 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -320,6 +320,16 @@ StatusOr LocalExecutable::RunAsync( return std::move(outputs); } +StatusOr LocalExecutable::RunAsync( + std::vector arguments, ExecutableRunOptions run_options) { + std::vector argument_shapes; + argument_shapes.reserve(arguments.size()); + for (const ExecutionInput& arg : arguments) { + argument_shapes.push_back(&arg.shape()); + } + return RunAsync(argument_shapes, std::move(arguments), run_options); +} + se::Platform* LocalClient::platform() const { return local_service_->backend().platform(); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 3241ac73d54..8b91f4a1739 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -68,6 +68,9 @@ class LocalExecutable { absl::Span argument_host_shapes, std::vector arguments, ExecutableRunOptions run_options); + StatusOr RunAsync(std::vector arguments, + ExecutableRunOptions run_options); + // Return the options used to build the executable. const ExecutableBuildOptions& build_options() const { return build_options_; }