diff --git a/tensorflow/lite/delegates/gpu/gl/api2.cc b/tensorflow/lite/delegates/gpu/gl/api2.cc index 38951c5e08e..921bb5736b0 100644 --- a/tensorflow/lite/delegates/gpu/gl/api2.cc +++ b/tensorflow/lite/delegates/gpu/gl/api2.cc @@ -372,7 +372,12 @@ class InferenceRunnerImpl : public InferenceRunner { const std::vector& outputs, TensorTieFactory* tie_factory) { RETURN_IF_ERROR(LinkTensors(inputs, tie_factory, &inputs_)); - return LinkTensors(outputs, tie_factory, &outputs_); + RETURN_IF_ERROR(LinkTensors(outputs, tie_factory, &outputs_)); + for (const auto& def : outputs) { + output_to_cpu_ |= def.external_def.object_def.object_type == + gpu::ObjectType::CPU_MEMORY; + } + return OkStatus(); } std::vector inputs() const override { @@ -421,6 +426,10 @@ class InferenceRunnerImpl : public InferenceRunner { for (auto& obj : outputs_) { RETURN_IF_ERROR(obj->CopyToExternalObject()); } + RETURN_IF_ERROR(runtime_->command_queue()->Flush()); + if (output_to_cpu_) { + RETURN_IF_ERROR(runtime_->command_queue()->WaitForCompletion()); + } return OkStatus(); } @@ -451,6 +460,7 @@ class InferenceRunnerImpl : public InferenceRunner { std::unique_ptr objects_; std::vector> inputs_; std::vector> outputs_; + bool output_to_cpu_ = false; }; class InferenceBuilderImpl : public InferenceBuilder { diff --git a/tensorflow/lite/delegates/gpu/gl/command_queue.cc b/tensorflow/lite/delegates/gpu/gl/command_queue.cc index 462d52e1258..87823761127 100644 --- a/tensorflow/lite/delegates/gpu/gl/command_queue.cc +++ b/tensorflow/lite/delegates/gpu/gl/command_queue.cc @@ -39,6 +39,8 @@ class DefaultCommandQueue : public CommandQueue { // TODO(akulik): Maybe let the user choose which wait method to use. return GlActiveSyncWait(); } + + Status Flush() override { return OkStatus(); } }; // On Adreno do flush periodically as this affects performance. Command queue @@ -60,6 +62,20 @@ class AdrenoCommandQueue : public DefaultCommandQueue { return OkStatus(); } + Status WaitForCompletion() override { + program_counter_ = 0; + return DefaultCommandQueue::WaitForCompletion(); + } + + Status Flush() final { + // Flush exactly once after the last dispatch. + if (program_counter_ != 0) { + program_counter_ = 0; + glFlush(); + } + return OkStatus(); + } + private: const int flush_every_n_; int program_counter_ = 0; diff --git a/tensorflow/lite/delegates/gpu/gl/command_queue.h b/tensorflow/lite/delegates/gpu/gl/command_queue.h index a4c21001cf2..6695852fc86 100644 --- a/tensorflow/lite/delegates/gpu/gl/command_queue.h +++ b/tensorflow/lite/delegates/gpu/gl/command_queue.h @@ -38,6 +38,9 @@ class CommandQueue { virtual Status Dispatch(const GlProgram& program, const uint3& workgroups) = 0; + // Called at the end of dispatching of all programs. + virtual Status Flush() = 0; + // Waits until all programs dispatched prior this call are completed. virtual Status WaitForCompletion() = 0; }; diff --git a/tensorflow/lite/delegates/gpu/gl/runtime.h b/tensorflow/lite/delegates/gpu/gl/runtime.h index 46e0732cd32..b66a7fdfaa4 100644 --- a/tensorflow/lite/delegates/gpu/gl/runtime.h +++ b/tensorflow/lite/delegates/gpu/gl/runtime.h @@ -59,6 +59,8 @@ class Runtime { // Gets access to objects created while executing generated code. const ObjectManager* internal_objects() const { return &internal_objects_; } + CommandQueue* command_queue() { return command_queue_; } + RuntimeStats stats() const { RuntimeStats stats; stats.const_objects = const_objects_.stats();