OpenGL API: Add explicit wait when data output goes to CPU.

PiperOrigin-RevId: 270301271
This commit is contained in:
A. Unique TensorFlower 2019-09-20 10:36:54 -07:00 committed by TensorFlower Gardener
parent 400c1801bc
commit f380896421
4 changed files with 32 additions and 1 deletions

View File

@ -372,7 +372,12 @@ class InferenceRunnerImpl : public InferenceRunner {
const std::vector<TensorTieDef>& outputs,
TensorTieFactory* tie_factory) {
RETURN_IF_ERROR(LinkTensors(inputs, tie_factory, &inputs_));
return LinkTensors(outputs, tie_factory, &outputs_);
RETURN_IF_ERROR(LinkTensors(outputs, tie_factory, &outputs_));
for (const auto& def : outputs) {
output_to_cpu_ |= def.external_def.object_def.object_type ==
gpu::ObjectType::CPU_MEMORY;
}
return OkStatus();
}
std::vector<TensorObjectDef> inputs() const override {
@ -421,6 +426,10 @@ class InferenceRunnerImpl : public InferenceRunner {
for (auto& obj : outputs_) {
RETURN_IF_ERROR(obj->CopyToExternalObject());
}
RETURN_IF_ERROR(runtime_->command_queue()->Flush());
if (output_to_cpu_) {
RETURN_IF_ERROR(runtime_->command_queue()->WaitForCompletion());
}
return OkStatus();
}
@ -451,6 +460,7 @@ class InferenceRunnerImpl : public InferenceRunner {
std::unique_ptr<ObjectManager> objects_;
std::vector<std::unique_ptr<TensorTie>> inputs_;
std::vector<std::unique_ptr<TensorTie>> outputs_;
bool output_to_cpu_ = false;
};
class InferenceBuilderImpl : public InferenceBuilder {

View File

@ -39,6 +39,8 @@ class DefaultCommandQueue : public CommandQueue {
// TODO(akulik): Maybe let the user choose which wait method to use.
return GlActiveSyncWait();
}
Status Flush() override { return OkStatus(); }
};
// On Adreno do flush periodically as this affects performance. Command queue
@ -60,6 +62,20 @@ class AdrenoCommandQueue : public DefaultCommandQueue {
return OkStatus();
}
Status WaitForCompletion() override {
program_counter_ = 0;
return DefaultCommandQueue::WaitForCompletion();
}
Status Flush() final {
// Flush exactly once after the last dispatch.
if (program_counter_ != 0) {
program_counter_ = 0;
glFlush();
}
return OkStatus();
}
private:
const int flush_every_n_;
int program_counter_ = 0;

View File

@ -38,6 +38,9 @@ class CommandQueue {
virtual Status Dispatch(const GlProgram& program,
const uint3& workgroups) = 0;
// Called at the end of dispatching of all programs.
virtual Status Flush() = 0;
// Waits until all programs dispatched prior this call are completed.
virtual Status WaitForCompletion() = 0;
};

View File

@ -59,6 +59,8 @@ class Runtime {
// Gets access to objects created while executing generated code.
const ObjectManager* internal_objects() const { return &internal_objects_; }
CommandQueue* command_queue() { return command_queue_; }
RuntimeStats stats() const {
RuntimeStats stats;
stats.const_objects = const_objects_.stats();