Mark Execute methods non-const allow more flexibility, e.g., lazily load programs on first execution.

PiperOrigin-RevId: 347817207
Change-Id: I81480d76c497c7ee24a1d6baefaef61561b0e40e
This commit is contained in:
Qiao Zhang 2020-12-16 07:21:05 -08:00 committed by TensorFlower Gardener
parent decac0e638
commit 26f4c529e7
3 changed files with 9 additions and 9 deletions

View File

@ -438,21 +438,21 @@ class PjRtExecutable {
// by the client.
virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) const = 0;
const ExecuteOptions& options) = 0;
// Execute the assigned replica/partition on a given `device`. Requires
// executable has a device_assignment, `device` is present in the
// device_assignment and addressable by the client.
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const = 0;
const ExecuteOptions& options) = 0;
// Execute on a given `device`. Requires `device` to be addressable by client.
// Requires executable has exactly 1 replica and 1 partition and no
// device_assignment (thus portable).
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const = 0;
const ExecuteOptions& options) = 0;
// Asynchronously free resources after the last execution completes.
virtual void Delete() = 0;

View File

@ -1936,7 +1936,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
PjRtStreamExecutorExecutable::Execute(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) const {
const ExecuteOptions& options) {
if (device_assignment_ == nullptr) {
return InvalidArgument("Execute expects a non-null device_assignment");
}
@ -2047,7 +2047,7 @@ PjRtStreamExecutorExecutable::Execute(
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtStreamExecutorExecutable::ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const {
const ExecuteOptions& options) {
if (device_assignment_ == nullptr) {
return InvalidArgument("ExecuteShard expects a non-null device_assignment");
}
@ -2070,7 +2070,7 @@ PjRtStreamExecutorExecutable::ExecuteSharded(
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtStreamExecutorExecutable::ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const {
const ExecuteOptions& options) {
if (device_assignment_ != nullptr) {
return InvalidArgument("ExecutePortable gets a non-portable executable");
}

View File

@ -686,15 +686,15 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) const override;
const ExecuteOptions& options) override;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const override;
const ExecuteOptions& options) override;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const override;
const ExecuteOptions& options) override;
void Delete() override { executables_.clear(); }