Rectifies delegation of DEQUANTIZE nodes in fp16 graphs. Only those dequant nodes which are consumed in the first delegated partition are now selected.

PiperOrigin-RevId: 314627542
Change-Id: I801922eef3a95d3d4b05d53f5e8d27a0842c3be6
This commit is contained in:
Sachin Joglekar 2020-06-03 16:32:37 -07:00 committed by TensorFlower Gardener
parent 70fd126d3a
commit c0b6b669e2
4 changed files with 216 additions and 257 deletions

View File

@ -250,7 +250,7 @@ class InterpreterFp16 : public DelegatedInterpreter {
InterpreterFp16* interpreter_fp16_add_op =
new InterpreterFp16(kTfLiteBuiltinAdd);
TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
TEST(ModelBuilderTest, GetOpsToReplaceAcceptsFp16DequantizeNodes) {
// Before pruning, the graph has three nodes:
//
// t0 (FP16) -> DequantNode -> t1 (FP32) -> Add -> t4
@ -283,14 +283,15 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
context->PreviewDelegatePartitioning =
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
// The partitioner should accept only the Add op initially.
EXPECT_EQ(nodes_to_replace->size, 1);
// Single partition output.
auto params = interpreter_fp16_add_op->add_delegate_params();
params->nodes_to_replace = TfLiteIntArrayCreate(3);
params->nodes_to_replace->data[0] = 0;
params->nodes_to_replace->data[1] = 1;
params->nodes_to_replace->data[2] = 2;
params->nodes_to_replace = TfLiteIntArrayCreate(1);
params->nodes_to_replace->data[0] = 2;
params->input_tensors = TfLiteIntArrayCreate(2);
params->input_tensors->data[0] = 0;
params->input_tensors->data[1] = 2;
params->input_tensors->data[0] = 1;
params->input_tensors->data[1] = 3;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 4;
@ -301,11 +302,13 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
// Replace all nodes.
// The Dequant nodes are added to ops_to_replace as a post-processing step by
// the FP16GraphPartitioner. ADD is delegated with its inputs pointing to the
// FP16 inputs.
EXPECT_EQ(ops_to_replace->size, 3);
TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr;
context->GetNodeAndRegistration(context, ops_to_replace->data[2], &node,
context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
&registration);
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
TfLiteType::kTfLiteFloat16);
@ -317,14 +320,14 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
InterpreterFp16* interpreter_fp16_gt_op =
new InterpreterFp16(kTfLiteBuiltinGreater);
TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) {
TEST(ModelBuilderTest, GetOpsToReplaceRejectsFp16DequantizeNodes) {
// Before pruning, the graph has three nodes:
//
// t0 (FP16) -> DequantNode -> t1 (FP32) -> Greater Op -> t4
// t2 (FP16) -> DequantNode -> t3 (FP32) --/
//
// Because there is no GPU equivalent for the Greater op, we don't prune
// the Dequantize nodes.
// Because there is no GPU equivalent for the Greater op, we don't choose any
// nodes.
TfLiteContext* context = interpreter_fp16_gt_op->context();
// These functions are meant to be called inside delegates. Swap out
@ -346,26 +349,10 @@ TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) {
context->PreviewDelegatePartitioning =
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
auto params = interpreter_fp16_gt_op->add_delegate_params();
// First partition for DequantNode (t0->t1)
params->nodes_to_replace = TfLiteIntArrayCreate(1);
params->nodes_to_replace->data[0] = 0;
params->input_tensors = TfLiteIntArrayCreate(1);
params->input_tensors->data[0] = 0;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 1;
// Second partition for DequantNode (t2->t3)
params = interpreter_fp16_gt_op->add_delegate_params();
params->nodes_to_replace = TfLiteIntArrayCreate(1);
params->nodes_to_replace->data[0] = 0;
params->input_tensors = TfLiteIntArrayCreate(1);
params->input_tensors->data[0] = 0;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 1;
*partition_params_array = interpreter_fp16_gt_op->delegate_params();
*num_partitions = interpreter_fp16_gt_op->num_delegate_params();
// No selected nodes.
EXPECT_EQ(nodes_to_replace->size, 0);
*partition_params_array = nullptr;
*num_partitions = 0;
return kTfLiteOk;
};
@ -685,7 +672,7 @@ TEST(ModelBuilderTest, GetOpsToReplaceMultiplePartitions) {
class InterpreterMultiNode : public DelegatedInterpreter {
public:
explicit InterpreterMultiNode(bool add_op_first = true)
explicit InterpreterMultiNode(bool both_ops_supported = true)
: DelegatedInterpreter(5) {
void* builtin_data = malloc(sizeof(int));
EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk);
@ -707,8 +694,8 @@ class InterpreterMultiNode : public DelegatedInterpreter {
kTfLiteOk);
}
if (add_op_first) {
// Add the ADD op node that GPU delegate supports.
if (both_ops_supported) {
// Add 2 ADD ops.
const TfLiteRegistration reg_add0 = {
[](TfLiteContext* context, const char* buffer, size_t length) {
return reinterpret_cast<void*>(new int(1));
@ -727,8 +714,7 @@ class InterpreterMultiNode : public DelegatedInterpreter {
/*registration=*/&reg_add0),
kTfLiteOk);
// Add the GREATER op node that GPU delegate doesn't support.
const TfLiteRegistration reg_greater = {
const TfLiteRegistration reg_add1 = {
[](TfLiteContext* context, const char* buffer, size_t length) {
return reinterpret_cast<void*>(new int(1));
},
@ -738,12 +724,12 @@ class InterpreterMultiNode : public DelegatedInterpreter {
nullptr,
nullptr,
nullptr,
kTfLiteBuiltinGreater};
kTfLiteBuiltinAdd};
EXPECT_EQ(interpreter_.AddNodeWithParameters(
/*inputs=*/{3, 4}, /*outputs=*/{6}, /*init_data=*/nullptr,
/*init_data_size=*/0,
/*builtin_data=*/builtin_data,
/*registration=*/&reg_greater),
/*registration=*/&reg_add1),
kTfLiteOk);
} else {
// Add the GREATER op node that GPU delegate doesn't support.
@ -828,19 +814,19 @@ class InterpreterMultiNode : public DelegatedInterpreter {
}
};
InterpreterMultiNode* interpreter_mn = new InterpreterMultiNode();
InterpreterMultiNode* interpreter_mn =
new InterpreterMultiNode(/*both_ops_supported*/ false);
TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) {
TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectFp16Nodes_SinglePartition) {
// A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'.
// 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
// t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(4) -> t6
// t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(3) -> t6
// t1 (FP16) --> Dequant(1) --> t4 (FP32) --/
// --\
// t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(3) -> t7
// t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7
//
// OpsToReplace should replace the 'Add' op and the Dequant outputing
// t5, but leave the other Dequant nodes because 'Greater' must run
// on the CPU.
// OpsToReplace should accept 'Add' & the Dequant nodes that only output to
// it (in this case, Dequant(2)).
TfLiteContext* context = interpreter_mn->context();
// These functions are meant to be called inside delegates. Swap out
@ -861,33 +847,16 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) {
context->PreviewDelegatePartitioning =
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
// The FP16GraphPartitioner should only mark the ADD op as accepted.
EXPECT_EQ(nodes_to_replace->size, 1);
EXPECT_EQ(nodes_to_replace->data[0], 4);
// Single partition.
auto params = interpreter_mn->add_delegate_params();
// First partition for DequantNode(0)
params->nodes_to_replace = TfLiteIntArrayCreate(1);
params->nodes_to_replace->data[0] = 0;
params->input_tensors = TfLiteIntArrayCreate(1);
params->input_tensors->data[0] = 0;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 3;
// Second partition for DequantNode(1)
params = interpreter_mn->add_delegate_params();
params->nodes_to_replace = TfLiteIntArrayCreate(1);
params->nodes_to_replace->data[0] = 1;
params->input_tensors = TfLiteIntArrayCreate(1);
params->input_tensors->data[0] = 1;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 4;
// Third partition for DequantNode(1), DequantNode(2), ADD(3)
params = interpreter_mn->add_delegate_params();
params->nodes_to_replace = TfLiteIntArrayCreate(3);
params->nodes_to_replace->data[0] = 1;
params->nodes_to_replace->data[1] = 2;
params->nodes_to_replace->data[2] = 3;
params->nodes_to_replace->data[0] = 4;
params->input_tensors = TfLiteIntArrayCreate(2);
params->input_tensors->data[0] = 1;
params->input_tensors->data[0] = 3;
params->input_tensors->data[1] = 3;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 7;
@ -898,16 +867,16 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) {
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
// Post-PreviewDelegatePartitioning, the partitioner will add Dequant(2) to
// ops_to_replace, since it only outputs to a delegated node.
EXPECT_EQ(ops_to_replace->size, 2);
// Op at index 2 is the Dequant op (t3 -> t5).
EXPECT_EQ(ops_to_replace->data[0], 2);
// Op at index 3 is the Add op.
EXPECT_EQ(ops_to_replace->data[1], 3);
// Op at index 4 is the Add op.
EXPECT_EQ(ops_to_replace->data[0], 4);
EXPECT_EQ(ops_to_replace->data[1], 2);
// Verify that Add op has fp16 inputs.
TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr;
// Verify that Add op has fp16 inputs.
context->GetNodeAndRegistration(context, ops_to_replace->data[1], &node,
context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
&registration);
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
TfLiteType::kTfLiteFloat16);
@ -917,21 +886,18 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) {
}
InterpreterMultiNode* interpreter_mn2 =
new InterpreterMultiNode(/*add_op_first=*/false);
TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsGreaterFirst) {
// A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'.
// 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
// t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(3) -> t6
new InterpreterMultiNode(/*both_ops_supported*/ true);
TEST(ModelBuilderTest,
GetOpsToReplaceSelectsCorrectFp16Nodes_MultiplePartitions) {
// A graph with three Dequant nodes feeding two Add ops.
// t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Add(3) -> t6
// t1 (FP16) --> Dequant(1) --> t4 (FP32) --/
// --\
// t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7
//
// Note: the graph dependency is exactly same w/ that in
// GetOpsToReplaceSelectsCorrectDequantsAddFirst, but the unsupported
// 'Greater' op appears first in the execution plan. Despite this,
// OpsToReplace should still replace the 'Add' op and the Dequant outputing
// t5, but leave the other Dequant nodes because 'Greater' must run
// on the CPU.
// In this test case, we purposely partition Add(3) & Add(4) into different
// partitions, to check if Dequant nodes that output *only* to the first
// partition nodes are accepted.
TfLiteContext* context = interpreter_mn2->context();
@ -954,33 +920,29 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsGreaterFirst) {
context->PreviewDelegatePartitioning =
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
// The FP16GraphPartitioner should only mark both ADD ops as accepted.
EXPECT_EQ(nodes_to_replace->size, 2);
EXPECT_EQ(nodes_to_replace->data[0], 3);
EXPECT_EQ(nodes_to_replace->data[1], 4);
// Technically, both ADD ops should end up in the same partition.
// But we put them in different partitions to test post-processing with
// DEQUANTIZE nodes.
// First partition with Add(3).
auto params = interpreter_mn2->add_delegate_params();
// First partition for DequantNode(0)
params->nodes_to_replace = TfLiteIntArrayCreate(1);
params->nodes_to_replace->data[0] = 0;
params->input_tensors = TfLiteIntArrayCreate(1);
params->input_tensors->data[0] = 0;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 3;
// Second partition for DequantNode(1)
params = interpreter_mn2->add_delegate_params();
params->nodes_to_replace = TfLiteIntArrayCreate(1);
params->nodes_to_replace->data[0] = 1;
params->input_tensors = TfLiteIntArrayCreate(1);
params->input_tensors->data[0] = 1;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 4;
// Third partition for DequantNode(1), DequantNode(2), ADD(4)
params = interpreter_mn2->add_delegate_params();
params->nodes_to_replace = TfLiteIntArrayCreate(3);
params->nodes_to_replace->data[0] = 1;
params->nodes_to_replace->data[1] = 2;
params->nodes_to_replace->data[2] = 4;
params->nodes_to_replace->data[0] = 3;
params->input_tensors = TfLiteIntArrayCreate(2);
params->input_tensors->data[0] = 1;
params->input_tensors->data[0] = 3;
params->input_tensors->data[1] = 4;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 6;
// Second partition with Add(4).
params = interpreter_mn2->add_delegate_params();
params->nodes_to_replace = TfLiteIntArrayCreate(1);
params->nodes_to_replace->data[0] = 4;
params->input_tensors = TfLiteIntArrayCreate(2);
params->input_tensors->data[0] = 4;
params->input_tensors->data[1] = 5;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 7;
@ -989,23 +951,32 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsGreaterFirst) {
return kTfLiteOk;
};
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
TfLiteIntArray* ops_to_replace = GetOpsToReplace(
context, /*allow_quant_ops*/ false, /*max_delegated_partitions*/ 2);
EXPECT_EQ(ops_to_replace->size, 2);
// Op at index 2 is the Dequant op (t3 -> t5).
EXPECT_EQ(ops_to_replace->data[0], 2);
// Op at index 4 is the Add op.
EXPECT_EQ(ops_to_replace->data[1], 4);
// Three ops should be selected:
// Add(3), Dequant(x), Add(4)
// Since both partitions are of size 1, either could end up as the 'first'
// partition with one Dequant node added for it.
EXPECT_EQ(ops_to_replace->size, 3);
TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr;
// Verify that Add op has fp16 inputs.
context->GetNodeAndRegistration(context, ops_to_replace->data[1], &node,
// Verify that both Add ops have fp16 inputs.
context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
&registration);
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
TfLiteType::kTfLiteFloat16);
EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
TfLiteType::kTfLiteFloat16);
context->GetNodeAndRegistration(context, ops_to_replace->data[2], &node,
&registration);
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
TfLiteType::kTfLiteFloat16);
EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
TfLiteType::kTfLiteFloat16);
// Verify that the op at index 1 is a Dequant outputing to a single Add.
EXPECT_TRUE(ops_to_replace->data[1] == 0 || ops_to_replace->data[1] == 2);
TfLiteIntArrayFree(ops_to_replace);
}

View File

@ -150,74 +150,112 @@ TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes(
return kTfLiteOk;
}
TfLiteStatus FP16GraphPartitionHelper::Partition(
std::set<std::string>* unsupported_nodes_info) {
const auto status = GraphPartitionHelper::Partition(unsupported_nodes_info);
// Clean up those partitions that have a single dequant op. NoteThose
// removed dequant ops have to be reserved in the graph and should not be
// delegated.
RemoveSingleDequantNodePartitions();
return status;
}
std::vector<int>
FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
int n, int min_nodes_per_partition) {
std::vector<int> ops_to_replace =
GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
n, min_nodes_per_partition);
RemapInputTensors(ops_to_replace);
RemoveReservedDequantsFromNodes(&ops_to_replace);
auto first_n_partitions =
GetFirstNLargestPartitions(n, min_nodes_per_partition);
std::vector<int> ops_to_replace;
if (first_n_partitions.empty()) return ops_to_replace;
// Handle the first delegated partition specially.
// All fp16 DEQUANTIZE nodes whose consumers exist only in this partition can
// be added to the ops to delegate. Others have to be preserved in the graph,
// since the partitioning algorithm will put such nodes greedily in the first
// partition.
const auto* first_partition = first_n_partitions[0];
std::unordered_map<int, int> delegated_dequant_consumers;
for (int i = 0; i < first_partition->nodes_to_replace->size; ++i) {
const int node_id = first_partition->nodes_to_replace->data[i];
ops_to_replace.push_back(node_id);
TfLiteNode* node;
TfLiteRegistration* registration;
const auto status = context_->GetNodeAndRegistration(context_, node_id,
&node, &registration);
if (status != kTfLiteOk) {
TF_LITE_KERNEL_LOG(context_,
"Couldn't get node and registration info for op: %d\n",
node_id);
ops_to_replace.clear();
return ops_to_replace;
}
// See if any input to the op is a (converted) fp16 value. If yes, increment
// its value in delegated_dequant_consumers.
for (int j = 0; j < node->inputs->size; ++j) {
const int input_tid = node->inputs->data[j];
if (dequant_consumers_.find(input_tid) != dequant_consumers_.end()) {
delegated_dequant_consumers[input_tid] += 1;
}
}
}
// Check all dequant nodes that have some consumers in the first partition.
// If the number of delegated consumers is same as total number of consumers,
// add the corresponding DEQUANTIZE op to the delegated nodes.
for (auto tensor_and_consumers : delegated_dequant_consumers) {
if (dequant_consumers_[tensor_and_consumers.first] ==
tensor_and_consumers.second) {
ops_to_replace.emplace_back(dequant_nodes_[tensor_and_consumers.first]);
}
}
// For all other partitions after the first one, insert all nodes into
// ops_to_replace.
for (int i = 1; i < first_n_partitions.size(); ++i) {
auto nodes = first_n_partitions[i]->nodes_to_replace;
ops_to_replace.insert(ops_to_replace.end(), nodes->data,
nodes->data + nodes->size);
}
// Modify the inputs of relevant ops that support fp16 constants.
// TODO(b/156707497): Ensure that these inputs are remapped during the
// delegate's 'free', so that CPU fallback works for fp16 models.
RemapFp16InputTensors(ops_to_replace);
return ops_to_replace;
}
bool FP16GraphPartitionHelper::IsNodeSupported(
TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration,
int node_id, std::string* unsupported_details) {
// If we need to handle dequant nodes, we have to remap input tensors of
// this node if some of them come from a dequant node before testing if
// the node is supported.
std::vector<int> orig_inputs;
if (RecordAndRemapInputTensors(registration->builtin_code, node_id, node,
&orig_inputs)) {
// We have a dequant op here. Note that we retrun an Ok status because a
// dequant node is first added as supported. Later, this dequant node
// will be removed if it has to be preserved in the graph which happens
// when its immediate downstream nodes cannot be supported.
return true;
}
const auto status = GraphPartitionHelper::IsNodeSupported(
context, node, registration, node_id, unsupported_details);
RestoreToOrigInputTensors(node, orig_inputs);
return status;
}
bool FP16GraphPartitionHelper::RecordAndRemapInputTensors(
int32_t op_code, int node_id, TfLiteNode* node,
std::vector<int>* orig_inputs) {
orig_inputs->clear();
// Record the dequant node.
if (op_code == kTfLiteBuiltinDequantize &&
if (registration->builtin_code == kTfLiteBuiltinDequantize &&
context_->tensors[node->inputs->data[0]].type ==
TfLiteType::kTfLiteFloat16) {
dequant_nodes_[node->outputs->data[0]] = node->inputs->data[0];
return true;
}
// For a dequantize op, there's no need to remap its input tensors.
if (dequant_nodes_.empty()) return false;
RemapInputTensors(node, orig_inputs);
// Update mappings if this node is a fp16 DEQUANTIZE node.
dequant_map_[node->outputs->data[0]] = node->inputs->data[0];
dequant_nodes_[node->outputs->data[0]] = node_id;
// We do not accept these ops right now.
// This is done to support use-cases where a DEQUANTIZE output might be
// consumed by a CPU op.
return false;
}
}
void FP16GraphPartitionHelper::RestoreToOrigInputTensors(
TfLiteNode* node, const std::vector<int>& orig_inputs) {
if (node->inputs->size != orig_inputs.size()) return;
// To check if a (possibly) FP16 node is supported, we temporarily point the
// node's inputs to the original fp16 tensors. This 'mutated' node is then
// passed to the base IsNodeSupported function for checking. After the check,
// we remap the original node inputs, so that the TFLite graph remains the
// same.
std::vector<int> orig_inputs;
if (!dequant_nodes_.empty()) {
RemapFp16InputTensors(node, &orig_inputs);
}
const auto is_supported = GraphPartitionHelper::IsNodeSupported(
context, node, registration, node_id, unsupported_details);
if (!orig_inputs.empty() && node->inputs->size == orig_inputs.size()) {
// Remapping happened. Restore original inputs.
for (int j = 0; j < node->inputs->size; ++j) {
node->inputs->data[j] = orig_inputs[j];
if (dequant_nodes_.find(orig_inputs[j]) != dequant_nodes_.end()) {
// If its a fp16 tensor, increment number of consumers of the
// corresponding DEQUANTIZE.
dequant_consumers_[orig_inputs[j]] += 1;
}
}
}
return is_supported;
}
void FP16GraphPartitionHelper::RemapInputTensors(
void FP16GraphPartitionHelper::RemapFp16InputTensors(
const std::vector<int>& nodes) const {
for (int node_id : nodes) {
TfLiteNode* node;
@ -229,56 +267,11 @@ void FP16GraphPartitionHelper::RemapInputTensors(
"Couldn't get node and registration info for op: %d\n",
node_id);
}
RemapInputTensors(node, nullptr /* orig_inputs*/);
RemapFp16InputTensors(node, nullptr /* orig_inputs*/);
}
}
void FP16GraphPartitionHelper::RemoveSingleDequantNodePartitions() {
auto it = partitions_.begin();
while (it != partitions_.end()) {
auto p = *it;
if (p->nodes_to_replace->size != 1) {
++it;
continue;
}
int node_id = p->nodes_to_replace->data[0];
TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr;
TfLiteStatus status = context_->GetNodeAndRegistration(
context_, node_id, &node, &registration);
if (status != kTfLiteOk) {
TF_LITE_KERNEL_LOG(context_,
"Couldn't get node and registration info for op: %d\n",
node_id);
}
if (registration->builtin_code != kTfLiteBuiltinDequantize ||
context_->tensors[node->inputs->data[0]].type !=
TfLiteType::kTfLiteFloat16) {
++it;
continue;
}
// Note such dequant nodes have to be preserved in the graph as dequant
// ops are not actually supported in the GPU delegate.
dequant_nodes_to_save_.insert(node_id);
it = partitions_.erase(it);
}
}
void FP16GraphPartitionHelper::RemoveReservedDequantsFromNodes(
std::vector<int>* nodes) {
if (dequant_nodes_to_save_.empty()) return;
auto it = nodes->begin();
while (it != nodes->end()) {
if (dequant_nodes_to_save_.find(*it) == dequant_nodes_to_save_.end()) {
++it;
continue;
}
it = nodes->erase(it);
}
}
void FP16GraphPartitionHelper::RemapInputTensors(
void FP16GraphPartitionHelper::RemapFp16InputTensors(
TfLiteNode* node, std::vector<int>* orig_inputs) const {
TfLiteIntArray* inputs = node->inputs;
auto inputs_view = TfLiteIntArrayView(inputs);
@ -296,8 +289,8 @@ void FP16GraphPartitionHelper::RemapInputTensors(
bool is_remapped = false;
for (int j = 0; j < inputs->size; ++j) {
const int input_tid = inputs->data[j];
const auto it = dequant_nodes_.find(input_tid);
if (it != dequant_nodes_.end()) {
const auto it = dequant_map_.find(input_tid);
if (it != dequant_map_.end()) {
inputs->data[j] = it->second;
is_remapped = true;
}

View File

@ -127,19 +127,23 @@ class GraphPartitionHelper {
TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory
};
// While partitioning the graph, this claims DEQUANTIZE nodes (FP16->FP32) in
// addition to supported nodes for the delegate, when the DEQUANTIZE node's
// output is an input to the kernel that supports FP16 input.
// Specialized partitioner for graphs that possibly contain fp16 tensors.
//
// From nodes that accept fp16 inputs, this delegates the following:
// 1. All nodes (except DEQUANTIZE) that are supported with fp16 inputs by the
// delegate (in the TFLite graph, these nodes take in dequantized FP32
// outputs).
// 2. All fp16 DEQUANTIZE nodes that have *all* their consumers in the *first*
// delegated partition. This is because TFLite's partitioning algorithm
// greedily puts all such nodes in the first partition.
class FP16GraphPartitionHelper : public GraphPartitionHelper {
public:
FP16GraphPartitionHelper(TfLiteContext* context,
IsNodeSupportedFn is_node_supported_fn)
: GraphPartitionHelper(context, std::move(is_node_supported_fn)) {}
TfLiteStatus Partition(
std::set<std::string>* unsupported_nodes_info) override;
protected:
// Specialized function to handle fp16 nodes.
bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
TfLiteRegistration* registration, int node_id,
std::string* unsupported_details) override;
@ -149,39 +153,25 @@ class FP16GraphPartitionHelper : public GraphPartitionHelper {
int n, int min_nodes_per_partition) override;
private:
// Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true.
// When it's not a dequant op, remap its inputs to the inputs of the preceding
// dequant if there's a one and returns false. 'orig_inputs' records original
// input tensor ids of this node if any input is remapped.
bool RecordAndRemapInputTensors(int32_t op_code, int node_id,
TfLiteNode* node,
std::vector<int>* orig_inputs);
// This remaps fp32 inputs of the given node to their corresponding fp16
// version, if applicable. Can be summarized as:
// fp16 -> DEQUANTIZE -> fp32 -> OP -> output
// becomes
// fp16 -> OP -> output
void RemapFp16InputTensors(TfLiteNode* node,
std::vector<int>* orig_inputs) const;
// Restore inputs of 'node' to 'orig_inputs' only if two sizes match.
void RestoreToOrigInputTensors(TfLiteNode* node,
const std::vector<int>& orig_inputs);
// Performs the above remapping for all nodes in the given list, without
// tracking the original inputs.
void RemapFp16InputTensors(const std::vector<int>& nodes) const;
// Remap input tensors of every node in 'nodes' (i.e. node indices) if some of
// them are from dequant ops.
void RemapInputTensors(const std::vector<int>& nodes) const;
void RemoveSingleDequantNodePartitions();
void RemoveReservedDequantsFromNodes(std::vector<int>* nodes);
// Remap input tensors of a single 'node' if some of come from a dequant op.
// If 'orig_inputs' isn't nullptr, it records original input tensor ids of
// this node if any input is remapped.
void RemapInputTensors(TfLiteNode* node, std::vector<int>* orig_inputs) const;
// A map recording dequantize nodes's input/output tensors of this selected
// graph. The key is the output tensor id, and the value is the input tensor
// id.
// ('dequantize' here refers to fp16 DEQUANTIZE)
// Mapping of dequantize nodes' output tensor-id to its node id.
std::unordered_map<int, int> dequant_nodes_;
// A set of dequant nodes as in node indices that have to be preserved in the
// graph.
std::set<int> dequant_nodes_to_save_;
// Mapping of DEQUANTIZE node's output (fp32) to its input (fp16).
std::unordered_map<int, int> dequant_map_;
// mapping of DEQUANTIZE output tensor-id to its number of consumers.
std::unordered_map<int, int> dequant_consumers_;
};
} // namespace delegates

View File

@ -29,6 +29,10 @@ namespace {
// PartitionGraphIntoIndependentNodeSubsetsImpl partitioner(
// info, nodes_to_part, node_subsets);
// partitioner.Partition();
//
// NOTE: Changing the partitioning logic would require a change to
// FP16GraphPartitionHelper.
// LINT.IfChange
class PartitionGraphIntoIndependentNodeSubsetsImpl {
public:
PartitionGraphIntoIndependentNodeSubsetsImpl(
@ -198,6 +202,7 @@ class PartitionGraphIntoIndependentNodeSubsetsImpl {
// negative values of kEpochNotReady if not assigned.
std::vector<int> node_epochs_;
};
// LINT.ThenChange(//tensorflow/lite/delegates/utils.h)
} // namespace