From 3c7a1d88277833dfa9b76f111f19569e3951174c Mon Sep 17 00:00:00 2001 From: Taehee Jeong <taeheej@google.com> Date: Tue, 30 Jun 2020 01:36:33 -0700 Subject: [PATCH] Support binding of Metal buffer to quantized models When an input or output of quantized model is read from Metal buffer directly, correctly link the internal FP32 tensor to the buffer and disable dequantization/quantization for that input/output. PiperOrigin-RevId: 318978466 Change-Id: I052bf6c1daab6425e6ff66166221f36b075326e8 --- tensorflow/lite/delegates/gpu/metal_delegate.mm | 8 ++++++++ tensorflow/lite/delegates/gpu/metal_delegate_internal.h | 6 ++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.mm b/tensorflow/lite/delegates/gpu/metal_delegate.mm index 01fa9dd7679..45bfe1f3b2f 100644 --- a/tensorflow/lite/delegates/gpu/metal_delegate.mm +++ b/tensorflow/lite/delegates/gpu/metal_delegate.mm @@ -205,6 +205,14 @@ class Delegate { } absl::Status BindBufferToTensor(id<MTLBuffer> buffer, int tensor_index) { + // The tensor index is expected to be an input or output tensor of the interpreter. + // For quantized model, the buffer should be linked with their dequantized counterpart. + if (quant_conversion_map_.find(tensor_index) != quant_conversion_map_.end()) { + tensor_index = quant_conversion_map_[tensor_index]; + // remove [dequantized tensor ID] -> [quantized tensor ID] mapping, to prevent extra + // dequant/quant on in/outputs. + quant_conversion_map_.erase(tensor_index); + } for (auto& input : graph_inputs_) { if (input.tensor_id == tensor_index) { input_output_buffers_[input.id] = buffer; diff --git a/tensorflow/lite/delegates/gpu/metal_delegate_internal.h b/tensorflow/lite/delegates/gpu/metal_delegate_internal.h index a479b5c6e28..82bc720844e 100644 --- a/tensorflow/lite/delegates/gpu/metal_delegate_internal.h +++ b/tensorflow/lite/delegates/gpu/metal_delegate_internal.h @@ -24,9 +24,11 @@ struct TfLiteDelegate; // Binds Metal buffer to an input or an output tensor in the initialized // delegate. Bound buffer should have sufficient storage to accommodate all -// elements of a tensor. Returns non-zero on success, or zero otherwise. +// elements of a tensor. For quantized model, the buffer is bound to internal +// dequantized float32 tensor. +// Returns non-zero on success, or zero otherwise. // -// *** Must be called *before* `Interpreter::ModifyGraphWithDelegate`. *** +// *** Must be called *after* `Interpreter::ModifyGraphWithDelegate`. *** bool TFLGpuDelegateBindMetalBufferToTensor(TfLiteDelegate* delegate, int tensor_index, id<MTLBuffer> metal_buffer);