Remove superfluous Dequantize nodes in GPU delegate when executing float16 quantized models.

PiperOrigin-RevId: 259941556
This commit is contained in:
A. Unique TensorFlower 2019-07-25 07:20:43 -07:00 committed by TensorFlower Gardener
parent 53da0bc5ce
commit 130a84e59c
3 changed files with 300 additions and 25 deletions

View File

@ -77,6 +77,7 @@ cc_library(
":tensor", ":tensor",
"//tensorflow/lite:context", "//tensorflow/lite:context",
"//tensorflow/lite:kernel_api", "//tensorflow/lite:kernel_api",
"//tensorflow/lite:util",
"//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",

View File

@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/util.h"
namespace tflite { namespace tflite {
namespace gpu { namespace gpu {
@ -708,7 +709,6 @@ class AddOperationParser : public TFLiteOperationParser {
} }
} }
node->operation.attributes = std::move(attr); node->operation.attributes = std::move(attr);
const auto* tf_options = const auto* tf_options =
reinterpret_cast<const TfLiteAddParams*>(tflite_node->builtin_data); reinterpret_cast<const TfLiteAddParams*>(tflite_node->builtin_data);
if (!tf_options) { if (!tf_options) {
@ -2226,6 +2226,106 @@ Status GetNodeAndRegistration(TfLiteContext* context, int node_id,
return OkStatus(); return OkStatus();
} }
TfLiteIntArray* GetOpsToReplaceFromGraphWithDequantize(TfLiteContext* context) {
TfLiteIntArray* execution_plan = nullptr;
if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) {
context->ReportError(context, "Unable to get graph execution plan.");
return nullptr;
}
std::set<std::string> errors;
std::unordered_map<int, int> dequant_nodes;
std::vector<int> ops_to_replace;
std::vector<int> dequant_nodes_to_save;
// Map the output tensor of a Dequantize nodes to its input tensor.
std::unordered_map<int, int> node_map;
for (int i = 0; i < execution_plan->size; ++i) {
bool replace_node = false;
// Keep track of any inputs from a Dequantize node.
std::vector<int> inputs_from_dequant;
std::vector<int> orig_inputs;
TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr;
auto status = GetNodeAndRegistration(context, i, &node, &registration);
if (!status.ok()) {
context->ReportError(context, status.error_message().c_str());
return nullptr;
}
if (registration->builtin_code == kTfLiteBuiltinDequantize &&
context->tensors[node->inputs->data[0]].type ==
TfLiteType::kTfLiteFloat16) {
// Record the output->input mapping for the op.
node_map[node->outputs->data[0]] = node->inputs->data[0];
// For now, add the node to the list of ops to replace.
ops_to_replace.push_back(i);
// Record the dequant node id, indexed by output id.
dequant_nodes[node->outputs->data[0]] = i;
continue;
}
TfLiteIntArray* inputs = node->inputs;
// Fix the node's inputs (i.e. prune out the preceding dequantize node)
// in order to test if it is supported on the GPU.
for (int j = 0; j < inputs->size; ++j) {
orig_inputs.push_back(inputs->data[j]);
if (node_map.find(inputs->data[j]) != node_map.end()) {
inputs_from_dequant.push_back(dequant_nodes[inputs->data[j]]);
// Remap inputs of this node to the inputs of the preceding dequant.
inputs->data[j] = node_map[inputs->data[j]];
}
}
status = IsSupported(context, node, registration);
if (status.ok() &&
// TODO(eignasheva): resolve sub operation support for metal delegate
// registration->builtin_code != kTfLiteBuiltinSub &&
IsAllFloatTensors(context, node->inputs) &&
IsAllFloatTensors(context, node->outputs)) {
if (errors.empty()) {
replace_node = true;
ops_to_replace.push_back(i);
}
} else {
// Unable to replace this node. Restore the inputs to the original
// if they were modified.
if (!inputs_from_dequant.empty()) {
TfLiteIntArray* inputs = node->inputs;
for (int j = 0; j < inputs->size; ++j) {
inputs->data[j] = orig_inputs[j];
}
}
errors.insert(GetOpNameByRegistration(registration) + ": " +
status.error_message());
}
// if any input is the output of a dequantize node AND we failed to
// replace this op, mark the corresponding dequantize node as a node to
// save.
if (!replace_node && !inputs_from_dequant.empty()) {
dequant_nodes_to_save.insert(dequant_nodes_to_save.end(),
inputs_from_dequant.begin(),
inputs_from_dequant.end());
}
}
if (!errors.empty()) {
std::string unsupported = absl::StrJoin(errors, "\n");
std::string error_message =
"Next operations are not supported by GPU delegate:\n" + unsupported +
"\nFirst " + std::to_string(ops_to_replace.size()) +
" operations will run on the GPU, and the remaining " +
std::to_string(execution_plan->size - ops_to_replace.size()) +
" on the CPU.";
context->ReportError(context, error_message.c_str());
}
// Pop all dequantize nodes that must be preserved.
for (int i = 0; i < dequant_nodes_to_save.size(); ++i) {
auto it = std::find(ops_to_replace.begin(), ops_to_replace.end(),
dequant_nodes_to_save[i]);
if (it != ops_to_replace.end()) {
ops_to_replace.erase(it);
}
}
return ConvertVectorToTfLiteIntArray(ops_to_replace);
}
// TODO(impjdi): Check number of input/output tensors and their dimensions. // TODO(impjdi): Check number of input/output tensors and their dimensions.
// TODO(impjdi): Check ops' parameters. // TODO(impjdi): Check ops' parameters.
TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) { TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
@ -2234,27 +2334,34 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
context->ReportError(context, "Unable to get graph execution plan."); context->ReportError(context, "Unable to get graph execution plan.");
return nullptr; return nullptr;
} }
TfLiteIntArray* subgraph = TfLiteIntArrayCreate(execution_plan->size);
subgraph->size = 0;
std::set<std::string> errors;
// Map the output tensor of a Dequantize nodes to its input tensor. // Dispatch to another function if graph has Dequantize nodes.
std::unordered_map<int, int> node_map;
for (int i = 0; i < execution_plan->size; ++i) { for (int i = 0; i < execution_plan->size; ++i) {
TfLiteNode* node = nullptr; TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr; TfLiteRegistration* registration = nullptr;
auto status = GetNodeAndRegistration(context, i, &node, &registration); auto status = GetNodeAndRegistration(context, i, &node, &registration);
if (!status.ok()) { if (!status.ok()) {
context->ReportError(context, status.error_message().c_str()); context->ReportError(context, status.error_message().c_str());
TfLiteIntArrayFree(subgraph);
return nullptr; return nullptr;
} }
if (registration->builtin_code == kTfLiteBuiltinDequantize && if (registration->builtin_code == kTfLiteBuiltinDequantize &&
context->tensors[node->inputs->data[0]].type == context->tensors[node->inputs->data[0]].type ==
TfLiteType::kTfLiteFloat16) { TfLiteType::kTfLiteFloat16) {
// Record the output->input mapping for the op. return GetOpsToReplaceFromGraphWithDequantize(context);
node_map[node->outputs->data[0]] = node->inputs->data[0]; }
continue; }
// No Dequantize nodes. Iterate through graph and find ops to replace.
TfLiteIntArray* subgraph = TfLiteIntArrayCreate(execution_plan->size);
subgraph->size = 0;
std::set<std::string> errors;
for (int i = 0; i < execution_plan->size; ++i) {
TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr;
auto status = GetNodeAndRegistration(context, i, &node, &registration);
if (!status.ok()) {
context->ReportError(context, status.error_message().c_str());
return nullptr;
} }
status = IsSupported(context, node, registration); status = IsSupported(context, node, registration);
if (status.ok() && if (status.ok() &&
@ -2262,14 +2369,6 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
// registration->builtin_code != kTfLiteBuiltinSub && // registration->builtin_code != kTfLiteBuiltinSub &&
IsAllFloatTensors(context, node->inputs) && IsAllFloatTensors(context, node->inputs) &&
IsAllFloatTensors(context, node->outputs)) { IsAllFloatTensors(context, node->outputs)) {
// Fix the node's inputs (i.e. prune out the preceding dequantize node)
// if the op is supported.
TfLiteIntArray* inputs = node->inputs;
for (int j = 0; j < inputs->size; ++j) {
if (node_map.find(inputs->data[j]) != node_map.end()) {
inputs->data[j] = node_map[inputs->data[j]];
}
}
if (errors.empty()) subgraph->data[subgraph->size++] = i; if (errors.empty()) subgraph->data[subgraph->size++] = i;
} else { } else {
errors.insert(GetOpNameByRegistration(registration) + ": " + errors.insert(GetOpNameByRegistration(registration) + ": " +
@ -2292,12 +2391,17 @@ Status BuildModel(TfLiteContext* context,
const TfLiteDelegateParams* delegate_params, const TfLiteDelegateParams* delegate_params,
GraphFloat32* graph) { GraphFloat32* graph) {
std::vector<std::unique_ptr<TFLiteOperationParser>> operations; std::vector<std::unique_ptr<TFLiteOperationParser>> operations;
std::vector<int> tflite_nodes;
for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) { for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
TfLiteNode* tflite_node = nullptr; TfLiteNode* tflite_node = nullptr;
TfLiteRegistration* registration = nullptr; TfLiteRegistration* registration = nullptr;
RETURN_IF_ERROR(GetNodeAndRegistration( RETURN_IF_ERROR(GetNodeAndRegistration(
context, delegate_params->nodes_to_replace->data[i], &tflite_node, context, delegate_params->nodes_to_replace->data[i], &tflite_node,
&registration)); &registration));
if (registration->builtin_code == kTfLiteBuiltinDequantize) {
// Ignore Dequantize nodes.
continue;
}
auto op_parser = NewOperationParser(registration); auto op_parser = NewOperationParser(registration);
if (!op_parser) { if (!op_parser) {
return UnimplementedError( return UnimplementedError(
@ -2306,15 +2410,16 @@ Status BuildModel(TfLiteContext* context,
") is not supported by TFLite GPU Delegate.")); ") is not supported by TFLite GPU Delegate."));
} }
operations.push_back(std::move(op_parser)); operations.push_back(std::move(op_parser));
tflite_nodes.push_back(i);
} }
std::vector<Value<TensorRef<BHWC>>*> tensor_to_value(context->tensors_size, std::vector<Value<TensorRef<BHWC>>*> tensor_to_value(context->tensors_size,
nullptr); nullptr);
for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) { for (int i = 0; i < operations.size(); ++i) {
TfLiteNode* tflite_node = nullptr; TfLiteNode* tflite_node = nullptr;
TfLiteRegistration* registration = nullptr; TfLiteRegistration* registration = nullptr;
RETURN_IF_ERROR(GetNodeAndRegistration( RETURN_IF_ERROR(GetNodeAndRegistration(
context, delegate_params->nodes_to_replace->data[i], &tflite_node, context, delegate_params->nodes_to_replace->data[tflite_nodes[i]],
&registration)); &tflite_node, &registration));
ObjectReader reader(graph, context, tflite_node, &tensor_to_value); ObjectReader reader(graph, context, tflite_node, &tensor_to_value);
RETURN_IF_ERROR( RETURN_IF_ERROR(
operations[i]->Parse(tflite_node, registration, graph, &reader)); operations[i]->Parse(tflite_node, registration, graph, &reader));

View File

@ -212,7 +212,8 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
// t0 (FP16) -> DequantNode -> t1 (FP32) -> Add -> t4 // t0 (FP16) -> DequantNode -> t1 (FP32) -> Add -> t4
// t2 (FP16) -> DequantNode -> t3 (FP32) --/ // t2 (FP16) -> DequantNode -> t3 (FP32) --/
// //
// After pruning, the graph has one node: // OpsToReplace should choose all three nodes for replacement, and
// the graph on the GPU will look like this (no Dequants):
// //
// t0 (FP16) --> Add -> t4 // t0 (FP16) --> Add -> t4
// t2 (FP16) --/ // t2 (FP16) --/
@ -237,11 +238,11 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
// Just one node left. // Replace all nodes.
EXPECT_EQ(ops_to_replace->size, 1); EXPECT_EQ(ops_to_replace->size, 3);
TfLiteNode* node = nullptr; TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr; TfLiteRegistration* registration = nullptr;
context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node, context->GetNodeAndRegistration(context, ops_to_replace->data[2], &node,
&registration); &registration);
EXPECT_EQ(context->tensors[node->inputs->data[0]].type, EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
TfLiteType::kTfLiteFloat16); TfLiteType::kTfLiteFloat16);
@ -416,6 +417,174 @@ TEST(ModelBuilderTest, GetOpsToReplaceDoesNotPruneUint8) {
TfLiteIntArrayFree(ops_to_replace); TfLiteIntArrayFree(ops_to_replace);
} }
class InterpreterMultiNode {
public:
InterpreterMultiNode() {
void* builtin_data = malloc(sizeof(int));
EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk);
EXPECT_EQ(interpreter_.SetInputs({0, 1, 2}), kTfLiteOk);
EXPECT_EQ(interpreter_.SetOutputs({6, 7}), kTfLiteOk);
// Add 3 Dequantize Nodes with float16 input.
for (int i = 0; i < 3; ++i) {
const TfLiteRegistration reg_dequant = {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*invoke=*/nullptr,
/*profiling_string=*/nullptr,
kTfLiteBuiltinDequantize};
EXPECT_EQ(interpreter_.AddNodeWithParameters(
/*inputs=*/{i}, /*outputs=*/{i + 3}, /*init_data=*/nullptr,
/*init_data_size=*/0, /*builtin_data=*/nullptr,
/*registration=*/&reg_dequant),
kTfLiteOk);
}
// Add the ADD op node that GPU delegate supports.
const TfLiteRegistration reg_add0 = {
[](TfLiteContext* context, const char* buffer, size_t length) {
return reinterpret_cast<void*>(new int(1));
},
[](TfLiteContext* context, void* buffer) {
delete reinterpret_cast<int*>(buffer);
},
nullptr,
nullptr,
nullptr,
kTfLiteBuiltinAdd};
EXPECT_EQ(interpreter_.AddNodeWithParameters(
/*inputs=*/{4, 5}, /*outputs=*/{7}, /*init_data=*/nullptr,
/*init_data_size=*/0,
/*builtin_data=*/builtin_data,
/*registration=*/&reg_add0),
kTfLiteOk);
// Add the GreaterThan op node that GPU delegate doesn't support.
const TfLiteRegistration reg_greater = {
[](TfLiteContext* context, const char* buffer, size_t length) {
return reinterpret_cast<void*>(new int(1));
},
[](TfLiteContext* context, void* buffer) {
delete reinterpret_cast<int*>(buffer);
},
nullptr,
nullptr,
nullptr,
kTfLiteBuiltinGreater};
EXPECT_EQ(interpreter_.AddNodeWithParameters(
/*inputs=*/{3, 4}, /*outputs=*/{6}, /*init_data=*/nullptr,
/*init_data_size=*/0,
/*builtin_data=*/builtin_data,
/*registration=*/&reg_greater),
kTfLiteOk);
const std::vector<int> dims = {1};
TfLiteQuantization quantization;
quantization.type = kTfLiteNoQuantization;
EXPECT_EQ(
interpreter_.SetTensorParametersReadWrite(
0, TfLiteType::kTfLiteFloat16, "t0", dims, quantization, false),
kTfLiteOk);
EXPECT_EQ(
interpreter_.SetTensorParametersReadWrite(
1, TfLiteType::kTfLiteFloat16, "t1", dims, quantization, false),
kTfLiteOk);
EXPECT_EQ(
interpreter_.SetTensorParametersReadWrite(
2, TfLiteType::kTfLiteFloat16, "t2", dims, quantization, false),
kTfLiteOk);
EXPECT_EQ(
interpreter_.SetTensorParametersReadWrite(
3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false),
kTfLiteOk);
EXPECT_EQ(
interpreter_.SetTensorParametersReadWrite(
4, TfLiteType::kTfLiteFloat32, "t4", dims, quantization, false),
kTfLiteOk);
EXPECT_EQ(
interpreter_.SetTensorParametersReadWrite(
5, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
kTfLiteOk);
EXPECT_EQ(
interpreter_.SetTensorParametersReadWrite(
6, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
kTfLiteOk);
EXPECT_EQ(
interpreter_.SetTensorParametersReadWrite(
7, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
kTfLiteOk);
exec_plan_ = TfLiteIntArrayCreate(5);
exec_plan_->data[0] = 0;
exec_plan_->data[1] = 1;
exec_plan_->data[2] = 2;
exec_plan_->data[3] = 3;
exec_plan_->data[4] = 4;
}
~InterpreterMultiNode() { TfLiteIntArrayFree(exec_plan_); }
Subgraph* GetSubgraph() { return interpreter_.subgraph(0); }
TfLiteIntArray* exec_plan() const { return exec_plan_; }
private:
Interpreter interpreter_;
TfLiteIntArray* exec_plan_;
};
InterpreterMultiNode* interpreter_mn = new InterpreterMultiNode();
TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequants) {
// A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'.
// 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
// t0 (FP16) --> Dequant --> t3 (FP32) --> Greater -> t6
// t1 (FP16) --> Dequant --> t4 (FP32) --/
// --\
// t3 (FP16) --> Dequant --> t5 (FP32) --> Add -> t7
//
// OpsToReplace should replace the 'Add' op and the Dequant outputing
// t5, but leave the other Dequant nodes because 'Greater' must run
// on the CPU.
TfLiteContext* context = interpreter_mn->GetSubgraph()->context();
// These functions are meant to be called inside delegates. Swap out
// for similar functions to permit direct calling of GetOpsToReplace.
context->GetExecutionPlan = [](struct TfLiteContext* context,
TfLiteIntArray** execution_plan) {
*execution_plan = interpreter_mn->exec_plan();
return kTfLiteOk;
};
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
TfLiteNode** node,
TfLiteRegistration** registration) {
auto& node_and_reg =
interpreter_mn->GetSubgraph()->nodes_and_registration()[node_index];
*node = &node_and_reg.first;
*registration = &node_and_reg.second;
return kTfLiteOk;
};
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
EXPECT_EQ(ops_to_replace->size, 2);
// Op at index 2 is the Dequant op (t3 -> t5).
EXPECT_EQ(ops_to_replace->data[0], 2);
// Op at index 3 is the Add op.
EXPECT_EQ(ops_to_replace->data[1], 3);
TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr;
// Verify that Add op has fp16 inputs.
context->GetNodeAndRegistration(context, ops_to_replace->data[1], &node,
&registration);
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
TfLiteType::kTfLiteFloat16);
EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
TfLiteType::kTfLiteFloat16);
TfLiteIntArrayFree(ops_to_replace);
}
} // namespace } // namespace
} // namespace gpu } // namespace gpu
} // namespace tflite } // namespace tflite