From c0b6b669e2e83aa08531e649718ef881b46bf11d Mon Sep 17 00:00:00 2001 From: Sachin Joglekar Date: Wed, 3 Jun 2020 16:32:37 -0700 Subject: [PATCH] Rectifies delegation of DEQUANTIZE nodes in fp16 graphs. Only those dequant nodes which are consumed in the first delegated partition are now selected. PiperOrigin-RevId: 314627542 Change-Id: I801922eef3a95d3d4b05d53f5e8d27a0842c3be6 --- .../gpu/common/model_builder_test.cc | 207 ++++++++---------- tensorflow/lite/delegates/utils.cc | 199 ++++++++--------- tensorflow/lite/delegates/utils.h | 62 +++--- tensorflow/lite/graph_info.cc | 5 + 4 files changed, 216 insertions(+), 257 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc index f0525e5e2c9..c5ee71b3f3f 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_test.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder_test.cc @@ -250,7 +250,7 @@ class InterpreterFp16 : public DelegatedInterpreter { InterpreterFp16* interpreter_fp16_add_op = new InterpreterFp16(kTfLiteBuiltinAdd); -TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) { +TEST(ModelBuilderTest, GetOpsToReplaceAcceptsFp16DequantizeNodes) { // Before pruning, the graph has three nodes: // // t0 (FP16) -> DequantNode -> t1 (FP32) -> Add -> t4 @@ -283,14 +283,15 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) { context->PreviewDelegatePartitioning = [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, TfLiteDelegateParams** partition_params_array, int* num_partitions) { + // The partitioner should accept only the Add op initially. + EXPECT_EQ(nodes_to_replace->size, 1); + // Single partition output. auto params = interpreter_fp16_add_op->add_delegate_params(); - params->nodes_to_replace = TfLiteIntArrayCreate(3); - params->nodes_to_replace->data[0] = 0; - params->nodes_to_replace->data[1] = 1; - params->nodes_to_replace->data[2] = 2; + params->nodes_to_replace = TfLiteIntArrayCreate(1); + params->nodes_to_replace->data[0] = 2; params->input_tensors = TfLiteIntArrayCreate(2); - params->input_tensors->data[0] = 0; - params->input_tensors->data[1] = 2; + params->input_tensors->data[0] = 1; + params->input_tensors->data[1] = 3; params->output_tensors = TfLiteIntArrayCreate(1); params->output_tensors->data[0] = 4; @@ -301,11 +302,13 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) { TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); - // Replace all nodes. + // The Dequant nodes are added to ops_to_replace as a post-processing step by + // the FP16GraphPartitioner. ADD is delegated with its inputs pointing to the + // FP16 inputs. EXPECT_EQ(ops_to_replace->size, 3); TfLiteNode* node = nullptr; TfLiteRegistration* registration = nullptr; - context->GetNodeAndRegistration(context, ops_to_replace->data[2], &node, + context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node, ®istration); EXPECT_EQ(context->tensors[node->inputs->data[0]].type, TfLiteType::kTfLiteFloat16); @@ -317,14 +320,14 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) { InterpreterFp16* interpreter_fp16_gt_op = new InterpreterFp16(kTfLiteBuiltinGreater); -TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) { +TEST(ModelBuilderTest, GetOpsToReplaceRejectsFp16DequantizeNodes) { // Before pruning, the graph has three nodes: // // t0 (FP16) -> DequantNode -> t1 (FP32) -> Greater Op -> t4 // t2 (FP16) -> DequantNode -> t3 (FP32) --/ // - // Because there is no GPU equivalent for the Greater op, we don't prune - // the Dequantize nodes. + // Because there is no GPU equivalent for the Greater op, we don't choose any + // nodes. TfLiteContext* context = interpreter_fp16_gt_op->context(); // These functions are meant to be called inside delegates. Swap out @@ -346,26 +349,10 @@ TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) { context->PreviewDelegatePartitioning = [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, TfLiteDelegateParams** partition_params_array, int* num_partitions) { - auto params = interpreter_fp16_gt_op->add_delegate_params(); - // First partition for DequantNode (t0->t1) - params->nodes_to_replace = TfLiteIntArrayCreate(1); - params->nodes_to_replace->data[0] = 0; - params->input_tensors = TfLiteIntArrayCreate(1); - params->input_tensors->data[0] = 0; - params->output_tensors = TfLiteIntArrayCreate(1); - params->output_tensors->data[0] = 1; - - // Second partition for DequantNode (t2->t3) - params = interpreter_fp16_gt_op->add_delegate_params(); - params->nodes_to_replace = TfLiteIntArrayCreate(1); - params->nodes_to_replace->data[0] = 0; - params->input_tensors = TfLiteIntArrayCreate(1); - params->input_tensors->data[0] = 0; - params->output_tensors = TfLiteIntArrayCreate(1); - params->output_tensors->data[0] = 1; - - *partition_params_array = interpreter_fp16_gt_op->delegate_params(); - *num_partitions = interpreter_fp16_gt_op->num_delegate_params(); + // No selected nodes. + EXPECT_EQ(nodes_to_replace->size, 0); + *partition_params_array = nullptr; + *num_partitions = 0; return kTfLiteOk; }; @@ -685,7 +672,7 @@ TEST(ModelBuilderTest, GetOpsToReplaceMultiplePartitions) { class InterpreterMultiNode : public DelegatedInterpreter { public: - explicit InterpreterMultiNode(bool add_op_first = true) + explicit InterpreterMultiNode(bool both_ops_supported = true) : DelegatedInterpreter(5) { void* builtin_data = malloc(sizeof(int)); EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk); @@ -707,8 +694,8 @@ class InterpreterMultiNode : public DelegatedInterpreter { kTfLiteOk); } - if (add_op_first) { - // Add the ADD op node that GPU delegate supports. + if (both_ops_supported) { + // Add 2 ADD ops. const TfLiteRegistration reg_add0 = { [](TfLiteContext* context, const char* buffer, size_t length) { return reinterpret_cast(new int(1)); @@ -727,8 +714,7 @@ class InterpreterMultiNode : public DelegatedInterpreter { /*registration=*/®_add0), kTfLiteOk); - // Add the GREATER op node that GPU delegate doesn't support. - const TfLiteRegistration reg_greater = { + const TfLiteRegistration reg_add1 = { [](TfLiteContext* context, const char* buffer, size_t length) { return reinterpret_cast(new int(1)); }, @@ -738,12 +724,12 @@ class InterpreterMultiNode : public DelegatedInterpreter { nullptr, nullptr, nullptr, - kTfLiteBuiltinGreater}; + kTfLiteBuiltinAdd}; EXPECT_EQ(interpreter_.AddNodeWithParameters( /*inputs=*/{3, 4}, /*outputs=*/{6}, /*init_data=*/nullptr, /*init_data_size=*/0, /*builtin_data=*/builtin_data, - /*registration=*/®_greater), + /*registration=*/®_add1), kTfLiteOk); } else { // Add the GREATER op node that GPU delegate doesn't support. @@ -828,19 +814,19 @@ class InterpreterMultiNode : public DelegatedInterpreter { } }; -InterpreterMultiNode* interpreter_mn = new InterpreterMultiNode(); +InterpreterMultiNode* interpreter_mn = + new InterpreterMultiNode(/*both_ops_supported*/ false); -TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) { +TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectFp16Nodes_SinglePartition) { // A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'. // 'Add' can be replaced by the GPU delegate, but 'Greater' can not. - // t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(4) -> t6 + // t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(3) -> t6 // t1 (FP16) --> Dequant(1) --> t4 (FP32) --/ // --\ - // t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(3) -> t7 + // t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7 // - // OpsToReplace should replace the 'Add' op and the Dequant outputing - // t5, but leave the other Dequant nodes because 'Greater' must run - // on the CPU. + // OpsToReplace should accept 'Add' & the Dequant nodes that only output to + // it (in this case, Dequant(2)). TfLiteContext* context = interpreter_mn->context(); // These functions are meant to be called inside delegates. Swap out @@ -861,33 +847,16 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) { context->PreviewDelegatePartitioning = [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, TfLiteDelegateParams** partition_params_array, int* num_partitions) { + // The FP16GraphPartitioner should only mark the ADD op as accepted. + EXPECT_EQ(nodes_to_replace->size, 1); + EXPECT_EQ(nodes_to_replace->data[0], 4); + // Single partition. auto params = interpreter_mn->add_delegate_params(); - // First partition for DequantNode(0) params->nodes_to_replace = TfLiteIntArrayCreate(1); - params->nodes_to_replace->data[0] = 0; - params->input_tensors = TfLiteIntArrayCreate(1); - params->input_tensors->data[0] = 0; - params->output_tensors = TfLiteIntArrayCreate(1); - params->output_tensors->data[0] = 3; - - // Second partition for DequantNode(1) - params = interpreter_mn->add_delegate_params(); - params->nodes_to_replace = TfLiteIntArrayCreate(1); - params->nodes_to_replace->data[0] = 1; - params->input_tensors = TfLiteIntArrayCreate(1); - params->input_tensors->data[0] = 1; - params->output_tensors = TfLiteIntArrayCreate(1); - params->output_tensors->data[0] = 4; - - // Third partition for DequantNode(1), DequantNode(2), ADD(3) - params = interpreter_mn->add_delegate_params(); - params->nodes_to_replace = TfLiteIntArrayCreate(3); - params->nodes_to_replace->data[0] = 1; - params->nodes_to_replace->data[1] = 2; - params->nodes_to_replace->data[2] = 3; + params->nodes_to_replace->data[0] = 4; params->input_tensors = TfLiteIntArrayCreate(2); params->input_tensors->data[0] = 1; - params->input_tensors->data[0] = 3; + params->input_tensors->data[1] = 3; params->output_tensors = TfLiteIntArrayCreate(1); params->output_tensors->data[0] = 7; @@ -898,16 +867,16 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) { TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); + // Post-PreviewDelegatePartitioning, the partitioner will add Dequant(2) to + // ops_to_replace, since it only outputs to a delegated node. EXPECT_EQ(ops_to_replace->size, 2); - // Op at index 2 is the Dequant op (t3 -> t5). - EXPECT_EQ(ops_to_replace->data[0], 2); - // Op at index 3 is the Add op. - EXPECT_EQ(ops_to_replace->data[1], 3); - + // Op at index 4 is the Add op. + EXPECT_EQ(ops_to_replace->data[0], 4); + EXPECT_EQ(ops_to_replace->data[1], 2); + // Verify that Add op has fp16 inputs. TfLiteNode* node = nullptr; TfLiteRegistration* registration = nullptr; - // Verify that Add op has fp16 inputs. - context->GetNodeAndRegistration(context, ops_to_replace->data[1], &node, + context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node, ®istration); EXPECT_EQ(context->tensors[node->inputs->data[0]].type, TfLiteType::kTfLiteFloat16); @@ -917,21 +886,18 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) { } InterpreterMultiNode* interpreter_mn2 = - new InterpreterMultiNode(/*add_op_first=*/false); -TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsGreaterFirst) { - // A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'. - // 'Add' can be replaced by the GPU delegate, but 'Greater' can not. - // t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(3) -> t6 + new InterpreterMultiNode(/*both_ops_supported*/ true); +TEST(ModelBuilderTest, + GetOpsToReplaceSelectsCorrectFp16Nodes_MultiplePartitions) { + // A graph with three Dequant nodes feeding two Add ops. + // t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Add(3) -> t6 // t1 (FP16) --> Dequant(1) --> t4 (FP32) --/ // --\ // t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7 // - // Note: the graph dependency is exactly same w/ that in - // GetOpsToReplaceSelectsCorrectDequantsAddFirst, but the unsupported - // 'Greater' op appears first in the execution plan. Despite this, - // OpsToReplace should still replace the 'Add' op and the Dequant outputing - // t5, but leave the other Dequant nodes because 'Greater' must run - // on the CPU. + // In this test case, we purposely partition Add(3) & Add(4) into different + // partitions, to check if Dequant nodes that output *only* to the first + // partition nodes are accepted. TfLiteContext* context = interpreter_mn2->context(); @@ -954,33 +920,29 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsGreaterFirst) { context->PreviewDelegatePartitioning = [](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace, TfLiteDelegateParams** partition_params_array, int* num_partitions) { + // The FP16GraphPartitioner should only mark both ADD ops as accepted. + EXPECT_EQ(nodes_to_replace->size, 2); + EXPECT_EQ(nodes_to_replace->data[0], 3); + EXPECT_EQ(nodes_to_replace->data[1], 4); + // Technically, both ADD ops should end up in the same partition. + // But we put them in different partitions to test post-processing with + // DEQUANTIZE nodes. + // First partition with Add(3). auto params = interpreter_mn2->add_delegate_params(); - // First partition for DequantNode(0) params->nodes_to_replace = TfLiteIntArrayCreate(1); - params->nodes_to_replace->data[0] = 0; - params->input_tensors = TfLiteIntArrayCreate(1); - params->input_tensors->data[0] = 0; - params->output_tensors = TfLiteIntArrayCreate(1); - params->output_tensors->data[0] = 3; - - // Second partition for DequantNode(1) - params = interpreter_mn2->add_delegate_params(); - params->nodes_to_replace = TfLiteIntArrayCreate(1); - params->nodes_to_replace->data[0] = 1; - params->input_tensors = TfLiteIntArrayCreate(1); - params->input_tensors->data[0] = 1; - params->output_tensors = TfLiteIntArrayCreate(1); - params->output_tensors->data[0] = 4; - - // Third partition for DequantNode(1), DequantNode(2), ADD(4) - params = interpreter_mn2->add_delegate_params(); - params->nodes_to_replace = TfLiteIntArrayCreate(3); - params->nodes_to_replace->data[0] = 1; - params->nodes_to_replace->data[1] = 2; - params->nodes_to_replace->data[2] = 4; + params->nodes_to_replace->data[0] = 3; params->input_tensors = TfLiteIntArrayCreate(2); - params->input_tensors->data[0] = 1; params->input_tensors->data[0] = 3; + params->input_tensors->data[1] = 4; + params->output_tensors = TfLiteIntArrayCreate(1); + params->output_tensors->data[0] = 6; + // Second partition with Add(4). + params = interpreter_mn2->add_delegate_params(); + params->nodes_to_replace = TfLiteIntArrayCreate(1); + params->nodes_to_replace->data[0] = 4; + params->input_tensors = TfLiteIntArrayCreate(2); + params->input_tensors->data[0] = 4; + params->input_tensors->data[1] = 5; params->output_tensors = TfLiteIntArrayCreate(1); params->output_tensors->data[0] = 7; @@ -989,23 +951,32 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsGreaterFirst) { return kTfLiteOk; }; - TfLiteIntArray* ops_to_replace = GetOpsToReplace(context); + TfLiteIntArray* ops_to_replace = GetOpsToReplace( + context, /*allow_quant_ops*/ false, /*max_delegated_partitions*/ 2); - EXPECT_EQ(ops_to_replace->size, 2); - // Op at index 2 is the Dequant op (t3 -> t5). - EXPECT_EQ(ops_to_replace->data[0], 2); - // Op at index 4 is the Add op. - EXPECT_EQ(ops_to_replace->data[1], 4); + // Three ops should be selected: + // Add(3), Dequant(x), Add(4) + // Since both partitions are of size 1, either could end up as the 'first' + // partition with one Dequant node added for it. + EXPECT_EQ(ops_to_replace->size, 3); TfLiteNode* node = nullptr; TfLiteRegistration* registration = nullptr; - // Verify that Add op has fp16 inputs. - context->GetNodeAndRegistration(context, ops_to_replace->data[1], &node, + // Verify that both Add ops have fp16 inputs. + context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node, ®istration); EXPECT_EQ(context->tensors[node->inputs->data[0]].type, TfLiteType::kTfLiteFloat16); EXPECT_EQ(context->tensors[node->inputs->data[1]].type, TfLiteType::kTfLiteFloat16); + context->GetNodeAndRegistration(context, ops_to_replace->data[2], &node, + ®istration); + EXPECT_EQ(context->tensors[node->inputs->data[0]].type, + TfLiteType::kTfLiteFloat16); + EXPECT_EQ(context->tensors[node->inputs->data[1]].type, + TfLiteType::kTfLiteFloat16); + // Verify that the op at index 1 is a Dequant outputing to a single Add. + EXPECT_TRUE(ops_to_replace->data[1] == 0 || ops_to_replace->data[1] == 2); TfLiteIntArrayFree(ops_to_replace); } diff --git a/tensorflow/lite/delegates/utils.cc b/tensorflow/lite/delegates/utils.cc index 135b4d531f9..873cadc180f 100644 --- a/tensorflow/lite/delegates/utils.cc +++ b/tensorflow/lite/delegates/utils.cc @@ -150,74 +150,112 @@ TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes( return kTfLiteOk; } -TfLiteStatus FP16GraphPartitionHelper::Partition( - std::set* unsupported_nodes_info) { - const auto status = GraphPartitionHelper::Partition(unsupported_nodes_info); - // Clean up those partitions that have a single dequant op. NoteThose - // removed dequant ops have to be reserved in the graph and should not be - // delegated. - RemoveSingleDequantNodePartitions(); - return status; -} - std::vector FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl( int n, int min_nodes_per_partition) { - std::vector ops_to_replace = - GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl( - n, min_nodes_per_partition); - RemapInputTensors(ops_to_replace); - RemoveReservedDequantsFromNodes(&ops_to_replace); + auto first_n_partitions = + GetFirstNLargestPartitions(n, min_nodes_per_partition); + std::vector ops_to_replace; + if (first_n_partitions.empty()) return ops_to_replace; + + // Handle the first delegated partition specially. + // All fp16 DEQUANTIZE nodes whose consumers exist only in this partition can + // be added to the ops to delegate. Others have to be preserved in the graph, + // since the partitioning algorithm will put such nodes greedily in the first + // partition. + const auto* first_partition = first_n_partitions[0]; + std::unordered_map delegated_dequant_consumers; + for (int i = 0; i < first_partition->nodes_to_replace->size; ++i) { + const int node_id = first_partition->nodes_to_replace->data[i]; + ops_to_replace.push_back(node_id); + TfLiteNode* node; + TfLiteRegistration* registration; + const auto status = context_->GetNodeAndRegistration(context_, node_id, + &node, ®istration); + if (status != kTfLiteOk) { + TF_LITE_KERNEL_LOG(context_, + "Couldn't get node and registration info for op: %d\n", + node_id); + ops_to_replace.clear(); + return ops_to_replace; + } + // See if any input to the op is a (converted) fp16 value. If yes, increment + // its value in delegated_dequant_consumers. + for (int j = 0; j < node->inputs->size; ++j) { + const int input_tid = node->inputs->data[j]; + if (dequant_consumers_.find(input_tid) != dequant_consumers_.end()) { + delegated_dequant_consumers[input_tid] += 1; + } + } + } + // Check all dequant nodes that have some consumers in the first partition. + // If the number of delegated consumers is same as total number of consumers, + // add the corresponding DEQUANTIZE op to the delegated nodes. + for (auto tensor_and_consumers : delegated_dequant_consumers) { + if (dequant_consumers_[tensor_and_consumers.first] == + tensor_and_consumers.second) { + ops_to_replace.emplace_back(dequant_nodes_[tensor_and_consumers.first]); + } + } + + // For all other partitions after the first one, insert all nodes into + // ops_to_replace. + for (int i = 1; i < first_n_partitions.size(); ++i) { + auto nodes = first_n_partitions[i]->nodes_to_replace; + ops_to_replace.insert(ops_to_replace.end(), nodes->data, + nodes->data + nodes->size); + } + + // Modify the inputs of relevant ops that support fp16 constants. + // TODO(b/156707497): Ensure that these inputs are remapped during the + // delegate's 'free', so that CPU fallback works for fp16 models. + RemapFp16InputTensors(ops_to_replace); return ops_to_replace; } bool FP16GraphPartitionHelper::IsNodeSupported( TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration, int node_id, std::string* unsupported_details) { - // If we need to handle dequant nodes, we have to remap input tensors of - // this node if some of them come from a dequant node before testing if - // the node is supported. - std::vector orig_inputs; - if (RecordAndRemapInputTensors(registration->builtin_code, node_id, node, - &orig_inputs)) { - // We have a dequant op here. Note that we retrun an Ok status because a - // dequant node is first added as supported. Later, this dequant node - // will be removed if it has to be preserved in the graph which happens - // when its immediate downstream nodes cannot be supported. - return true; - } - const auto status = GraphPartitionHelper::IsNodeSupported( - context, node, registration, node_id, unsupported_details); - RestoreToOrigInputTensors(node, orig_inputs); - return status; -} - -bool FP16GraphPartitionHelper::RecordAndRemapInputTensors( - int32_t op_code, int node_id, TfLiteNode* node, - std::vector* orig_inputs) { - orig_inputs->clear(); - // Record the dequant node. - if (op_code == kTfLiteBuiltinDequantize && + if (registration->builtin_code == kTfLiteBuiltinDequantize && context_->tensors[node->inputs->data[0]].type == TfLiteType::kTfLiteFloat16) { - dequant_nodes_[node->outputs->data[0]] = node->inputs->data[0]; - return true; + // Update mappings if this node is a fp16 DEQUANTIZE node. + dequant_map_[node->outputs->data[0]] = node->inputs->data[0]; + dequant_nodes_[node->outputs->data[0]] = node_id; + // We do not accept these ops right now. + // This is done to support use-cases where a DEQUANTIZE output might be + // consumed by a CPU op. + return false; } - // For a dequantize op, there's no need to remap its input tensors. - if (dequant_nodes_.empty()) return false; - RemapInputTensors(node, orig_inputs); - return false; + + // To check if a (possibly) FP16 node is supported, we temporarily point the + // node's inputs to the original fp16 tensors. This 'mutated' node is then + // passed to the base IsNodeSupported function for checking. After the check, + // we remap the original node inputs, so that the TFLite graph remains the + // same. + std::vector orig_inputs; + if (!dequant_nodes_.empty()) { + RemapFp16InputTensors(node, &orig_inputs); + } + + const auto is_supported = GraphPartitionHelper::IsNodeSupported( + context, node, registration, node_id, unsupported_details); + + if (!orig_inputs.empty() && node->inputs->size == orig_inputs.size()) { + // Remapping happened. Restore original inputs. + for (int j = 0; j < node->inputs->size; ++j) { + node->inputs->data[j] = orig_inputs[j]; + if (dequant_nodes_.find(orig_inputs[j]) != dequant_nodes_.end()) { + // If its a fp16 tensor, increment number of consumers of the + // corresponding DEQUANTIZE. + dequant_consumers_[orig_inputs[j]] += 1; + } + } + } + return is_supported; } -void FP16GraphPartitionHelper::RestoreToOrigInputTensors( - TfLiteNode* node, const std::vector& orig_inputs) { - if (node->inputs->size != orig_inputs.size()) return; - for (int j = 0; j < node->inputs->size; ++j) { - node->inputs->data[j] = orig_inputs[j]; - } -} - -void FP16GraphPartitionHelper::RemapInputTensors( +void FP16GraphPartitionHelper::RemapFp16InputTensors( const std::vector& nodes) const { for (int node_id : nodes) { TfLiteNode* node; @@ -229,56 +267,11 @@ void FP16GraphPartitionHelper::RemapInputTensors( "Couldn't get node and registration info for op: %d\n", node_id); } - RemapInputTensors(node, nullptr /* orig_inputs*/); + RemapFp16InputTensors(node, nullptr /* orig_inputs*/); } } -void FP16GraphPartitionHelper::RemoveSingleDequantNodePartitions() { - auto it = partitions_.begin(); - while (it != partitions_.end()) { - auto p = *it; - if (p->nodes_to_replace->size != 1) { - ++it; - continue; - } - int node_id = p->nodes_to_replace->data[0]; - TfLiteNode* node = nullptr; - TfLiteRegistration* registration = nullptr; - - TfLiteStatus status = context_->GetNodeAndRegistration( - context_, node_id, &node, ®istration); - if (status != kTfLiteOk) { - TF_LITE_KERNEL_LOG(context_, - "Couldn't get node and registration info for op: %d\n", - node_id); - } - if (registration->builtin_code != kTfLiteBuiltinDequantize || - context_->tensors[node->inputs->data[0]].type != - TfLiteType::kTfLiteFloat16) { - ++it; - continue; - } - // Note such dequant nodes have to be preserved in the graph as dequant - // ops are not actually supported in the GPU delegate. - dequant_nodes_to_save_.insert(node_id); - it = partitions_.erase(it); - } -} - -void FP16GraphPartitionHelper::RemoveReservedDequantsFromNodes( - std::vector* nodes) { - if (dequant_nodes_to_save_.empty()) return; - auto it = nodes->begin(); - while (it != nodes->end()) { - if (dequant_nodes_to_save_.find(*it) == dequant_nodes_to_save_.end()) { - ++it; - continue; - } - it = nodes->erase(it); - } -} - -void FP16GraphPartitionHelper::RemapInputTensors( +void FP16GraphPartitionHelper::RemapFp16InputTensors( TfLiteNode* node, std::vector* orig_inputs) const { TfLiteIntArray* inputs = node->inputs; auto inputs_view = TfLiteIntArrayView(inputs); @@ -296,8 +289,8 @@ void FP16GraphPartitionHelper::RemapInputTensors( bool is_remapped = false; for (int j = 0; j < inputs->size; ++j) { const int input_tid = inputs->data[j]; - const auto it = dequant_nodes_.find(input_tid); - if (it != dequant_nodes_.end()) { + const auto it = dequant_map_.find(input_tid); + if (it != dequant_map_.end()) { inputs->data[j] = it->second; is_remapped = true; } diff --git a/tensorflow/lite/delegates/utils.h b/tensorflow/lite/delegates/utils.h index 6b498b908f9..12684fcb84a 100644 --- a/tensorflow/lite/delegates/utils.h +++ b/tensorflow/lite/delegates/utils.h @@ -127,19 +127,23 @@ class GraphPartitionHelper { TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory }; -// While partitioning the graph, this claims DEQUANTIZE nodes (FP16->FP32) in -// addition to supported nodes for the delegate, when the DEQUANTIZE node's -// output is an input to the kernel that supports FP16 input. +// Specialized partitioner for graphs that possibly contain fp16 tensors. +// +// From nodes that accept fp16 inputs, this delegates the following: +// 1. All nodes (except DEQUANTIZE) that are supported with fp16 inputs by the +// delegate (in the TFLite graph, these nodes take in dequantized FP32 +// outputs). +// 2. All fp16 DEQUANTIZE nodes that have *all* their consumers in the *first* +// delegated partition. This is because TFLite's partitioning algorithm +// greedily puts all such nodes in the first partition. class FP16GraphPartitionHelper : public GraphPartitionHelper { public: FP16GraphPartitionHelper(TfLiteContext* context, IsNodeSupportedFn is_node_supported_fn) : GraphPartitionHelper(context, std::move(is_node_supported_fn)) {} - TfLiteStatus Partition( - std::set* unsupported_nodes_info) override; - protected: + // Specialized function to handle fp16 nodes. bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration, int node_id, std::string* unsupported_details) override; @@ -149,39 +153,25 @@ class FP16GraphPartitionHelper : public GraphPartitionHelper { int n, int min_nodes_per_partition) override; private: - // Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true. - // When it's not a dequant op, remap its inputs to the inputs of the preceding - // dequant if there's a one and returns false. 'orig_inputs' records original - // input tensor ids of this node if any input is remapped. - bool RecordAndRemapInputTensors(int32_t op_code, int node_id, - TfLiteNode* node, - std::vector* orig_inputs); + // This remaps fp32 inputs of the given node to their corresponding fp16 + // version, if applicable. Can be summarized as: + // fp16 -> DEQUANTIZE -> fp32 -> OP -> output + // becomes + // fp16 -> OP -> output + void RemapFp16InputTensors(TfLiteNode* node, + std::vector* orig_inputs) const; - // Restore inputs of 'node' to 'orig_inputs' only if two sizes match. - void RestoreToOrigInputTensors(TfLiteNode* node, - const std::vector& orig_inputs); + // Performs the above remapping for all nodes in the given list, without + // tracking the original inputs. + void RemapFp16InputTensors(const std::vector& nodes) const; - // Remap input tensors of every node in 'nodes' (i.e. node indices) if some of - // them are from dequant ops. - void RemapInputTensors(const std::vector& nodes) const; - - void RemoveSingleDequantNodePartitions(); - - void RemoveReservedDequantsFromNodes(std::vector* nodes); - - // Remap input tensors of a single 'node' if some of come from a dequant op. - // If 'orig_inputs' isn't nullptr, it records original input tensor ids of - // this node if any input is remapped. - void RemapInputTensors(TfLiteNode* node, std::vector* orig_inputs) const; - - // A map recording dequantize nodes's input/output tensors of this selected - // graph. The key is the output tensor id, and the value is the input tensor - // id. + // ('dequantize' here refers to fp16 DEQUANTIZE) + // Mapping of dequantize nodes' output tensor-id to its node id. std::unordered_map dequant_nodes_; - - // A set of dequant nodes as in node indices that have to be preserved in the - // graph. - std::set dequant_nodes_to_save_; + // Mapping of DEQUANTIZE node's output (fp32) to its input (fp16). + std::unordered_map dequant_map_; + // mapping of DEQUANTIZE output tensor-id to its number of consumers. + std::unordered_map dequant_consumers_; }; } // namespace delegates diff --git a/tensorflow/lite/graph_info.cc b/tensorflow/lite/graph_info.cc index a419a56a9e6..8968fe6cb21 100644 --- a/tensorflow/lite/graph_info.cc +++ b/tensorflow/lite/graph_info.cc @@ -29,6 +29,10 @@ namespace { // PartitionGraphIntoIndependentNodeSubsetsImpl partitioner( // info, nodes_to_part, node_subsets); // partitioner.Partition(); +// +// NOTE: Changing the partitioning logic would require a change to +// FP16GraphPartitionHelper. +// LINT.IfChange class PartitionGraphIntoIndependentNodeSubsetsImpl { public: PartitionGraphIntoIndependentNodeSubsetsImpl( @@ -198,6 +202,7 @@ class PartitionGraphIntoIndependentNodeSubsetsImpl { // negative values of kEpochNotReady if not assigned. std::vector node_epochs_; }; +// LINT.ThenChange(//tensorflow/lite/delegates/utils.h) } // namespace