Rectifies delegation of DEQUANTIZE nodes in fp16 graphs. Only those dequant nodes which are consumed in the first delegated partition are now selected.

PiperOrigin-RevId: 314627542
Change-Id: I801922eef3a95d3d4b05d53f5e8d27a0842c3be6
This commit is contained in:
Sachin Joglekar 2020-06-03 16:32:37 -07:00 committed by TensorFlower Gardener
parent 70fd126d3a
commit c0b6b669e2
4 changed files with 216 additions and 257 deletions

View File

@ -250,7 +250,7 @@ class InterpreterFp16 : public DelegatedInterpreter {
InterpreterFp16* interpreter_fp16_add_op = InterpreterFp16* interpreter_fp16_add_op =
new InterpreterFp16(kTfLiteBuiltinAdd); new InterpreterFp16(kTfLiteBuiltinAdd);
TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) { TEST(ModelBuilderTest, GetOpsToReplaceAcceptsFp16DequantizeNodes) {
// Before pruning, the graph has three nodes: // Before pruning, the graph has three nodes:
// //
// t0 (FP16) -> DequantNode -> t1 (FP32) -> Add -> t4 // t0 (FP16) -> DequantNode -> t1 (FP32) -> Add -> t4
@ -283,14 +283,15 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
context->PreviewDelegatePartitioning = context->PreviewDelegatePartitioning =
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
TfLiteDelegateParams** partition_params_array, int* num_partitions) { TfLiteDelegateParams** partition_params_array, int* num_partitions) {
// The partitioner should accept only the Add op initially.
EXPECT_EQ(nodes_to_replace->size, 1);
// Single partition output.
auto params = interpreter_fp16_add_op->add_delegate_params(); auto params = interpreter_fp16_add_op->add_delegate_params();
params->nodes_to_replace = TfLiteIntArrayCreate(3); params->nodes_to_replace = TfLiteIntArrayCreate(1);
params->nodes_to_replace->data[0] = 0; params->nodes_to_replace->data[0] = 2;
params->nodes_to_replace->data[1] = 1;
params->nodes_to_replace->data[2] = 2;
params->input_tensors = TfLiteIntArrayCreate(2); params->input_tensors = TfLiteIntArrayCreate(2);
params->input_tensors->data[0] = 0; params->input_tensors->data[0] = 1;
params->input_tensors->data[1] = 2; params->input_tensors->data[1] = 3;
params->output_tensors = TfLiteIntArrayCreate(1); params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 4; params->output_tensors->data[0] = 4;
@ -301,11 +302,13 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
// Replace all nodes. // The Dequant nodes are added to ops_to_replace as a post-processing step by
// the FP16GraphPartitioner. ADD is delegated with its inputs pointing to the
// FP16 inputs.
EXPECT_EQ(ops_to_replace->size, 3); EXPECT_EQ(ops_to_replace->size, 3);
TfLiteNode* node = nullptr; TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr; TfLiteRegistration* registration = nullptr;
context->GetNodeAndRegistration(context, ops_to_replace->data[2], &node, context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
&registration); &registration);
EXPECT_EQ(context->tensors[node->inputs->data[0]].type, EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
TfLiteType::kTfLiteFloat16); TfLiteType::kTfLiteFloat16);
@ -317,14 +320,14 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
InterpreterFp16* interpreter_fp16_gt_op = InterpreterFp16* interpreter_fp16_gt_op =
new InterpreterFp16(kTfLiteBuiltinGreater); new InterpreterFp16(kTfLiteBuiltinGreater);
TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) { TEST(ModelBuilderTest, GetOpsToReplaceRejectsFp16DequantizeNodes) {
// Before pruning, the graph has three nodes: // Before pruning, the graph has three nodes:
// //
// t0 (FP16) -> DequantNode -> t1 (FP32) -> Greater Op -> t4 // t0 (FP16) -> DequantNode -> t1 (FP32) -> Greater Op -> t4
// t2 (FP16) -> DequantNode -> t3 (FP32) --/ // t2 (FP16) -> DequantNode -> t3 (FP32) --/
// //
// Because there is no GPU equivalent for the Greater op, we don't prune // Because there is no GPU equivalent for the Greater op, we don't choose any
// the Dequantize nodes. // nodes.
TfLiteContext* context = interpreter_fp16_gt_op->context(); TfLiteContext* context = interpreter_fp16_gt_op->context();
// These functions are meant to be called inside delegates. Swap out // These functions are meant to be called inside delegates. Swap out
@ -346,26 +349,10 @@ TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) {
context->PreviewDelegatePartitioning = context->PreviewDelegatePartitioning =
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
TfLiteDelegateParams** partition_params_array, int* num_partitions) { TfLiteDelegateParams** partition_params_array, int* num_partitions) {
auto params = interpreter_fp16_gt_op->add_delegate_params(); // No selected nodes.
// First partition for DequantNode (t0->t1) EXPECT_EQ(nodes_to_replace->size, 0);
params->nodes_to_replace = TfLiteIntArrayCreate(1); *partition_params_array = nullptr;
params->nodes_to_replace->data[0] = 0; *num_partitions = 0;
params->input_tensors = TfLiteIntArrayCreate(1);
params->input_tensors->data[0] = 0;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 1;
// Second partition for DequantNode (t2->t3)
params = interpreter_fp16_gt_op->add_delegate_params();
params->nodes_to_replace = TfLiteIntArrayCreate(1);
params->nodes_to_replace->data[0] = 0;
params->input_tensors = TfLiteIntArrayCreate(1);
params->input_tensors->data[0] = 0;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 1;
*partition_params_array = interpreter_fp16_gt_op->delegate_params();
*num_partitions = interpreter_fp16_gt_op->num_delegate_params();
return kTfLiteOk; return kTfLiteOk;
}; };
@ -685,7 +672,7 @@ TEST(ModelBuilderTest, GetOpsToReplaceMultiplePartitions) {
class InterpreterMultiNode : public DelegatedInterpreter { class InterpreterMultiNode : public DelegatedInterpreter {
public: public:
explicit InterpreterMultiNode(bool add_op_first = true) explicit InterpreterMultiNode(bool both_ops_supported = true)
: DelegatedInterpreter(5) { : DelegatedInterpreter(5) {
void* builtin_data = malloc(sizeof(int)); void* builtin_data = malloc(sizeof(int));
EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk); EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk);
@ -707,8 +694,8 @@ class InterpreterMultiNode : public DelegatedInterpreter {
kTfLiteOk); kTfLiteOk);
} }
if (add_op_first) { if (both_ops_supported) {
// Add the ADD op node that GPU delegate supports. // Add 2 ADD ops.
const TfLiteRegistration reg_add0 = { const TfLiteRegistration reg_add0 = {
[](TfLiteContext* context, const char* buffer, size_t length) { [](TfLiteContext* context, const char* buffer, size_t length) {
return reinterpret_cast<void*>(new int(1)); return reinterpret_cast<void*>(new int(1));
@ -727,8 +714,7 @@ class InterpreterMultiNode : public DelegatedInterpreter {
/*registration=*/&reg_add0), /*registration=*/&reg_add0),
kTfLiteOk); kTfLiteOk);
// Add the GREATER op node that GPU delegate doesn't support. const TfLiteRegistration reg_add1 = {
const TfLiteRegistration reg_greater = {
[](TfLiteContext* context, const char* buffer, size_t length) { [](TfLiteContext* context, const char* buffer, size_t length) {
return reinterpret_cast<void*>(new int(1)); return reinterpret_cast<void*>(new int(1));
}, },
@ -738,12 +724,12 @@ class InterpreterMultiNode : public DelegatedInterpreter {
nullptr, nullptr,
nullptr, nullptr,
nullptr, nullptr,
kTfLiteBuiltinGreater}; kTfLiteBuiltinAdd};
EXPECT_EQ(interpreter_.AddNodeWithParameters( EXPECT_EQ(interpreter_.AddNodeWithParameters(
/*inputs=*/{3, 4}, /*outputs=*/{6}, /*init_data=*/nullptr, /*inputs=*/{3, 4}, /*outputs=*/{6}, /*init_data=*/nullptr,
/*init_data_size=*/0, /*init_data_size=*/0,
/*builtin_data=*/builtin_data, /*builtin_data=*/builtin_data,
/*registration=*/&reg_greater), /*registration=*/&reg_add1),
kTfLiteOk); kTfLiteOk);
} else { } else {
// Add the GREATER op node that GPU delegate doesn't support. // Add the GREATER op node that GPU delegate doesn't support.
@ -828,19 +814,19 @@ class InterpreterMultiNode : public DelegatedInterpreter {
} }
}; };
InterpreterMultiNode* interpreter_mn = new InterpreterMultiNode(); InterpreterMultiNode* interpreter_mn =
new InterpreterMultiNode(/*both_ops_supported*/ false);
TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) { TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectFp16Nodes_SinglePartition) {
// A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'. // A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'.
// 'Add' can be replaced by the GPU delegate, but 'Greater' can not. // 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
// t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(4) -> t6 // t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(3) -> t6
// t1 (FP16) --> Dequant(1) --> t4 (FP32) --/ // t1 (FP16) --> Dequant(1) --> t4 (FP32) --/
// --\ // --\
// t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(3) -> t7 // t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7
// //
// OpsToReplace should replace the 'Add' op and the Dequant outputing // OpsToReplace should accept 'Add' & the Dequant nodes that only output to
// t5, but leave the other Dequant nodes because 'Greater' must run // it (in this case, Dequant(2)).
// on the CPU.
TfLiteContext* context = interpreter_mn->context(); TfLiteContext* context = interpreter_mn->context();
// These functions are meant to be called inside delegates. Swap out // These functions are meant to be called inside delegates. Swap out
@ -861,33 +847,16 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) {
context->PreviewDelegatePartitioning = context->PreviewDelegatePartitioning =
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
TfLiteDelegateParams** partition_params_array, int* num_partitions) { TfLiteDelegateParams** partition_params_array, int* num_partitions) {
// The FP16GraphPartitioner should only mark the ADD op as accepted.
EXPECT_EQ(nodes_to_replace->size, 1);
EXPECT_EQ(nodes_to_replace->data[0], 4);
// Single partition.
auto params = interpreter_mn->add_delegate_params(); auto params = interpreter_mn->add_delegate_params();
// First partition for DequantNode(0)
params->nodes_to_replace = TfLiteIntArrayCreate(1); params->nodes_to_replace = TfLiteIntArrayCreate(1);
params->nodes_to_replace->data[0] = 0; params->nodes_to_replace->data[0] = 4;
params->input_tensors = TfLiteIntArrayCreate(1);
params->input_tensors->data[0] = 0;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 3;
// Second partition for DequantNode(1)
params = interpreter_mn->add_delegate_params();
params->nodes_to_replace = TfLiteIntArrayCreate(1);
params->nodes_to_replace->data[0] = 1;
params->input_tensors = TfLiteIntArrayCreate(1);
params->input_tensors->data[0] = 1;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 4;
// Third partition for DequantNode(1), DequantNode(2), ADD(3)
params = interpreter_mn->add_delegate_params();
params->nodes_to_replace = TfLiteIntArrayCreate(3);
params->nodes_to_replace->data[0] = 1;
params->nodes_to_replace->data[1] = 2;
params->nodes_to_replace->data[2] = 3;
params->input_tensors = TfLiteIntArrayCreate(2); params->input_tensors = TfLiteIntArrayCreate(2);
params->input_tensors->data[0] = 1; params->input_tensors->data[0] = 1;
params->input_tensors->data[0] = 3; params->input_tensors->data[1] = 3;
params->output_tensors = TfLiteIntArrayCreate(1); params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 7; params->output_tensors->data[0] = 7;
@ -898,16 +867,16 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) {
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
// Post-PreviewDelegatePartitioning, the partitioner will add Dequant(2) to
// ops_to_replace, since it only outputs to a delegated node.
EXPECT_EQ(ops_to_replace->size, 2); EXPECT_EQ(ops_to_replace->size, 2);
// Op at index 2 is the Dequant op (t3 -> t5). // Op at index 4 is the Add op.
EXPECT_EQ(ops_to_replace->data[0], 2); EXPECT_EQ(ops_to_replace->data[0], 4);
// Op at index 3 is the Add op. EXPECT_EQ(ops_to_replace->data[1], 2);
EXPECT_EQ(ops_to_replace->data[1], 3); // Verify that Add op has fp16 inputs.
TfLiteNode* node = nullptr; TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr; TfLiteRegistration* registration = nullptr;
// Verify that Add op has fp16 inputs. context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
context->GetNodeAndRegistration(context, ops_to_replace->data[1], &node,
&registration); &registration);
EXPECT_EQ(context->tensors[node->inputs->data[0]].type, EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
TfLiteType::kTfLiteFloat16); TfLiteType::kTfLiteFloat16);
@ -917,21 +886,18 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) {
} }
InterpreterMultiNode* interpreter_mn2 = InterpreterMultiNode* interpreter_mn2 =
new InterpreterMultiNode(/*add_op_first=*/false); new InterpreterMultiNode(/*both_ops_supported*/ true);
TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsGreaterFirst) { TEST(ModelBuilderTest,
// A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'. GetOpsToReplaceSelectsCorrectFp16Nodes_MultiplePartitions) {
// 'Add' can be replaced by the GPU delegate, but 'Greater' can not. // A graph with three Dequant nodes feeding two Add ops.
// t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(3) -> t6 // t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Add(3) -> t6
// t1 (FP16) --> Dequant(1) --> t4 (FP32) --/ // t1 (FP16) --> Dequant(1) --> t4 (FP32) --/
// --\ // --\
// t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7 // t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7
// //
// Note: the graph dependency is exactly same w/ that in // In this test case, we purposely partition Add(3) & Add(4) into different
// GetOpsToReplaceSelectsCorrectDequantsAddFirst, but the unsupported // partitions, to check if Dequant nodes that output *only* to the first
// 'Greater' op appears first in the execution plan. Despite this, // partition nodes are accepted.
// OpsToReplace should still replace the 'Add' op and the Dequant outputing
// t5, but leave the other Dequant nodes because 'Greater' must run
// on the CPU.
TfLiteContext* context = interpreter_mn2->context(); TfLiteContext* context = interpreter_mn2->context();
@ -954,33 +920,29 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsGreaterFirst) {
context->PreviewDelegatePartitioning = context->PreviewDelegatePartitioning =
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
TfLiteDelegateParams** partition_params_array, int* num_partitions) { TfLiteDelegateParams** partition_params_array, int* num_partitions) {
// The FP16GraphPartitioner should only mark both ADD ops as accepted.
EXPECT_EQ(nodes_to_replace->size, 2);
EXPECT_EQ(nodes_to_replace->data[0], 3);
EXPECT_EQ(nodes_to_replace->data[1], 4);
// Technically, both ADD ops should end up in the same partition.
// But we put them in different partitions to test post-processing with
// DEQUANTIZE nodes.
// First partition with Add(3).
auto params = interpreter_mn2->add_delegate_params(); auto params = interpreter_mn2->add_delegate_params();
// First partition for DequantNode(0)
params->nodes_to_replace = TfLiteIntArrayCreate(1); params->nodes_to_replace = TfLiteIntArrayCreate(1);
params->nodes_to_replace->data[0] = 0; params->nodes_to_replace->data[0] = 3;
params->input_tensors = TfLiteIntArrayCreate(1);
params->input_tensors->data[0] = 0;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 3;
// Second partition for DequantNode(1)
params = interpreter_mn2->add_delegate_params();
params->nodes_to_replace = TfLiteIntArrayCreate(1);
params->nodes_to_replace->data[0] = 1;
params->input_tensors = TfLiteIntArrayCreate(1);
params->input_tensors->data[0] = 1;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 4;
// Third partition for DequantNode(1), DequantNode(2), ADD(4)
params = interpreter_mn2->add_delegate_params();
params->nodes_to_replace = TfLiteIntArrayCreate(3);
params->nodes_to_replace->data[0] = 1;
params->nodes_to_replace->data[1] = 2;
params->nodes_to_replace->data[2] = 4;
params->input_tensors = TfLiteIntArrayCreate(2); params->input_tensors = TfLiteIntArrayCreate(2);
params->input_tensors->data[0] = 1;
params->input_tensors->data[0] = 3; params->input_tensors->data[0] = 3;
params->input_tensors->data[1] = 4;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 6;
// Second partition with Add(4).
params = interpreter_mn2->add_delegate_params();
params->nodes_to_replace = TfLiteIntArrayCreate(1);
params->nodes_to_replace->data[0] = 4;
params->input_tensors = TfLiteIntArrayCreate(2);
params->input_tensors->data[0] = 4;
params->input_tensors->data[1] = 5;
params->output_tensors = TfLiteIntArrayCreate(1); params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 7; params->output_tensors->data[0] = 7;
@ -989,23 +951,32 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsGreaterFirst) {
return kTfLiteOk; return kTfLiteOk;
}; };
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); TfLiteIntArray* ops_to_replace = GetOpsToReplace(
context, /*allow_quant_ops*/ false, /*max_delegated_partitions*/ 2);
EXPECT_EQ(ops_to_replace->size, 2); // Three ops should be selected:
// Op at index 2 is the Dequant op (t3 -> t5). // Add(3), Dequant(x), Add(4)
EXPECT_EQ(ops_to_replace->data[0], 2); // Since both partitions are of size 1, either could end up as the 'first'
// Op at index 4 is the Add op. // partition with one Dequant node added for it.
EXPECT_EQ(ops_to_replace->data[1], 4); EXPECT_EQ(ops_to_replace->size, 3);
TfLiteNode* node = nullptr; TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr; TfLiteRegistration* registration = nullptr;
// Verify that Add op has fp16 inputs. // Verify that both Add ops have fp16 inputs.
context->GetNodeAndRegistration(context, ops_to_replace->data[1], &node, context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
&registration); &registration);
EXPECT_EQ(context->tensors[node->inputs->data[0]].type, EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
TfLiteType::kTfLiteFloat16); TfLiteType::kTfLiteFloat16);
EXPECT_EQ(context->tensors[node->inputs->data[1]].type, EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
TfLiteType::kTfLiteFloat16); TfLiteType::kTfLiteFloat16);
context->GetNodeAndRegistration(context, ops_to_replace->data[2], &node,
&registration);
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
TfLiteType::kTfLiteFloat16);
EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
TfLiteType::kTfLiteFloat16);
// Verify that the op at index 1 is a Dequant outputing to a single Add.
EXPECT_TRUE(ops_to_replace->data[1] == 0 || ops_to_replace->data[1] == 2);
TfLiteIntArrayFree(ops_to_replace); TfLiteIntArrayFree(ops_to_replace);
} }

View File

@ -150,74 +150,112 @@ TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes(
return kTfLiteOk; return kTfLiteOk;
} }
TfLiteStatus FP16GraphPartitionHelper::Partition(
std::set<std::string>* unsupported_nodes_info) {
const auto status = GraphPartitionHelper::Partition(unsupported_nodes_info);
// Clean up those partitions that have a single dequant op. NoteThose
// removed dequant ops have to be reserved in the graph and should not be
// delegated.
RemoveSingleDequantNodePartitions();
return status;
}
std::vector<int> std::vector<int>
FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl( FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
int n, int min_nodes_per_partition) { int n, int min_nodes_per_partition) {
std::vector<int> ops_to_replace = auto first_n_partitions =
GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl( GetFirstNLargestPartitions(n, min_nodes_per_partition);
n, min_nodes_per_partition); std::vector<int> ops_to_replace;
RemapInputTensors(ops_to_replace); if (first_n_partitions.empty()) return ops_to_replace;
RemoveReservedDequantsFromNodes(&ops_to_replace);
// Handle the first delegated partition specially.
// All fp16 DEQUANTIZE nodes whose consumers exist only in this partition can
// be added to the ops to delegate. Others have to be preserved in the graph,
// since the partitioning algorithm will put such nodes greedily in the first
// partition.
const auto* first_partition = first_n_partitions[0];
std::unordered_map<int, int> delegated_dequant_consumers;
for (int i = 0; i < first_partition->nodes_to_replace->size; ++i) {
const int node_id = first_partition->nodes_to_replace->data[i];
ops_to_replace.push_back(node_id);
TfLiteNode* node;
TfLiteRegistration* registration;
const auto status = context_->GetNodeAndRegistration(context_, node_id,
&node, &registration);
if (status != kTfLiteOk) {
TF_LITE_KERNEL_LOG(context_,
"Couldn't get node and registration info for op: %d\n",
node_id);
ops_to_replace.clear();
return ops_to_replace;
}
// See if any input to the op is a (converted) fp16 value. If yes, increment
// its value in delegated_dequant_consumers.
for (int j = 0; j < node->inputs->size; ++j) {
const int input_tid = node->inputs->data[j];
if (dequant_consumers_.find(input_tid) != dequant_consumers_.end()) {
delegated_dequant_consumers[input_tid] += 1;
}
}
}
// Check all dequant nodes that have some consumers in the first partition.
// If the number of delegated consumers is same as total number of consumers,
// add the corresponding DEQUANTIZE op to the delegated nodes.
for (auto tensor_and_consumers : delegated_dequant_consumers) {
if (dequant_consumers_[tensor_and_consumers.first] ==
tensor_and_consumers.second) {
ops_to_replace.emplace_back(dequant_nodes_[tensor_and_consumers.first]);
}
}
// For all other partitions after the first one, insert all nodes into
// ops_to_replace.
for (int i = 1; i < first_n_partitions.size(); ++i) {
auto nodes = first_n_partitions[i]->nodes_to_replace;
ops_to_replace.insert(ops_to_replace.end(), nodes->data,
nodes->data + nodes->size);
}
// Modify the inputs of relevant ops that support fp16 constants.
// TODO(b/156707497): Ensure that these inputs are remapped during the
// delegate's 'free', so that CPU fallback works for fp16 models.
RemapFp16InputTensors(ops_to_replace);
return ops_to_replace; return ops_to_replace;
} }
bool FP16GraphPartitionHelper::IsNodeSupported( bool FP16GraphPartitionHelper::IsNodeSupported(
TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration, TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration,
int node_id, std::string* unsupported_details) { int node_id, std::string* unsupported_details) {
// If we need to handle dequant nodes, we have to remap input tensors of if (registration->builtin_code == kTfLiteBuiltinDequantize &&
// this node if some of them come from a dequant node before testing if
// the node is supported.
std::vector<int> orig_inputs;
if (RecordAndRemapInputTensors(registration->builtin_code, node_id, node,
&orig_inputs)) {
// We have a dequant op here. Note that we retrun an Ok status because a
// dequant node is first added as supported. Later, this dequant node
// will be removed if it has to be preserved in the graph which happens
// when its immediate downstream nodes cannot be supported.
return true;
}
const auto status = GraphPartitionHelper::IsNodeSupported(
context, node, registration, node_id, unsupported_details);
RestoreToOrigInputTensors(node, orig_inputs);
return status;
}
bool FP16GraphPartitionHelper::RecordAndRemapInputTensors(
int32_t op_code, int node_id, TfLiteNode* node,
std::vector<int>* orig_inputs) {
orig_inputs->clear();
// Record the dequant node.
if (op_code == kTfLiteBuiltinDequantize &&
context_->tensors[node->inputs->data[0]].type == context_->tensors[node->inputs->data[0]].type ==
TfLiteType::kTfLiteFloat16) { TfLiteType::kTfLiteFloat16) {
dequant_nodes_[node->outputs->data[0]] = node->inputs->data[0]; // Update mappings if this node is a fp16 DEQUANTIZE node.
return true; dequant_map_[node->outputs->data[0]] = node->inputs->data[0];
} dequant_nodes_[node->outputs->data[0]] = node_id;
// For a dequantize op, there's no need to remap its input tensors. // We do not accept these ops right now.
if (dequant_nodes_.empty()) return false; // This is done to support use-cases where a DEQUANTIZE output might be
RemapInputTensors(node, orig_inputs); // consumed by a CPU op.
return false; return false;
} }
void FP16GraphPartitionHelper::RestoreToOrigInputTensors( // To check if a (possibly) FP16 node is supported, we temporarily point the
TfLiteNode* node, const std::vector<int>& orig_inputs) { // node's inputs to the original fp16 tensors. This 'mutated' node is then
if (node->inputs->size != orig_inputs.size()) return; // passed to the base IsNodeSupported function for checking. After the check,
// we remap the original node inputs, so that the TFLite graph remains the
// same.
std::vector<int> orig_inputs;
if (!dequant_nodes_.empty()) {
RemapFp16InputTensors(node, &orig_inputs);
}
const auto is_supported = GraphPartitionHelper::IsNodeSupported(
context, node, registration, node_id, unsupported_details);
if (!orig_inputs.empty() && node->inputs->size == orig_inputs.size()) {
// Remapping happened. Restore original inputs.
for (int j = 0; j < node->inputs->size; ++j) { for (int j = 0; j < node->inputs->size; ++j) {
node->inputs->data[j] = orig_inputs[j]; node->inputs->data[j] = orig_inputs[j];
if (dequant_nodes_.find(orig_inputs[j]) != dequant_nodes_.end()) {
// If its a fp16 tensor, increment number of consumers of the
// corresponding DEQUANTIZE.
dequant_consumers_[orig_inputs[j]] += 1;
} }
}
}
return is_supported;
} }
void FP16GraphPartitionHelper::RemapInputTensors( void FP16GraphPartitionHelper::RemapFp16InputTensors(
const std::vector<int>& nodes) const { const std::vector<int>& nodes) const {
for (int node_id : nodes) { for (int node_id : nodes) {
TfLiteNode* node; TfLiteNode* node;
@ -229,56 +267,11 @@ void FP16GraphPartitionHelper::RemapInputTensors(
"Couldn't get node and registration info for op: %d\n", "Couldn't get node and registration info for op: %d\n",
node_id); node_id);
} }
RemapInputTensors(node, nullptr /* orig_inputs*/); RemapFp16InputTensors(node, nullptr /* orig_inputs*/);
} }
} }
void FP16GraphPartitionHelper::RemoveSingleDequantNodePartitions() { void FP16GraphPartitionHelper::RemapFp16InputTensors(
auto it = partitions_.begin();
while (it != partitions_.end()) {
auto p = *it;
if (p->nodes_to_replace->size != 1) {
++it;
continue;
}
int node_id = p->nodes_to_replace->data[0];
TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr;
TfLiteStatus status = context_->GetNodeAndRegistration(
context_, node_id, &node, &registration);
if (status != kTfLiteOk) {
TF_LITE_KERNEL_LOG(context_,
"Couldn't get node and registration info for op: %d\n",
node_id);
}
if (registration->builtin_code != kTfLiteBuiltinDequantize ||
context_->tensors[node->inputs->data[0]].type !=
TfLiteType::kTfLiteFloat16) {
++it;
continue;
}
// Note such dequant nodes have to be preserved in the graph as dequant
// ops are not actually supported in the GPU delegate.
dequant_nodes_to_save_.insert(node_id);
it = partitions_.erase(it);
}
}
void FP16GraphPartitionHelper::RemoveReservedDequantsFromNodes(
std::vector<int>* nodes) {
if (dequant_nodes_to_save_.empty()) return;
auto it = nodes->begin();
while (it != nodes->end()) {
if (dequant_nodes_to_save_.find(*it) == dequant_nodes_to_save_.end()) {
++it;
continue;
}
it = nodes->erase(it);
}
}
void FP16GraphPartitionHelper::RemapInputTensors(
TfLiteNode* node, std::vector<int>* orig_inputs) const { TfLiteNode* node, std::vector<int>* orig_inputs) const {
TfLiteIntArray* inputs = node->inputs; TfLiteIntArray* inputs = node->inputs;
auto inputs_view = TfLiteIntArrayView(inputs); auto inputs_view = TfLiteIntArrayView(inputs);
@ -296,8 +289,8 @@ void FP16GraphPartitionHelper::RemapInputTensors(
bool is_remapped = false; bool is_remapped = false;
for (int j = 0; j < inputs->size; ++j) { for (int j = 0; j < inputs->size; ++j) {
const int input_tid = inputs->data[j]; const int input_tid = inputs->data[j];
const auto it = dequant_nodes_.find(input_tid); const auto it = dequant_map_.find(input_tid);
if (it != dequant_nodes_.end()) { if (it != dequant_map_.end()) {
inputs->data[j] = it->second; inputs->data[j] = it->second;
is_remapped = true; is_remapped = true;
} }

View File

@ -127,19 +127,23 @@ class GraphPartitionHelper {
TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory
}; };
// While partitioning the graph, this claims DEQUANTIZE nodes (FP16->FP32) in // Specialized partitioner for graphs that possibly contain fp16 tensors.
// addition to supported nodes for the delegate, when the DEQUANTIZE node's //
// output is an input to the kernel that supports FP16 input. // From nodes that accept fp16 inputs, this delegates the following:
// 1. All nodes (except DEQUANTIZE) that are supported with fp16 inputs by the
// delegate (in the TFLite graph, these nodes take in dequantized FP32
// outputs).
// 2. All fp16 DEQUANTIZE nodes that have *all* their consumers in the *first*
// delegated partition. This is because TFLite's partitioning algorithm
// greedily puts all such nodes in the first partition.
class FP16GraphPartitionHelper : public GraphPartitionHelper { class FP16GraphPartitionHelper : public GraphPartitionHelper {
public: public:
FP16GraphPartitionHelper(TfLiteContext* context, FP16GraphPartitionHelper(TfLiteContext* context,
IsNodeSupportedFn is_node_supported_fn) IsNodeSupportedFn is_node_supported_fn)
: GraphPartitionHelper(context, std::move(is_node_supported_fn)) {} : GraphPartitionHelper(context, std::move(is_node_supported_fn)) {}
TfLiteStatus Partition(
std::set<std::string>* unsupported_nodes_info) override;
protected: protected:
// Specialized function to handle fp16 nodes.
bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node, bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
TfLiteRegistration* registration, int node_id, TfLiteRegistration* registration, int node_id,
std::string* unsupported_details) override; std::string* unsupported_details) override;
@ -149,39 +153,25 @@ class FP16GraphPartitionHelper : public GraphPartitionHelper {
int n, int min_nodes_per_partition) override; int n, int min_nodes_per_partition) override;
private: private:
// Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true. // This remaps fp32 inputs of the given node to their corresponding fp16
// When it's not a dequant op, remap its inputs to the inputs of the preceding // version, if applicable. Can be summarized as:
// dequant if there's a one and returns false. 'orig_inputs' records original // fp16 -> DEQUANTIZE -> fp32 -> OP -> output
// input tensor ids of this node if any input is remapped. // becomes
bool RecordAndRemapInputTensors(int32_t op_code, int node_id, // fp16 -> OP -> output
TfLiteNode* node, void RemapFp16InputTensors(TfLiteNode* node,
std::vector<int>* orig_inputs); std::vector<int>* orig_inputs) const;
// Restore inputs of 'node' to 'orig_inputs' only if two sizes match. // Performs the above remapping for all nodes in the given list, without
void RestoreToOrigInputTensors(TfLiteNode* node, // tracking the original inputs.
const std::vector<int>& orig_inputs); void RemapFp16InputTensors(const std::vector<int>& nodes) const;
// Remap input tensors of every node in 'nodes' (i.e. node indices) if some of // ('dequantize' here refers to fp16 DEQUANTIZE)
// them are from dequant ops. // Mapping of dequantize nodes' output tensor-id to its node id.
void RemapInputTensors(const std::vector<int>& nodes) const;
void RemoveSingleDequantNodePartitions();
void RemoveReservedDequantsFromNodes(std::vector<int>* nodes);
// Remap input tensors of a single 'node' if some of come from a dequant op.
// If 'orig_inputs' isn't nullptr, it records original input tensor ids of
// this node if any input is remapped.
void RemapInputTensors(TfLiteNode* node, std::vector<int>* orig_inputs) const;
// A map recording dequantize nodes's input/output tensors of this selected
// graph. The key is the output tensor id, and the value is the input tensor
// id.
std::unordered_map<int, int> dequant_nodes_; std::unordered_map<int, int> dequant_nodes_;
// Mapping of DEQUANTIZE node's output (fp32) to its input (fp16).
// A set of dequant nodes as in node indices that have to be preserved in the std::unordered_map<int, int> dequant_map_;
// graph. // mapping of DEQUANTIZE output tensor-id to its number of consumers.
std::set<int> dequant_nodes_to_save_; std::unordered_map<int, int> dequant_consumers_;
}; };
} // namespace delegates } // namespace delegates

View File

@ -29,6 +29,10 @@ namespace {
// PartitionGraphIntoIndependentNodeSubsetsImpl partitioner( // PartitionGraphIntoIndependentNodeSubsetsImpl partitioner(
// info, nodes_to_part, node_subsets); // info, nodes_to_part, node_subsets);
// partitioner.Partition(); // partitioner.Partition();
//
// NOTE: Changing the partitioning logic would require a change to
// FP16GraphPartitionHelper.
// LINT.IfChange
class PartitionGraphIntoIndependentNodeSubsetsImpl { class PartitionGraphIntoIndependentNodeSubsetsImpl {
public: public:
PartitionGraphIntoIndependentNodeSubsetsImpl( PartitionGraphIntoIndependentNodeSubsetsImpl(
@ -198,6 +202,7 @@ class PartitionGraphIntoIndependentNodeSubsetsImpl {
// negative values of kEpochNotReady if not assigned. // negative values of kEpochNotReady if not assigned.
std::vector<int> node_epochs_; std::vector<int> node_epochs_;
}; };
// LINT.ThenChange(//tensorflow/lite/delegates/utils.h)
} // namespace } // namespace