TFLGpuDelegateSetCommandEncoder removed.

Using only TFLGpuDelegateSetCommandBuffer.

PiperOrigin-RevId: 347776756
Change-Id: I6fb7752c09d7c5b77c8e1b0fed29607870efedb2
This commit is contained in:
Raman Sarokin 2020-12-16 01:08:23 -08:00 committed by TensorFlower Gardener
parent ac0fa2c570
commit 626a13aafa
2 changed files with 14 additions and 48 deletions

View File

@ -230,10 +230,6 @@ class Delegate {
return absl::NotFoundError("Couldn't find tensor: " + std::to_string(tensor_index)); return absl::NotFoundError("Couldn't find tensor: " + std::to_string(tensor_index));
} }
void SetCommandEncoder(id<MTLComputeCommandEncoder> encoder) {
external_command_encoder_ = encoder;
}
void SetCommandBuffer(id<MTLCommandBuffer> command_buffer) { void SetCommandBuffer(id<MTLCommandBuffer> command_buffer) {
external_command_buffer_ = command_buffer; external_command_buffer_ = command_buffer;
} }
@ -454,14 +450,10 @@ class Delegate {
// memory each time. // memory each time.
__block volatile bool buffer_completed = false; __block volatile bool buffer_completed = false;
id<MTLCommandBuffer> command_buffer = external_command_buffer_; id<MTLCommandBuffer> command_buffer = external_command_buffer_;
id<MTLComputeCommandEncoder> encoder = external_command_encoder_; if (external_command_buffer_ == nil) {
if (external_command_buffer_ == nil && external_command_encoder_ == nil) {
command_buffer = [command_queue_ commandBuffer]; command_buffer = [command_queue_ commandBuffer];
} }
if (external_command_encoder_ == nil) { const bool flush = external_command_buffer_ == nil &&
encoder = [command_buffer computeCommandEncoder];
}
const bool flush = external_command_encoder_ == nil && external_command_buffer_ == nil &&
(options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive || (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive ||
options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive); options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive);
const int flush_period = 8; const int flush_period = 8;
@ -480,43 +472,36 @@ class Delegate {
void* gpu_ptr = [input_output_buffers_[input.id] contents]; void* gpu_ptr = [input_output_buffers_[input.id] contents];
std::memcpy(gpu_ptr, tensor->data.f, input.shape.DimensionsProduct() * sizeof(float)); std::memcpy(gpu_ptr, tensor->data.f, input.shape.DimensionsProduct() * sizeof(float));
if (input_output_buffers_[input.id] == bphwc4_buffers_[input.id]) continue; if (input_output_buffers_[input.id] == bphwc4_buffers_[input.id]) continue;
[converter_to_BPHWC4_ convertWithEncoder:encoder id<MTLComputeCommandEncoder> input_encoder = [command_buffer computeCommandEncoder];
[converter_to_BPHWC4_ convertWithEncoder:input_encoder
shape:input.shape shape:input.shape
sourceBuffer:input_output_buffers_[input.id] sourceBuffer:input_output_buffers_[input.id]
convertedBuffer:bphwc4_buffers_[input.id]]; convertedBuffer:bphwc4_buffers_[input.id]];
} [input_encoder endEncoding];
if (flush) {
[encoder endEncoding];
[command_buffer commit];
} }
if (external_command_encoder_ != nil || @autoreleasepool {
options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive) {
// encoder == external_command_encoder_
inference_context_.EncodeWithEncoder(encoder, bphwc4_buffers_);
} else {
if (flush) { if (flush) {
[command_buffer commit];
inference_context_.EncodeWithCommandQueue(command_queue_, bphwc4_buffers_, flush_period); inference_context_.EncodeWithCommandQueue(command_queue_, bphwc4_buffers_, flush_period);
command_buffer = [command_queue_ commandBuffer]; command_buffer = [command_queue_ commandBuffer];
encoder = [command_buffer computeCommandEncoder];
} else { } else {
[encoder endEncoding];
inference_context_.EncodeWithCommandBuffer(command_buffer, bphwc4_buffers_); inference_context_.EncodeWithCommandBuffer(command_buffer, bphwc4_buffers_);
encoder = [command_buffer computeCommandEncoder];
} }
} }
for (const auto& output : graph_outputs_) { for (const auto& output : graph_outputs_) {
if (output.set_externally) continue; if (output.set_externally) continue;
if (bphwc4_buffers_[output.id] == input_output_buffers_[output.id]) continue; if (bphwc4_buffers_[output.id] == input_output_buffers_[output.id]) continue;
[converter_from_BPHWC4_ convertWithEncoder:encoder id<MTLComputeCommandEncoder> output_encoder = [command_buffer computeCommandEncoder];
[converter_from_BPHWC4_ convertWithEncoder:output_encoder
shape:output.shape shape:output.shape
sourceBuffer:bphwc4_buffers_[output.id] sourceBuffer:bphwc4_buffers_[output.id]
convertedBuffer:input_output_buffers_[output.id]]; convertedBuffer:input_output_buffers_[output.id]];
[output_encoder endEncoding];
} }
if (external_command_encoder_ == nil && external_command_buffer_ == nil) { if (external_command_buffer_ == nil) {
[encoder endEncoding];
if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive) { if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive) {
[command_buffer addCompletedHandler:^(id<MTLCommandBuffer>) { [command_buffer addCompletedHandler:^(id<MTLCommandBuffer>) {
buffer_completed = true; buffer_completed = true;
@ -552,14 +537,9 @@ class Delegate {
} }
} }
} else { } else {
if (external_command_buffer_ != nil) { // External command buffer must be set before every invoke call.
[encoder endEncoding]; external_command_buffer_ = nil;
// External command buffer must be set before every invoke call. // External command buffer is assigned so all output buffers are controlled by a user.
external_command_buffer_ = nil;
}
// External command encoder must be set before every invoke call.
external_command_encoder_ = nil;
// External command encoder is assigned so all output buffers are controlled by a user.
for (const auto& output : graph_outputs_) { for (const auto& output : graph_outputs_) {
if (!output.set_externally) { if (!output.set_externally) {
return absl::InternalError( return absl::InternalError(
@ -627,7 +607,6 @@ class Delegate {
std::vector<BufferDescriptor> graph_inputs_; std::vector<BufferDescriptor> graph_inputs_;
std::vector<BufferDescriptor> graph_outputs_; std::vector<BufferDescriptor> graph_outputs_;
id<MTLComputeCommandEncoder> external_command_encoder_ = nil;
id<MTLCommandBuffer> external_command_buffer_ = nil; id<MTLCommandBuffer> external_command_buffer_ = nil;
id<MTLCommandQueue> command_queue_; id<MTLCommandQueue> command_queue_;
std::unique_ptr<GpuAlarmClock> gpu_alarm_clock_; std::unique_ptr<GpuAlarmClock> gpu_alarm_clock_;
@ -720,14 +699,6 @@ bool TFLGpuDelegateBindMetalBufferToTensor(TfLiteDelegate* delegate, int tensor_
// Note: This function is not exposed in `metal_delegate.h`, but it's exposed in // Note: This function is not exposed in `metal_delegate.h`, but it's exposed in
// `metal_delegate_internal.h`. // `metal_delegate_internal.h`.
bool TFLGpuDelegateSetCommandEncoder(
TfLiteDelegate* delegate, id<MTLComputeCommandEncoder> encoder) {
auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate);
if (!metal_delegate) return false;
metal_delegate->SetCommandEncoder(encoder);
return true;
}
bool TFLGpuDelegateSetCommandBuffer(TfLiteDelegate* delegate, bool TFLGpuDelegateSetCommandBuffer(TfLiteDelegate* delegate,
id<MTLCommandBuffer> command_buffer) { id<MTLCommandBuffer> command_buffer) {
auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate); auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate);

View File

@ -33,11 +33,6 @@ bool TFLGpuDelegateBindMetalBufferToTensor(TfLiteDelegate* delegate,
int tensor_index, int tensor_index,
id<MTLBuffer> metal_buffer); id<MTLBuffer> metal_buffer);
// Binds user-defined MTLComputeCommandEncoder. The delegate puts all GPU tasks
// into this encoder instead of the internal encoder.
bool TFLGpuDelegateSetCommandEncoder(TfLiteDelegate* delegate,
id<MTLComputeCommandEncoder> encoder);
// Binds user-defined MTLCommandBuffer. The delegate puts all GPU tasks // Binds user-defined MTLCommandBuffer. The delegate puts all GPU tasks
// into this buffer instead of the internal command buffer. // into this buffer instead of the internal command buffer.
bool TFLGpuDelegateSetCommandBuffer(TfLiteDelegate* delegate, bool TFLGpuDelegateSetCommandBuffer(TfLiteDelegate* delegate,