Enabling lazy variable tensor copy.
PiperOrigin-RevId: 356301869 Change-Id: I1b8e2b5192d08294d2ee685f1a101690e2929983
This commit is contained in:
parent
e2b5e921b8
commit
7516f6ee75
@ -440,7 +440,7 @@ class TensorTieFactory {
|
||||
std::unique_ptr<TensorObjectConverterBuilder> converter_builder_;
|
||||
};
|
||||
|
||||
class InferenceRunnerImpl : public InferenceRunner {
|
||||
class InferenceRunnerImpl : public CLInferenceRunner {
|
||||
public:
|
||||
InferenceRunnerImpl(Environment* environment,
|
||||
std::unique_ptr<InferenceContext> context
|
||||
@ -503,19 +503,36 @@ class InferenceRunnerImpl : public InferenceRunner {
|
||||
return outputs_[index]->SetExternalObject(object);
|
||||
}
|
||||
|
||||
absl::Status CopyFromExternalInput(int index) override {
|
||||
if (index > inputs_.size()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("Input id ", index, " is an invalid input index."));
|
||||
}
|
||||
return inputs_[index]->CopyFromExternalObject();
|
||||
}
|
||||
|
||||
absl::Status CopyToExternalOutput(int index) override {
|
||||
if (index > outputs_.size()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("Output id ", index, " is an invalid output index"));
|
||||
}
|
||||
return outputs_[index]->CopyToExternalObject();
|
||||
}
|
||||
|
||||
absl::Status Run() override {
|
||||
#ifdef CL_DELEGATE_ALLOW_GL
|
||||
if (gl_interop_fabric_) {
|
||||
RETURN_IF_ERROR(gl_interop_fabric_->Start());
|
||||
}
|
||||
#endif
|
||||
for (auto& obj : inputs_) {
|
||||
RETURN_IF_ERROR(obj->CopyFromExternalObject());
|
||||
for (int i = 0; i < inputs_.size(); i++) {
|
||||
RETURN_IF_ERROR(CopyFromExternalInput(i));
|
||||
}
|
||||
RETURN_IF_ERROR(context_->AddToQueue(queue_));
|
||||
clFlush(queue_->queue());
|
||||
for (auto& obj : outputs_) {
|
||||
RETURN_IF_ERROR(obj->CopyToExternalObject());
|
||||
|
||||
RETURN_IF_ERROR(RunWithoutExternalBufferCopy());
|
||||
|
||||
for (int i = 0; i < outputs_.size(); i++) {
|
||||
RETURN_IF_ERROR(CopyToExternalOutput(i));
|
||||
}
|
||||
#ifdef CL_DELEGATE_ALLOW_GL
|
||||
if (gl_interop_fabric_) {
|
||||
@ -525,6 +542,13 @@ class InferenceRunnerImpl : public InferenceRunner {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status RunWithoutExternalBufferCopy() override {
|
||||
RETURN_IF_ERROR(context_->AddToQueue(queue_));
|
||||
clFlush(queue_->queue());
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
static absl::Status LinkTensors(
|
||||
const std::vector<TensorTieDef>& defs, TensorTieFactory* factory,
|
||||
|
@ -140,6 +140,28 @@ absl::Status NewInferenceEnvironment(
|
||||
std::unique_ptr<InferenceEnvironment>* environment,
|
||||
InferenceEnvironmentProperties* properties /* optional */);
|
||||
|
||||
class CLInferenceRunner : public ::tflite::gpu::InferenceRunner {
|
||||
public:
|
||||
// The RunWithoutExternalBufferCopy provides a contract where the user of this
|
||||
// interface does not need
|
||||
// a. Inputs to be copied to the internal GPU buffer from the external CPU
|
||||
// input buffer
|
||||
// b. Outputs to be copied from the internal GPU buffer to the
|
||||
// external CPU buffer
|
||||
//
|
||||
// The user of this interface is responsible for copying the inputs prior to
|
||||
// running the GPU kernels and outputs post running with the other interfaces
|
||||
// provided here.
|
||||
virtual absl::Status RunWithoutExternalBufferCopy() = 0;
|
||||
|
||||
// Copies from the external input tensor (normally CPU buffer) to the internal
|
||||
// OpenCL buffer.
|
||||
virtual absl::Status CopyFromExternalInput(int index) = 0;
|
||||
|
||||
// Copies from the internal output OpenCL buffer to the external output tensor
|
||||
virtual absl::Status CopyToExternalOutput(int index) = 0;
|
||||
};
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
Loading…
x
Reference in New Issue
Block a user