Mark Execute methods non-const allow more flexibility, e.g., lazily load programs on first execution.
PiperOrigin-RevId: 347848574 Change-Id: Id22a727f1c0d491ecb3e2eadb72fe734d0670dbc
This commit is contained in:
parent
40bd5a4d99
commit
e8f5ba82c9
@ -438,21 +438,21 @@ class PjRtExecutable {
|
||||
// by the client.
|
||||
virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
|
||||
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
||||
const ExecuteOptions& options) = 0;
|
||||
const ExecuteOptions& options) const = 0;
|
||||
|
||||
// Execute the assigned replica/partition on a given `device`. Requires
|
||||
// executable has a device_assignment, `device` is present in the
|
||||
// device_assignment and addressable by the client.
|
||||
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||
const ExecuteOptions& options) = 0;
|
||||
const ExecuteOptions& options) const = 0;
|
||||
|
||||
// Execute on a given `device`. Requires `device` to be addressable by client.
|
||||
// Requires executable has exactly 1 replica and 1 partition and no
|
||||
// device_assignment (thus portable).
|
||||
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||
const ExecuteOptions& options) = 0;
|
||||
const ExecuteOptions& options) const = 0;
|
||||
|
||||
// Asynchronously free resources after the last execution completes.
|
||||
virtual void Delete() = 0;
|
||||
|
@ -1936,7 +1936,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
||||
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
|
||||
PjRtStreamExecutorExecutable::Execute(
|
||||
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
||||
const ExecuteOptions& options) {
|
||||
const ExecuteOptions& options) const {
|
||||
if (device_assignment_ == nullptr) {
|
||||
return InvalidArgument("Execute expects a non-null device_assignment");
|
||||
}
|
||||
@ -2047,7 +2047,7 @@ PjRtStreamExecutorExecutable::Execute(
|
||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
||||
PjRtStreamExecutorExecutable::ExecuteSharded(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||
const ExecuteOptions& options) {
|
||||
const ExecuteOptions& options) const {
|
||||
if (device_assignment_ == nullptr) {
|
||||
return InvalidArgument("ExecuteShard expects a non-null device_assignment");
|
||||
}
|
||||
@ -2070,7 +2070,7 @@ PjRtStreamExecutorExecutable::ExecuteSharded(
|
||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
||||
PjRtStreamExecutorExecutable::ExecutePortable(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||
const ExecuteOptions& options) {
|
||||
const ExecuteOptions& options) const {
|
||||
if (device_assignment_ != nullptr) {
|
||||
return InvalidArgument("ExecutePortable gets a non-portable executable");
|
||||
}
|
||||
|
@ -686,15 +686,15 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
||||
|
||||
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
|
||||
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
||||
const ExecuteOptions& options) override;
|
||||
const ExecuteOptions& options) const override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||
const ExecuteOptions& options) override;
|
||||
const ExecuteOptions& options) const override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||
const ExecuteOptions& options) override;
|
||||
const ExecuteOptions& options) const override;
|
||||
|
||||
void Delete() override { executables_.clear(); }
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user