Remove superfluous Dequantize nodes in GPU delegate when executing float16 quantized models.
PiperOrigin-RevId: 259941556
This commit is contained in:
parent
53da0bc5ce
commit
130a84e59c
@ -77,6 +77,7 @@ cc_library(
|
||||
":tensor",
|
||||
"//tensorflow/lite:context",
|
||||
"//tensorflow/lite:kernel_api",
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
|
@ -43,6 +43,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
@ -708,7 +709,6 @@ class AddOperationParser : public TFLiteOperationParser {
|
||||
}
|
||||
}
|
||||
node->operation.attributes = std::move(attr);
|
||||
|
||||
const auto* tf_options =
|
||||
reinterpret_cast<const TfLiteAddParams*>(tflite_node->builtin_data);
|
||||
if (!tf_options) {
|
||||
@ -2226,6 +2226,106 @@ Status GetNodeAndRegistration(TfLiteContext* context, int node_id,
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
TfLiteIntArray* GetOpsToReplaceFromGraphWithDequantize(TfLiteContext* context) {
|
||||
TfLiteIntArray* execution_plan = nullptr;
|
||||
if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) {
|
||||
context->ReportError(context, "Unable to get graph execution plan.");
|
||||
return nullptr;
|
||||
}
|
||||
std::set<std::string> errors;
|
||||
std::unordered_map<int, int> dequant_nodes;
|
||||
std::vector<int> ops_to_replace;
|
||||
std::vector<int> dequant_nodes_to_save;
|
||||
|
||||
// Map the output tensor of a Dequantize nodes to its input tensor.
|
||||
std::unordered_map<int, int> node_map;
|
||||
for (int i = 0; i < execution_plan->size; ++i) {
|
||||
bool replace_node = false;
|
||||
// Keep track of any inputs from a Dequantize node.
|
||||
std::vector<int> inputs_from_dequant;
|
||||
std::vector<int> orig_inputs;
|
||||
|
||||
TfLiteNode* node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
auto status = GetNodeAndRegistration(context, i, &node, ®istration);
|
||||
if (!status.ok()) {
|
||||
context->ReportError(context, status.error_message().c_str());
|
||||
return nullptr;
|
||||
}
|
||||
if (registration->builtin_code == kTfLiteBuiltinDequantize &&
|
||||
context->tensors[node->inputs->data[0]].type ==
|
||||
TfLiteType::kTfLiteFloat16) {
|
||||
// Record the output->input mapping for the op.
|
||||
node_map[node->outputs->data[0]] = node->inputs->data[0];
|
||||
// For now, add the node to the list of ops to replace.
|
||||
ops_to_replace.push_back(i);
|
||||
// Record the dequant node id, indexed by output id.
|
||||
dequant_nodes[node->outputs->data[0]] = i;
|
||||
continue;
|
||||
}
|
||||
TfLiteIntArray* inputs = node->inputs;
|
||||
// Fix the node's inputs (i.e. prune out the preceding dequantize node)
|
||||
// in order to test if it is supported on the GPU.
|
||||
for (int j = 0; j < inputs->size; ++j) {
|
||||
orig_inputs.push_back(inputs->data[j]);
|
||||
if (node_map.find(inputs->data[j]) != node_map.end()) {
|
||||
inputs_from_dequant.push_back(dequant_nodes[inputs->data[j]]);
|
||||
// Remap inputs of this node to the inputs of the preceding dequant.
|
||||
inputs->data[j] = node_map[inputs->data[j]];
|
||||
}
|
||||
}
|
||||
status = IsSupported(context, node, registration);
|
||||
if (status.ok() &&
|
||||
// TODO(eignasheva): resolve sub operation support for metal delegate
|
||||
// registration->builtin_code != kTfLiteBuiltinSub &&
|
||||
IsAllFloatTensors(context, node->inputs) &&
|
||||
IsAllFloatTensors(context, node->outputs)) {
|
||||
if (errors.empty()) {
|
||||
replace_node = true;
|
||||
ops_to_replace.push_back(i);
|
||||
}
|
||||
} else {
|
||||
// Unable to replace this node. Restore the inputs to the original
|
||||
// if they were modified.
|
||||
if (!inputs_from_dequant.empty()) {
|
||||
TfLiteIntArray* inputs = node->inputs;
|
||||
for (int j = 0; j < inputs->size; ++j) {
|
||||
inputs->data[j] = orig_inputs[j];
|
||||
}
|
||||
}
|
||||
errors.insert(GetOpNameByRegistration(registration) + ": " +
|
||||
status.error_message());
|
||||
}
|
||||
// if any input is the output of a dequantize node AND we failed to
|
||||
// replace this op, mark the corresponding dequantize node as a node to
|
||||
// save.
|
||||
if (!replace_node && !inputs_from_dequant.empty()) {
|
||||
dequant_nodes_to_save.insert(dequant_nodes_to_save.end(),
|
||||
inputs_from_dequant.begin(),
|
||||
inputs_from_dequant.end());
|
||||
}
|
||||
}
|
||||
if (!errors.empty()) {
|
||||
std::string unsupported = absl::StrJoin(errors, "\n");
|
||||
std::string error_message =
|
||||
"Next operations are not supported by GPU delegate:\n" + unsupported +
|
||||
"\nFirst " + std::to_string(ops_to_replace.size()) +
|
||||
" operations will run on the GPU, and the remaining " +
|
||||
std::to_string(execution_plan->size - ops_to_replace.size()) +
|
||||
" on the CPU.";
|
||||
context->ReportError(context, error_message.c_str());
|
||||
}
|
||||
// Pop all dequantize nodes that must be preserved.
|
||||
for (int i = 0; i < dequant_nodes_to_save.size(); ++i) {
|
||||
auto it = std::find(ops_to_replace.begin(), ops_to_replace.end(),
|
||||
dequant_nodes_to_save[i]);
|
||||
if (it != ops_to_replace.end()) {
|
||||
ops_to_replace.erase(it);
|
||||
}
|
||||
}
|
||||
return ConvertVectorToTfLiteIntArray(ops_to_replace);
|
||||
}
|
||||
|
||||
// TODO(impjdi): Check number of input/output tensors and their dimensions.
|
||||
// TODO(impjdi): Check ops' parameters.
|
||||
TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
|
||||
@ -2234,27 +2334,34 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
|
||||
context->ReportError(context, "Unable to get graph execution plan.");
|
||||
return nullptr;
|
||||
}
|
||||
TfLiteIntArray* subgraph = TfLiteIntArrayCreate(execution_plan->size);
|
||||
subgraph->size = 0;
|
||||
std::set<std::string> errors;
|
||||
|
||||
// Map the output tensor of a Dequantize nodes to its input tensor.
|
||||
std::unordered_map<int, int> node_map;
|
||||
// Dispatch to another function if graph has Dequantize nodes.
|
||||
for (int i = 0; i < execution_plan->size; ++i) {
|
||||
TfLiteNode* node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
auto status = GetNodeAndRegistration(context, i, &node, ®istration);
|
||||
if (!status.ok()) {
|
||||
context->ReportError(context, status.error_message().c_str());
|
||||
TfLiteIntArrayFree(subgraph);
|
||||
return nullptr;
|
||||
}
|
||||
if (registration->builtin_code == kTfLiteBuiltinDequantize &&
|
||||
context->tensors[node->inputs->data[0]].type ==
|
||||
TfLiteType::kTfLiteFloat16) {
|
||||
// Record the output->input mapping for the op.
|
||||
node_map[node->outputs->data[0]] = node->inputs->data[0];
|
||||
continue;
|
||||
return GetOpsToReplaceFromGraphWithDequantize(context);
|
||||
}
|
||||
}
|
||||
|
||||
// No Dequantize nodes. Iterate through graph and find ops to replace.
|
||||
TfLiteIntArray* subgraph = TfLiteIntArrayCreate(execution_plan->size);
|
||||
subgraph->size = 0;
|
||||
std::set<std::string> errors;
|
||||
for (int i = 0; i < execution_plan->size; ++i) {
|
||||
TfLiteNode* node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
auto status = GetNodeAndRegistration(context, i, &node, ®istration);
|
||||
if (!status.ok()) {
|
||||
context->ReportError(context, status.error_message().c_str());
|
||||
return nullptr;
|
||||
}
|
||||
status = IsSupported(context, node, registration);
|
||||
if (status.ok() &&
|
||||
@ -2262,14 +2369,6 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
|
||||
// registration->builtin_code != kTfLiteBuiltinSub &&
|
||||
IsAllFloatTensors(context, node->inputs) &&
|
||||
IsAllFloatTensors(context, node->outputs)) {
|
||||
// Fix the node's inputs (i.e. prune out the preceding dequantize node)
|
||||
// if the op is supported.
|
||||
TfLiteIntArray* inputs = node->inputs;
|
||||
for (int j = 0; j < inputs->size; ++j) {
|
||||
if (node_map.find(inputs->data[j]) != node_map.end()) {
|
||||
inputs->data[j] = node_map[inputs->data[j]];
|
||||
}
|
||||
}
|
||||
if (errors.empty()) subgraph->data[subgraph->size++] = i;
|
||||
} else {
|
||||
errors.insert(GetOpNameByRegistration(registration) + ": " +
|
||||
@ -2292,12 +2391,17 @@ Status BuildModel(TfLiteContext* context,
|
||||
const TfLiteDelegateParams* delegate_params,
|
||||
GraphFloat32* graph) {
|
||||
std::vector<std::unique_ptr<TFLiteOperationParser>> operations;
|
||||
std::vector<int> tflite_nodes;
|
||||
for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
|
||||
TfLiteNode* tflite_node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
RETURN_IF_ERROR(GetNodeAndRegistration(
|
||||
context, delegate_params->nodes_to_replace->data[i], &tflite_node,
|
||||
®istration));
|
||||
if (registration->builtin_code == kTfLiteBuiltinDequantize) {
|
||||
// Ignore Dequantize nodes.
|
||||
continue;
|
||||
}
|
||||
auto op_parser = NewOperationParser(registration);
|
||||
if (!op_parser) {
|
||||
return UnimplementedError(
|
||||
@ -2306,15 +2410,16 @@ Status BuildModel(TfLiteContext* context,
|
||||
") is not supported by TFLite GPU Delegate."));
|
||||
}
|
||||
operations.push_back(std::move(op_parser));
|
||||
tflite_nodes.push_back(i);
|
||||
}
|
||||
std::vector<Value<TensorRef<BHWC>>*> tensor_to_value(context->tensors_size,
|
||||
nullptr);
|
||||
for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
|
||||
for (int i = 0; i < operations.size(); ++i) {
|
||||
TfLiteNode* tflite_node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
RETURN_IF_ERROR(GetNodeAndRegistration(
|
||||
context, delegate_params->nodes_to_replace->data[i], &tflite_node,
|
||||
®istration));
|
||||
context, delegate_params->nodes_to_replace->data[tflite_nodes[i]],
|
||||
&tflite_node, ®istration));
|
||||
ObjectReader reader(graph, context, tflite_node, &tensor_to_value);
|
||||
RETURN_IF_ERROR(
|
||||
operations[i]->Parse(tflite_node, registration, graph, &reader));
|
||||
|
@ -212,7 +212,8 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
|
||||
// t0 (FP16) -> DequantNode -> t1 (FP32) -> Add -> t4
|
||||
// t2 (FP16) -> DequantNode -> t3 (FP32) --/
|
||||
//
|
||||
// After pruning, the graph has one node:
|
||||
// OpsToReplace should choose all three nodes for replacement, and
|
||||
// the graph on the GPU will look like this (no Dequants):
|
||||
//
|
||||
// t0 (FP16) --> Add -> t4
|
||||
// t2 (FP16) --/
|
||||
@ -237,11 +238,11 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
|
||||
|
||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
||||
|
||||
// Just one node left.
|
||||
EXPECT_EQ(ops_to_replace->size, 1);
|
||||
// Replace all nodes.
|
||||
EXPECT_EQ(ops_to_replace->size, 3);
|
||||
TfLiteNode* node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
|
||||
context->GetNodeAndRegistration(context, ops_to_replace->data[2], &node,
|
||||
®istration);
|
||||
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
|
||||
TfLiteType::kTfLiteFloat16);
|
||||
@ -416,6 +417,174 @@ TEST(ModelBuilderTest, GetOpsToReplaceDoesNotPruneUint8) {
|
||||
TfLiteIntArrayFree(ops_to_replace);
|
||||
}
|
||||
|
||||
class InterpreterMultiNode {
|
||||
public:
|
||||
InterpreterMultiNode() {
|
||||
void* builtin_data = malloc(sizeof(int));
|
||||
EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk);
|
||||
EXPECT_EQ(interpreter_.SetInputs({0, 1, 2}), kTfLiteOk);
|
||||
EXPECT_EQ(interpreter_.SetOutputs({6, 7}), kTfLiteOk);
|
||||
|
||||
// Add 3 Dequantize Nodes with float16 input.
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
const TfLiteRegistration reg_dequant = {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/nullptr,
|
||||
/*invoke=*/nullptr,
|
||||
/*profiling_string=*/nullptr,
|
||||
kTfLiteBuiltinDequantize};
|
||||
EXPECT_EQ(interpreter_.AddNodeWithParameters(
|
||||
/*inputs=*/{i}, /*outputs=*/{i + 3}, /*init_data=*/nullptr,
|
||||
/*init_data_size=*/0, /*builtin_data=*/nullptr,
|
||||
/*registration=*/®_dequant),
|
||||
kTfLiteOk);
|
||||
}
|
||||
|
||||
// Add the ADD op node that GPU delegate supports.
|
||||
const TfLiteRegistration reg_add0 = {
|
||||
[](TfLiteContext* context, const char* buffer, size_t length) {
|
||||
return reinterpret_cast<void*>(new int(1));
|
||||
},
|
||||
[](TfLiteContext* context, void* buffer) {
|
||||
delete reinterpret_cast<int*>(buffer);
|
||||
},
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
kTfLiteBuiltinAdd};
|
||||
|
||||
EXPECT_EQ(interpreter_.AddNodeWithParameters(
|
||||
/*inputs=*/{4, 5}, /*outputs=*/{7}, /*init_data=*/nullptr,
|
||||
/*init_data_size=*/0,
|
||||
/*builtin_data=*/builtin_data,
|
||||
/*registration=*/®_add0),
|
||||
kTfLiteOk);
|
||||
|
||||
// Add the GreaterThan op node that GPU delegate doesn't support.
|
||||
const TfLiteRegistration reg_greater = {
|
||||
[](TfLiteContext* context, const char* buffer, size_t length) {
|
||||
return reinterpret_cast<void*>(new int(1));
|
||||
},
|
||||
[](TfLiteContext* context, void* buffer) {
|
||||
delete reinterpret_cast<int*>(buffer);
|
||||
},
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
kTfLiteBuiltinGreater};
|
||||
|
||||
EXPECT_EQ(interpreter_.AddNodeWithParameters(
|
||||
/*inputs=*/{3, 4}, /*outputs=*/{6}, /*init_data=*/nullptr,
|
||||
/*init_data_size=*/0,
|
||||
/*builtin_data=*/builtin_data,
|
||||
/*registration=*/®_greater),
|
||||
kTfLiteOk);
|
||||
|
||||
const std::vector<int> dims = {1};
|
||||
TfLiteQuantization quantization;
|
||||
quantization.type = kTfLiteNoQuantization;
|
||||
EXPECT_EQ(
|
||||
interpreter_.SetTensorParametersReadWrite(
|
||||
0, TfLiteType::kTfLiteFloat16, "t0", dims, quantization, false),
|
||||
kTfLiteOk);
|
||||
EXPECT_EQ(
|
||||
interpreter_.SetTensorParametersReadWrite(
|
||||
1, TfLiteType::kTfLiteFloat16, "t1", dims, quantization, false),
|
||||
kTfLiteOk);
|
||||
EXPECT_EQ(
|
||||
interpreter_.SetTensorParametersReadWrite(
|
||||
2, TfLiteType::kTfLiteFloat16, "t2", dims, quantization, false),
|
||||
kTfLiteOk);
|
||||
EXPECT_EQ(
|
||||
interpreter_.SetTensorParametersReadWrite(
|
||||
3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false),
|
||||
kTfLiteOk);
|
||||
EXPECT_EQ(
|
||||
interpreter_.SetTensorParametersReadWrite(
|
||||
4, TfLiteType::kTfLiteFloat32, "t4", dims, quantization, false),
|
||||
kTfLiteOk);
|
||||
EXPECT_EQ(
|
||||
interpreter_.SetTensorParametersReadWrite(
|
||||
5, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
|
||||
kTfLiteOk);
|
||||
EXPECT_EQ(
|
||||
interpreter_.SetTensorParametersReadWrite(
|
||||
6, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
|
||||
kTfLiteOk);
|
||||
EXPECT_EQ(
|
||||
interpreter_.SetTensorParametersReadWrite(
|
||||
7, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
|
||||
kTfLiteOk);
|
||||
exec_plan_ = TfLiteIntArrayCreate(5);
|
||||
exec_plan_->data[0] = 0;
|
||||
exec_plan_->data[1] = 1;
|
||||
exec_plan_->data[2] = 2;
|
||||
exec_plan_->data[3] = 3;
|
||||
exec_plan_->data[4] = 4;
|
||||
}
|
||||
|
||||
~InterpreterMultiNode() { TfLiteIntArrayFree(exec_plan_); }
|
||||
|
||||
Subgraph* GetSubgraph() { return interpreter_.subgraph(0); }
|
||||
TfLiteIntArray* exec_plan() const { return exec_plan_; }
|
||||
|
||||
private:
|
||||
Interpreter interpreter_;
|
||||
TfLiteIntArray* exec_plan_;
|
||||
};
|
||||
|
||||
InterpreterMultiNode* interpreter_mn = new InterpreterMultiNode();
|
||||
|
||||
TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequants) {
|
||||
// A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'.
|
||||
// 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
|
||||
// t0 (FP16) --> Dequant --> t3 (FP32) --> Greater -> t6
|
||||
// t1 (FP16) --> Dequant --> t4 (FP32) --/
|
||||
// --\
|
||||
// t3 (FP16) --> Dequant --> t5 (FP32) --> Add -> t7
|
||||
//
|
||||
// OpsToReplace should replace the 'Add' op and the Dequant outputing
|
||||
// t5, but leave the other Dequant nodes because 'Greater' must run
|
||||
// on the CPU.
|
||||
TfLiteContext* context = interpreter_mn->GetSubgraph()->context();
|
||||
|
||||
// These functions are meant to be called inside delegates. Swap out
|
||||
// for similar functions to permit direct calling of GetOpsToReplace.
|
||||
context->GetExecutionPlan = [](struct TfLiteContext* context,
|
||||
TfLiteIntArray** execution_plan) {
|
||||
*execution_plan = interpreter_mn->exec_plan();
|
||||
return kTfLiteOk;
|
||||
};
|
||||
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
||||
TfLiteNode** node,
|
||||
TfLiteRegistration** registration) {
|
||||
auto& node_and_reg =
|
||||
interpreter_mn->GetSubgraph()->nodes_and_registration()[node_index];
|
||||
*node = &node_and_reg.first;
|
||||
*registration = &node_and_reg.second;
|
||||
return kTfLiteOk;
|
||||
};
|
||||
|
||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
||||
|
||||
EXPECT_EQ(ops_to_replace->size, 2);
|
||||
// Op at index 2 is the Dequant op (t3 -> t5).
|
||||
EXPECT_EQ(ops_to_replace->data[0], 2);
|
||||
// Op at index 3 is the Add op.
|
||||
EXPECT_EQ(ops_to_replace->data[1], 3);
|
||||
|
||||
TfLiteNode* node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
// Verify that Add op has fp16 inputs.
|
||||
context->GetNodeAndRegistration(context, ops_to_replace->data[1], &node,
|
||||
®istration);
|
||||
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
|
||||
TfLiteType::kTfLiteFloat16);
|
||||
EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
|
||||
TfLiteType::kTfLiteFloat16);
|
||||
TfLiteIntArrayFree(ops_to_replace);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
Loading…
Reference in New Issue
Block a user