Mark Execute methods non-const allow more flexibility, e.g., lazily load programs on first execution.

PiperOrigin-RevId: 347848574
Change-Id: Id22a727f1c0d491ecb3e2eadb72fe734d0670dbc
This commit is contained in:
A. Unique TensorFlower 2020-12-16 10:25:47 -08:00 committed by TensorFlower Gardener
parent 40bd5a4d99
commit e8f5ba82c9
3 changed files with 9 additions and 9 deletions

View File

@ -438,21 +438,21 @@ class PjRtExecutable {
// by the client.
virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) = 0;
const ExecuteOptions& options) const = 0;
// Execute the assigned replica/partition on a given `device`. Requires
// executable has a device_assignment, `device` is present in the
// device_assignment and addressable by the client.
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) = 0;
const ExecuteOptions& options) const = 0;
// Execute on a given `device`. Requires `device` to be addressable by client.
// Requires executable has exactly 1 replica and 1 partition and no
// device_assignment (thus portable).
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) = 0;
const ExecuteOptions& options) const = 0;
// Asynchronously free resources after the last execution completes.
virtual void Delete() = 0;

View File

@ -1936,7 +1936,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
PjRtStreamExecutorExecutable::Execute(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) {
const ExecuteOptions& options) const {
if (device_assignment_ == nullptr) {
return InvalidArgument("Execute expects a non-null device_assignment");
}
@ -2047,7 +2047,7 @@ PjRtStreamExecutorExecutable::Execute(
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtStreamExecutorExecutable::ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) {
const ExecuteOptions& options) const {
if (device_assignment_ == nullptr) {
return InvalidArgument("ExecuteShard expects a non-null device_assignment");
}
@ -2070,7 +2070,7 @@ PjRtStreamExecutorExecutable::ExecuteSharded(
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtStreamExecutorExecutable::ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) {
const ExecuteOptions& options) const {
if (device_assignment_ != nullptr) {
return InvalidArgument("ExecutePortable gets a non-portable executable");
}

View File

@ -686,15 +686,15 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) override;
const ExecuteOptions& options) const override;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) override;
const ExecuteOptions& options) const override;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) override;
const ExecuteOptions& options) const override;
void Delete() override { executables_.clear(); }