Enabling lazy variable tensor copy.

PiperOrigin-RevId: 356301869
Change-Id: I1b8e2b5192d08294d2ee685f1a101690e2929983
This commit is contained in:
A. Unique TensorFlower 2021-02-08 10:48:20 -08:00 committed by TensorFlower Gardener
parent e2b5e921b8
commit 7516f6ee75
2 changed files with 53 additions and 7 deletions

View File

@ -440,7 +440,7 @@ class TensorTieFactory {
std::unique_ptr<TensorObjectConverterBuilder> converter_builder_;
};
class InferenceRunnerImpl : public InferenceRunner {
class InferenceRunnerImpl : public CLInferenceRunner {
public:
InferenceRunnerImpl(Environment* environment,
std::unique_ptr<InferenceContext> context
@ -503,19 +503,36 @@ class InferenceRunnerImpl : public InferenceRunner {
return outputs_[index]->SetExternalObject(object);
}
absl::Status CopyFromExternalInput(int index) override {
if (index > inputs_.size()) {
return absl::NotFoundError(
absl::StrCat("Input id ", index, " is an invalid input index."));
}
return inputs_[index]->CopyFromExternalObject();
}
absl::Status CopyToExternalOutput(int index) override {
if (index > outputs_.size()) {
return absl::NotFoundError(
absl::StrCat("Output id ", index, " is an invalid output index"));
}
return outputs_[index]->CopyToExternalObject();
}
absl::Status Run() override {
#ifdef CL_DELEGATE_ALLOW_GL
if (gl_interop_fabric_) {
RETURN_IF_ERROR(gl_interop_fabric_->Start());
}
#endif
for (auto& obj : inputs_) {
RETURN_IF_ERROR(obj->CopyFromExternalObject());
for (int i = 0; i < inputs_.size(); i++) {
RETURN_IF_ERROR(CopyFromExternalInput(i));
}
RETURN_IF_ERROR(context_->AddToQueue(queue_));
clFlush(queue_->queue());
for (auto& obj : outputs_) {
RETURN_IF_ERROR(obj->CopyToExternalObject());
RETURN_IF_ERROR(RunWithoutExternalBufferCopy());
for (int i = 0; i < outputs_.size(); i++) {
RETURN_IF_ERROR(CopyToExternalOutput(i));
}
#ifdef CL_DELEGATE_ALLOW_GL
if (gl_interop_fabric_) {
@ -525,6 +542,13 @@ class InferenceRunnerImpl : public InferenceRunner {
return absl::OkStatus();
}
absl::Status RunWithoutExternalBufferCopy() override {
RETURN_IF_ERROR(context_->AddToQueue(queue_));
clFlush(queue_->queue());
return absl::OkStatus();
}
private:
static absl::Status LinkTensors(
const std::vector<TensorTieDef>& defs, TensorTieFactory* factory,

View File

@ -140,6 +140,28 @@ absl::Status NewInferenceEnvironment(
std::unique_ptr<InferenceEnvironment>* environment,
InferenceEnvironmentProperties* properties /* optional */);
class CLInferenceRunner : public ::tflite::gpu::InferenceRunner {
public:
// The RunWithoutExternalBufferCopy provides a contract where the user of this
// interface does not need
// a. Inputs to be copied to the internal GPU buffer from the external CPU
// input buffer
// b. Outputs to be copied from the internal GPU buffer to the
// external CPU buffer
//
// The user of this interface is responsible for copying the inputs prior to
// running the GPU kernels and outputs post running with the other interfaces
// provided here.
virtual absl::Status RunWithoutExternalBufferCopy() = 0;
// Copies from the external input tensor (normally CPU buffer) to the internal
// OpenCL buffer.
virtual absl::Status CopyFromExternalInput(int index) = 0;
// Copies from the internal output OpenCL buffer to the external output tensor
virtual absl::Status CopyToExternalOutput(int index) = 0;
};
} // namespace cl
} // namespace gpu
} // namespace tflite