Mark Execute methods non-const allow more flexibility, e.g., lazily load programs on first execution.
PiperOrigin-RevId: 347870420 Change-Id: Ie4c21280c14dd6201d24d42b9f1fd967a9617307
This commit is contained in:
parent
6106f70707
commit
2d0d477194
@ -438,21 +438,21 @@ class PjRtExecutable {
|
|||||||
// by the client.
|
// by the client.
|
||||||
virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
|
virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
|
||||||
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
||||||
const ExecuteOptions& options) const = 0;
|
const ExecuteOptions& options) = 0;
|
||||||
|
|
||||||
// Execute the assigned replica/partition on a given `device`. Requires
|
// Execute the assigned replica/partition on a given `device`. Requires
|
||||||
// executable has a device_assignment, `device` is present in the
|
// executable has a device_assignment, `device` is present in the
|
||||||
// device_assignment and addressable by the client.
|
// device_assignment and addressable by the client.
|
||||||
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
|
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||||
const ExecuteOptions& options) const = 0;
|
const ExecuteOptions& options) = 0;
|
||||||
|
|
||||||
// Execute on a given `device`. Requires `device` to be addressable by client.
|
// Execute on a given `device`. Requires `device` to be addressable by client.
|
||||||
// Requires executable has exactly 1 replica and 1 partition and no
|
// Requires executable has exactly 1 replica and 1 partition and no
|
||||||
// device_assignment (thus portable).
|
// device_assignment (thus portable).
|
||||||
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
|
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||||
const ExecuteOptions& options) const = 0;
|
const ExecuteOptions& options) = 0;
|
||||||
|
|
||||||
// Asynchronously free resources after the last execution completes.
|
// Asynchronously free resources after the last execution completes.
|
||||||
virtual void Delete() = 0;
|
virtual void Delete() = 0;
|
||||||
|
|||||||
@ -1936,7 +1936,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
|||||||
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
|
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
|
||||||
PjRtStreamExecutorExecutable::Execute(
|
PjRtStreamExecutorExecutable::Execute(
|
||||||
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
||||||
const ExecuteOptions& options) const {
|
const ExecuteOptions& options) {
|
||||||
if (device_assignment_ == nullptr) {
|
if (device_assignment_ == nullptr) {
|
||||||
return InvalidArgument("Execute expects a non-null device_assignment");
|
return InvalidArgument("Execute expects a non-null device_assignment");
|
||||||
}
|
}
|
||||||
@ -2047,7 +2047,7 @@ PjRtStreamExecutorExecutable::Execute(
|
|||||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
||||||
PjRtStreamExecutorExecutable::ExecuteSharded(
|
PjRtStreamExecutorExecutable::ExecuteSharded(
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||||
const ExecuteOptions& options) const {
|
const ExecuteOptions& options) {
|
||||||
if (device_assignment_ == nullptr) {
|
if (device_assignment_ == nullptr) {
|
||||||
return InvalidArgument("ExecuteShard expects a non-null device_assignment");
|
return InvalidArgument("ExecuteShard expects a non-null device_assignment");
|
||||||
}
|
}
|
||||||
@ -2070,7 +2070,7 @@ PjRtStreamExecutorExecutable::ExecuteSharded(
|
|||||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
||||||
PjRtStreamExecutorExecutable::ExecutePortable(
|
PjRtStreamExecutorExecutable::ExecutePortable(
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||||
const ExecuteOptions& options) const {
|
const ExecuteOptions& options) {
|
||||||
if (device_assignment_ != nullptr) {
|
if (device_assignment_ != nullptr) {
|
||||||
return InvalidArgument("ExecutePortable gets a non-portable executable");
|
return InvalidArgument("ExecutePortable gets a non-portable executable");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -686,15 +686,15 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
|||||||
|
|
||||||
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
|
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
|
||||||
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
||||||
const ExecuteOptions& options) const override;
|
const ExecuteOptions& options) override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
|
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||||
const ExecuteOptions& options) const override;
|
const ExecuteOptions& options) override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
|
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||||
const ExecuteOptions& options) const override;
|
const ExecuteOptions& options) override;
|
||||||
|
|
||||||
void Delete() override { executables_.clear(); }
|
void Delete() override { executables_.clear(); }
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user