Enabling lazy variable tensor copy.
PiperOrigin-RevId: 356301869 Change-Id: I1b8e2b5192d08294d2ee685f1a101690e2929983
This commit is contained in:
parent
e2b5e921b8
commit
7516f6ee75
@ -440,7 +440,7 @@ class TensorTieFactory {
|
|||||||
std::unique_ptr<TensorObjectConverterBuilder> converter_builder_;
|
std::unique_ptr<TensorObjectConverterBuilder> converter_builder_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class InferenceRunnerImpl : public InferenceRunner {
|
class InferenceRunnerImpl : public CLInferenceRunner {
|
||||||
public:
|
public:
|
||||||
InferenceRunnerImpl(Environment* environment,
|
InferenceRunnerImpl(Environment* environment,
|
||||||
std::unique_ptr<InferenceContext> context
|
std::unique_ptr<InferenceContext> context
|
||||||
@ -503,19 +503,36 @@ class InferenceRunnerImpl : public InferenceRunner {
|
|||||||
return outputs_[index]->SetExternalObject(object);
|
return outputs_[index]->SetExternalObject(object);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::Status CopyFromExternalInput(int index) override {
|
||||||
|
if (index > inputs_.size()) {
|
||||||
|
return absl::NotFoundError(
|
||||||
|
absl::StrCat("Input id ", index, " is an invalid input index."));
|
||||||
|
}
|
||||||
|
return inputs_[index]->CopyFromExternalObject();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status CopyToExternalOutput(int index) override {
|
||||||
|
if (index > outputs_.size()) {
|
||||||
|
return absl::NotFoundError(
|
||||||
|
absl::StrCat("Output id ", index, " is an invalid output index"));
|
||||||
|
}
|
||||||
|
return outputs_[index]->CopyToExternalObject();
|
||||||
|
}
|
||||||
|
|
||||||
absl::Status Run() override {
|
absl::Status Run() override {
|
||||||
#ifdef CL_DELEGATE_ALLOW_GL
|
#ifdef CL_DELEGATE_ALLOW_GL
|
||||||
if (gl_interop_fabric_) {
|
if (gl_interop_fabric_) {
|
||||||
RETURN_IF_ERROR(gl_interop_fabric_->Start());
|
RETURN_IF_ERROR(gl_interop_fabric_->Start());
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
for (auto& obj : inputs_) {
|
for (int i = 0; i < inputs_.size(); i++) {
|
||||||
RETURN_IF_ERROR(obj->CopyFromExternalObject());
|
RETURN_IF_ERROR(CopyFromExternalInput(i));
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(context_->AddToQueue(queue_));
|
|
||||||
clFlush(queue_->queue());
|
RETURN_IF_ERROR(RunWithoutExternalBufferCopy());
|
||||||
for (auto& obj : outputs_) {
|
|
||||||
RETURN_IF_ERROR(obj->CopyToExternalObject());
|
for (int i = 0; i < outputs_.size(); i++) {
|
||||||
|
RETURN_IF_ERROR(CopyToExternalOutput(i));
|
||||||
}
|
}
|
||||||
#ifdef CL_DELEGATE_ALLOW_GL
|
#ifdef CL_DELEGATE_ALLOW_GL
|
||||||
if (gl_interop_fabric_) {
|
if (gl_interop_fabric_) {
|
||||||
@ -525,6 +542,13 @@ class InferenceRunnerImpl : public InferenceRunner {
|
|||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::Status RunWithoutExternalBufferCopy() override {
|
||||||
|
RETURN_IF_ERROR(context_->AddToQueue(queue_));
|
||||||
|
clFlush(queue_->queue());
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static absl::Status LinkTensors(
|
static absl::Status LinkTensors(
|
||||||
const std::vector<TensorTieDef>& defs, TensorTieFactory* factory,
|
const std::vector<TensorTieDef>& defs, TensorTieFactory* factory,
|
||||||
|
@ -140,6 +140,28 @@ absl::Status NewInferenceEnvironment(
|
|||||||
std::unique_ptr<InferenceEnvironment>* environment,
|
std::unique_ptr<InferenceEnvironment>* environment,
|
||||||
InferenceEnvironmentProperties* properties /* optional */);
|
InferenceEnvironmentProperties* properties /* optional */);
|
||||||
|
|
||||||
|
class CLInferenceRunner : public ::tflite::gpu::InferenceRunner {
|
||||||
|
public:
|
||||||
|
// The RunWithoutExternalBufferCopy provides a contract where the user of this
|
||||||
|
// interface does not need
|
||||||
|
// a. Inputs to be copied to the internal GPU buffer from the external CPU
|
||||||
|
// input buffer
|
||||||
|
// b. Outputs to be copied from the internal GPU buffer to the
|
||||||
|
// external CPU buffer
|
||||||
|
//
|
||||||
|
// The user of this interface is responsible for copying the inputs prior to
|
||||||
|
// running the GPU kernels and outputs post running with the other interfaces
|
||||||
|
// provided here.
|
||||||
|
virtual absl::Status RunWithoutExternalBufferCopy() = 0;
|
||||||
|
|
||||||
|
// Copies from the external input tensor (normally CPU buffer) to the internal
|
||||||
|
// OpenCL buffer.
|
||||||
|
virtual absl::Status CopyFromExternalInput(int index) = 0;
|
||||||
|
|
||||||
|
// Copies from the internal output OpenCL buffer to the external output tensor
|
||||||
|
virtual absl::Status CopyToExternalOutput(int index) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
Loading…
x
Reference in New Issue
Block a user