OpenGL API: Add explicit wait when data output goes to CPU.
PiperOrigin-RevId: 270301271
This commit is contained in:
parent
400c1801bc
commit
f380896421
@ -372,7 +372,12 @@ class InferenceRunnerImpl : public InferenceRunner {
|
||||
const std::vector<TensorTieDef>& outputs,
|
||||
TensorTieFactory* tie_factory) {
|
||||
RETURN_IF_ERROR(LinkTensors(inputs, tie_factory, &inputs_));
|
||||
return LinkTensors(outputs, tie_factory, &outputs_);
|
||||
RETURN_IF_ERROR(LinkTensors(outputs, tie_factory, &outputs_));
|
||||
for (const auto& def : outputs) {
|
||||
output_to_cpu_ |= def.external_def.object_def.object_type ==
|
||||
gpu::ObjectType::CPU_MEMORY;
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
std::vector<TensorObjectDef> inputs() const override {
|
||||
@ -421,6 +426,10 @@ class InferenceRunnerImpl : public InferenceRunner {
|
||||
for (auto& obj : outputs_) {
|
||||
RETURN_IF_ERROR(obj->CopyToExternalObject());
|
||||
}
|
||||
RETURN_IF_ERROR(runtime_->command_queue()->Flush());
|
||||
if (output_to_cpu_) {
|
||||
RETURN_IF_ERROR(runtime_->command_queue()->WaitForCompletion());
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
@ -451,6 +460,7 @@ class InferenceRunnerImpl : public InferenceRunner {
|
||||
std::unique_ptr<ObjectManager> objects_;
|
||||
std::vector<std::unique_ptr<TensorTie>> inputs_;
|
||||
std::vector<std::unique_ptr<TensorTie>> outputs_;
|
||||
bool output_to_cpu_ = false;
|
||||
};
|
||||
|
||||
class InferenceBuilderImpl : public InferenceBuilder {
|
||||
|
@ -39,6 +39,8 @@ class DefaultCommandQueue : public CommandQueue {
|
||||
// TODO(akulik): Maybe let the user choose which wait method to use.
|
||||
return GlActiveSyncWait();
|
||||
}
|
||||
|
||||
Status Flush() override { return OkStatus(); }
|
||||
};
|
||||
|
||||
// On Adreno do flush periodically as this affects performance. Command queue
|
||||
@ -60,6 +62,20 @@ class AdrenoCommandQueue : public DefaultCommandQueue {
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status WaitForCompletion() override {
|
||||
program_counter_ = 0;
|
||||
return DefaultCommandQueue::WaitForCompletion();
|
||||
}
|
||||
|
||||
Status Flush() final {
|
||||
// Flush exactly once after the last dispatch.
|
||||
if (program_counter_ != 0) {
|
||||
program_counter_ = 0;
|
||||
glFlush();
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
const int flush_every_n_;
|
||||
int program_counter_ = 0;
|
||||
|
@ -38,6 +38,9 @@ class CommandQueue {
|
||||
virtual Status Dispatch(const GlProgram& program,
|
||||
const uint3& workgroups) = 0;
|
||||
|
||||
// Called at the end of dispatching of all programs.
|
||||
virtual Status Flush() = 0;
|
||||
|
||||
// Waits until all programs dispatched prior this call are completed.
|
||||
virtual Status WaitForCompletion() = 0;
|
||||
};
|
||||
|
@ -59,6 +59,8 @@ class Runtime {
|
||||
// Gets access to objects created while executing generated code.
|
||||
const ObjectManager* internal_objects() const { return &internal_objects_; }
|
||||
|
||||
CommandQueue* command_queue() { return command_queue_; }
|
||||
|
||||
RuntimeStats stats() const {
|
||||
RuntimeStats stats;
|
||||
stats.const_objects = const_objects_.stats();
|
||||
|
Loading…
Reference in New Issue
Block a user