diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index b54c93ba214..9545dbdb031 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -438,21 +438,21 @@ class PjRtExecutable { // by the client. virtual StatusOr>>> Execute(absl::Span> argument_handles, - const ExecuteOptions& options) = 0; + const ExecuteOptions& options) const = 0; // Execute the assigned replica/partition on a given `device`. Requires // executable has a device_assignment, `device` is present in the // device_assignment and addressable by the client. virtual StatusOr>> ExecuteSharded( absl::Span argument_handles, PjRtDevice* device, - const ExecuteOptions& options) = 0; + const ExecuteOptions& options) const = 0; // Execute on a given `device`. Requires `device` to be addressable by client. // Requires executable has exactly 1 replica and 1 partition and no // device_assignment (thus portable). virtual StatusOr>> ExecutePortable( absl::Span argument_handles, PjRtDevice* device, - const ExecuteOptions& options) = 0; + const ExecuteOptions& options) const = 0; // Asynchronously free resources after the last execution completes. virtual void Delete() = 0; diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc index e31db159b04..5f3edf9b67e 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc @@ -1936,7 +1936,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper( StatusOr>>> PjRtStreamExecutorExecutable::Execute( absl::Span> argument_handles, - const ExecuteOptions& options) { + const ExecuteOptions& options) const { if (device_assignment_ == nullptr) { return InvalidArgument("Execute expects a non-null device_assignment"); } @@ -2047,7 +2047,7 @@ PjRtStreamExecutorExecutable::Execute( StatusOr>> PjRtStreamExecutorExecutable::ExecuteSharded( absl::Span argument_handles, PjRtDevice* device, - const ExecuteOptions& options) { + const ExecuteOptions& options) const { if (device_assignment_ == nullptr) { return InvalidArgument("ExecuteShard expects a non-null device_assignment"); } @@ -2070,7 +2070,7 @@ PjRtStreamExecutorExecutable::ExecuteSharded( StatusOr>> PjRtStreamExecutorExecutable::ExecutePortable( absl::Span argument_handles, PjRtDevice* device, - const ExecuteOptions& options) { + const ExecuteOptions& options) const { if (device_assignment_ != nullptr) { return InvalidArgument("ExecutePortable gets a non-portable executable"); } diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h index 2f55a71a564..f26fb12e8a8 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h @@ -686,15 +686,15 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable { StatusOr>>> Execute( absl::Span> argument_handles, - const ExecuteOptions& options) override; + const ExecuteOptions& options) const override; StatusOr>> ExecuteSharded( absl::Span argument_handles, PjRtDevice* device, - const ExecuteOptions& options) override; + const ExecuteOptions& options) const override; StatusOr>> ExecutePortable( absl::Span argument_handles, PjRtDevice* device, - const ExecuteOptions& options) override; + const ExecuteOptions& options) const override; void Delete() override { executables_.clear(); }