Integer quantized model support for Metal delegate
The new feature is guarded by enable_quantization flag. PiperOrigin-RevId: 315139125 Change-Id: I944008fcb914e31bd9e54634f85c6b2e0be0b8fe
This commit is contained in:
parent
735bb0fc23
commit
115ca425e4
@ -90,6 +90,7 @@ objc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_builder",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||
"//tensorflow/lite/delegates/gpu/common:quantization_util",
|
||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||
|
@ -58,6 +58,8 @@ typedef struct {
|
||||
// Allows to quantify tensors, downcast values, process in float16 etc.
|
||||
bool allow_precision_loss;
|
||||
TFLGpuDelegateWaitType wait_type;
|
||||
// Allows execution of integer quantized models
|
||||
bool enable_quantization;
|
||||
} TFLGpuDelegateOptions;
|
||||
|
||||
// Creates a new delegate instance that need to be destroyed with
|
||||
|
@ -29,10 +29,12 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/lite/builtin_ops.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/context_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/convert.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_builder.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
|
||||
@ -40,10 +42,11 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/metal/api.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/common.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -176,6 +179,7 @@ class Delegate {
|
||||
} else {
|
||||
// Default options.
|
||||
options_.allow_precision_loss = false;
|
||||
options_.enable_quantization = false;
|
||||
options_.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive;
|
||||
}
|
||||
metal_device_ = MTLCreateSystemDefaultDevice();
|
||||
@ -227,16 +231,38 @@ class Delegate {
|
||||
external_command_encoder_ = encoder;
|
||||
}
|
||||
|
||||
// This directs the runtime to allocate memory for input/output temporary
|
||||
// tensors that require dequantization/quantization.
|
||||
absl::Status GetRequiredTemporaries(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteIntArray** temporaries_array_ptr) {
|
||||
if (quant_conversion_map_.empty()) return absl::OkStatus();
|
||||
|
||||
std::vector<int> temporary_tensor_ids;
|
||||
for (auto index : input_tensor_ids_) {
|
||||
if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) {
|
||||
temporary_tensor_ids.push_back(index);
|
||||
}
|
||||
}
|
||||
for (auto index : output_tensor_ids_) {
|
||||
if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) {
|
||||
temporary_tensor_ids.push_back(index);
|
||||
}
|
||||
}
|
||||
*temporaries_array_ptr = TfLiteIntArrayCreate(temporary_tensor_ids.size());
|
||||
for (int i = 0; i < temporary_tensor_ids.size(); ++i) {
|
||||
(*temporaries_array_ptr)->data[i] = temporary_tensor_ids[i];
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Prepare(TfLiteContext* context, const TfLiteDelegateParams* delegate_params) {
|
||||
// Extract TFLite delegate execution plan from the context and convert it into GraphFloat32.
|
||||
GraphFloat32 graph;
|
||||
RETURN_IF_ERROR(BuildModel(context, delegate_params, &graph));
|
||||
|
||||
// Apply general transformations on the graph.
|
||||
NullTransformationReporter reporter;
|
||||
ModelTransformer transformer(&graph, &reporter);
|
||||
if (!ApplyGeneralTransformations(&transformer)) {
|
||||
return absl::InternalError("Graph general transformations failed");
|
||||
quant_conversion_map_.clear();
|
||||
if (options_.enable_quantization) {
|
||||
RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, &graph, &quant_conversion_map_));
|
||||
} else {
|
||||
RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, &graph));
|
||||
}
|
||||
|
||||
// TODO(impjdi): Remove code duplication.
|
||||
@ -260,17 +286,23 @@ class Delegate {
|
||||
//
|
||||
// Note that graph.inputs() cannot be used directly, as the notion of graph input has a
|
||||
// different meaning in public API and GPU-internal API.
|
||||
inputs_.reserve(delegate_params->input_tensors->size);
|
||||
for (int i = 0; i < delegate_params->input_tensors->size; ++i) {
|
||||
const int tensor_index = delegate_params->input_tensors->data[i];
|
||||
auto* tensor = context->tensors + tensor_index;
|
||||
if (tensor->allocation_type == TfLiteAllocationType::kTfLiteMmapRo) continue;
|
||||
for (int tensor_index : TfLiteIntArrayView(delegate_params->input_tensors)) {
|
||||
auto* tensor = &context->tensors[tensor_index];
|
||||
if (IsConstantTensor(tensor)) continue;
|
||||
// For quantized models, actual inputs of GPU graph are float tensors, so the 8-bit inputs
|
||||
// to the delegate kernel need to be dequantized berfore feeding to the GPU graph.
|
||||
if (options_.enable_quantization &&
|
||||
quant_conversion_map_.find(tensor_index) != quant_conversion_map_.end()) {
|
||||
tensor_index = quant_conversion_map_[tensor_index];
|
||||
tensor = &context->tensors[tensor_index];
|
||||
}
|
||||
const auto* input = find_value(tensor_index);
|
||||
if (!input || tensor->type != TfLiteType::kTfLiteFloat32) {
|
||||
return absl::NotFoundError("Input tensor is not found in the graph.");
|
||||
}
|
||||
|
||||
inputs_.push_back(input->id);
|
||||
input_tensor_ids_.push_back(tensor_index);
|
||||
tensor->buffer_handle = input->id;
|
||||
tensor->delegate = &delegate_;
|
||||
}
|
||||
@ -279,16 +311,23 @@ class Delegate {
|
||||
//
|
||||
// Note that graph.outputs() cannot be used directly, as the notion of graph output has a
|
||||
// different meaning in public API and GPU-internal API.
|
||||
outputs_.reserve(delegate_params->output_tensors->size);
|
||||
for (int i = 0; i < delegate_params->output_tensors->size; ++i) {
|
||||
const int tensor_index = delegate_params->output_tensors->data[i];
|
||||
auto* tensor = context->tensors + tensor_index;
|
||||
for (int tensor_index : TfLiteIntArrayView(delegate_params->output_tensors)) {
|
||||
auto* tensor = &context->tensors[tensor_index];
|
||||
if (IsConstantTensor(tensor)) continue;
|
||||
// For quantized models, actual outputs of GPU graph are float tensors, so they should be
|
||||
// quantized to be the 8-bit outputs of delegate.
|
||||
if (options_.enable_quantization &&
|
||||
quant_conversion_map_.find(tensor_index) != quant_conversion_map_.end()) {
|
||||
tensor_index = quant_conversion_map_[tensor_index];
|
||||
tensor = &context->tensors[tensor_index];
|
||||
}
|
||||
const auto* output = find_value(tensor_index);
|
||||
if (!output || tensor->type != TfLiteType::kTfLiteFloat32) {
|
||||
return absl::NotFoundError("Output tensor is not found in the graph.");
|
||||
}
|
||||
|
||||
outputs_.push_back(output->id);
|
||||
output_tensor_ids_.push_back(tensor_index);
|
||||
tensor->buffer_handle = output->id;
|
||||
tensor->delegate = &delegate_;
|
||||
}
|
||||
@ -422,12 +461,17 @@ class Delegate {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
const bool is_quantized_model = !quant_conversion_map_.empty();
|
||||
if (is_quantized_model) {
|
||||
RETURN_IF_ERROR(DequantizeInputs(context, input_tensor_ids_, quant_conversion_map_));
|
||||
}
|
||||
|
||||
// CPU HWC input data conversion to PHWC4 and fill the GPU buffer
|
||||
for (const auto& input : graph_inputs_) {
|
||||
if (input.set_externally) continue;
|
||||
// A user provides data on CPU memory for this buffer - need to copy to MTLBuffer
|
||||
|
||||
TfLiteTensor* tensor = context->tensors + input.tensor_id;
|
||||
TfLiteTensor* tensor = &context->tensors[input.tensor_id];
|
||||
void* gpu_ptr = [input_output_buffers_[input.id] contents];
|
||||
std::memcpy(gpu_ptr, tensor->data.f, input.shape.DimensionsProduct() * sizeof(float));
|
||||
if (input_output_buffers_[input.id] == bphwc4_buffers_[input.id]) continue;
|
||||
@ -529,9 +573,14 @@ class Delegate {
|
||||
const void* gpu_ptr = [input_output_buffers_[output.id] contents];
|
||||
std::memcpy(tensor->data.f, gpu_ptr, output.shape.DimensionsProduct() * sizeof(float));
|
||||
}
|
||||
if (is_quantized_model) {
|
||||
RETURN_IF_ERROR(QuantizeOutputs(context, output_tensor_ids_, quant_conversion_map_));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
const TFLGpuDelegateOptions options() const { return options_; }
|
||||
|
||||
TfLiteDelegate* tflite_delegate() { return &delegate_; }
|
||||
|
||||
private:
|
||||
@ -551,6 +600,12 @@ class Delegate {
|
||||
std::vector<ValueRef> tensors_; // indexed by ValueId
|
||||
std::vector<ValueId> inputs_;
|
||||
std::vector<ValueId> outputs_;
|
||||
std::vector<int64_t> input_tensor_ids_;
|
||||
std::vector<int64_t> output_tensor_ids_;
|
||||
// Whenever quantized inference is enabled, this maps the tensor index of each
|
||||
// originally quantized (8-bit) tensor to its float version added in
|
||||
// model_builder - and vice versa.
|
||||
std::unordered_map<int, int> quant_conversion_map_;
|
||||
|
||||
TFLInferenceContext* inference_context_;
|
||||
// input and output buffers are passed into Metal inference engine
|
||||
@ -595,22 +650,34 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
// forbids that.
|
||||
const auto status = metal_delegate->Prepare(context, params);
|
||||
if (status.ok()) return metal_delegate;
|
||||
context->ReportError(context, "TfLiteGpuDelegate Prepare: %s",
|
||||
std::string(status.message()).c_str());
|
||||
TF_LITE_KERNEL_LOG(context, "TfLiteMetalDelegate Prepare: %s",
|
||||
std::string(status.message()).c_str());
|
||||
return nullptr;
|
||||
},
|
||||
// .free
|
||||
[](TfLiteContext*, void* buffer) -> void {},
|
||||
// .prepare
|
||||
[](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
|
||||
if (!node->user_data) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
auto* gpu_delegate_kernel = GetMetalDelegate(node);
|
||||
const auto status =
|
||||
gpu_delegate_kernel->GetRequiredTemporaries(context, node, &node->temporaries);
|
||||
if (!status.ok()) {
|
||||
TF_LITE_KERNEL_LOG(context, "TfLiteMetalDelegate Prepare: %s",
|
||||
std::string(status.message()).c_str());
|
||||
return kTfLiteError;
|
||||
}
|
||||
return node->user_data ? kTfLiteOk : kTfLiteError;
|
||||
},
|
||||
// .invoke
|
||||
[](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
|
||||
const auto status = GetMetalDelegate(node)->Invoke(context);
|
||||
if (status.ok()) return kTfLiteOk;
|
||||
context->ReportError(context, "TfLiteMetalDelegate Invoke: %s",
|
||||
std::string(status.message()).c_str());
|
||||
TF_LITE_KERNEL_LOG(context, "TfLiteMetalDelegate Invoke: %s",
|
||||
std::string(status.message()).c_str());
|
||||
return kTfLiteError;
|
||||
},
|
||||
nullptr, // .profiling_string
|
||||
@ -618,7 +685,8 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
"TfLiteMetalDelegate", // .custom_name
|
||||
1, // .version
|
||||
};
|
||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
||||
TfLiteIntArray* ops_to_replace =
|
||||
GetOpsToReplace(context, GetMetalDelegate(delegate)->options().enable_quantization);
|
||||
const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(context, kRegistration,
|
||||
ops_to_replace, delegate);
|
||||
TfLiteIntArrayFree(ops_to_replace);
|
||||
|
Loading…
x
Reference in New Issue
Block a user