Rectifies delegation of DEQUANTIZE nodes in fp16 graphs. Only those dequant nodes which are consumed in the first delegated partition are now selected.
PiperOrigin-RevId: 314627542 Change-Id: I801922eef3a95d3d4b05d53f5e8d27a0842c3be6
This commit is contained in:
parent
70fd126d3a
commit
c0b6b669e2
|
@ -250,7 +250,7 @@ class InterpreterFp16 : public DelegatedInterpreter {
|
||||||
InterpreterFp16* interpreter_fp16_add_op =
|
InterpreterFp16* interpreter_fp16_add_op =
|
||||||
new InterpreterFp16(kTfLiteBuiltinAdd);
|
new InterpreterFp16(kTfLiteBuiltinAdd);
|
||||||
|
|
||||||
TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
|
TEST(ModelBuilderTest, GetOpsToReplaceAcceptsFp16DequantizeNodes) {
|
||||||
// Before pruning, the graph has three nodes:
|
// Before pruning, the graph has three nodes:
|
||||||
//
|
//
|
||||||
// t0 (FP16) -> DequantNode -> t1 (FP32) -> Add -> t4
|
// t0 (FP16) -> DequantNode -> t1 (FP32) -> Add -> t4
|
||||||
|
@ -283,14 +283,15 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
|
||||||
context->PreviewDelegatePartitioning =
|
context->PreviewDelegatePartitioning =
|
||||||
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
|
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
|
||||||
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
|
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
|
||||||
|
// The partitioner should accept only the Add op initially.
|
||||||
|
EXPECT_EQ(nodes_to_replace->size, 1);
|
||||||
|
// Single partition output.
|
||||||
auto params = interpreter_fp16_add_op->add_delegate_params();
|
auto params = interpreter_fp16_add_op->add_delegate_params();
|
||||||
params->nodes_to_replace = TfLiteIntArrayCreate(3);
|
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
||||||
params->nodes_to_replace->data[0] = 0;
|
params->nodes_to_replace->data[0] = 2;
|
||||||
params->nodes_to_replace->data[1] = 1;
|
|
||||||
params->nodes_to_replace->data[2] = 2;
|
|
||||||
params->input_tensors = TfLiteIntArrayCreate(2);
|
params->input_tensors = TfLiteIntArrayCreate(2);
|
||||||
params->input_tensors->data[0] = 0;
|
params->input_tensors->data[0] = 1;
|
||||||
params->input_tensors->data[1] = 2;
|
params->input_tensors->data[1] = 3;
|
||||||
params->output_tensors = TfLiteIntArrayCreate(1);
|
params->output_tensors = TfLiteIntArrayCreate(1);
|
||||||
params->output_tensors->data[0] = 4;
|
params->output_tensors->data[0] = 4;
|
||||||
|
|
||||||
|
@ -301,11 +302,13 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
|
||||||
|
|
||||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
||||||
|
|
||||||
// Replace all nodes.
|
// The Dequant nodes are added to ops_to_replace as a post-processing step by
|
||||||
|
// the FP16GraphPartitioner. ADD is delegated with its inputs pointing to the
|
||||||
|
// FP16 inputs.
|
||||||
EXPECT_EQ(ops_to_replace->size, 3);
|
EXPECT_EQ(ops_to_replace->size, 3);
|
||||||
TfLiteNode* node = nullptr;
|
TfLiteNode* node = nullptr;
|
||||||
TfLiteRegistration* registration = nullptr;
|
TfLiteRegistration* registration = nullptr;
|
||||||
context->GetNodeAndRegistration(context, ops_to_replace->data[2], &node,
|
context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
|
||||||
®istration);
|
®istration);
|
||||||
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
|
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
|
||||||
TfLiteType::kTfLiteFloat16);
|
TfLiteType::kTfLiteFloat16);
|
||||||
|
@ -317,14 +320,14 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
|
||||||
InterpreterFp16* interpreter_fp16_gt_op =
|
InterpreterFp16* interpreter_fp16_gt_op =
|
||||||
new InterpreterFp16(kTfLiteBuiltinGreater);
|
new InterpreterFp16(kTfLiteBuiltinGreater);
|
||||||
|
|
||||||
TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) {
|
TEST(ModelBuilderTest, GetOpsToReplaceRejectsFp16DequantizeNodes) {
|
||||||
// Before pruning, the graph has three nodes:
|
// Before pruning, the graph has three nodes:
|
||||||
//
|
//
|
||||||
// t0 (FP16) -> DequantNode -> t1 (FP32) -> Greater Op -> t4
|
// t0 (FP16) -> DequantNode -> t1 (FP32) -> Greater Op -> t4
|
||||||
// t2 (FP16) -> DequantNode -> t3 (FP32) --/
|
// t2 (FP16) -> DequantNode -> t3 (FP32) --/
|
||||||
//
|
//
|
||||||
// Because there is no GPU equivalent for the Greater op, we don't prune
|
// Because there is no GPU equivalent for the Greater op, we don't choose any
|
||||||
// the Dequantize nodes.
|
// nodes.
|
||||||
|
|
||||||
TfLiteContext* context = interpreter_fp16_gt_op->context();
|
TfLiteContext* context = interpreter_fp16_gt_op->context();
|
||||||
// These functions are meant to be called inside delegates. Swap out
|
// These functions are meant to be called inside delegates. Swap out
|
||||||
|
@ -346,26 +349,10 @@ TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) {
|
||||||
context->PreviewDelegatePartitioning =
|
context->PreviewDelegatePartitioning =
|
||||||
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
|
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
|
||||||
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
|
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
|
||||||
auto params = interpreter_fp16_gt_op->add_delegate_params();
|
// No selected nodes.
|
||||||
// First partition for DequantNode (t0->t1)
|
EXPECT_EQ(nodes_to_replace->size, 0);
|
||||||
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
*partition_params_array = nullptr;
|
||||||
params->nodes_to_replace->data[0] = 0;
|
*num_partitions = 0;
|
||||||
params->input_tensors = TfLiteIntArrayCreate(1);
|
|
||||||
params->input_tensors->data[0] = 0;
|
|
||||||
params->output_tensors = TfLiteIntArrayCreate(1);
|
|
||||||
params->output_tensors->data[0] = 1;
|
|
||||||
|
|
||||||
// Second partition for DequantNode (t2->t3)
|
|
||||||
params = interpreter_fp16_gt_op->add_delegate_params();
|
|
||||||
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
|
||||||
params->nodes_to_replace->data[0] = 0;
|
|
||||||
params->input_tensors = TfLiteIntArrayCreate(1);
|
|
||||||
params->input_tensors->data[0] = 0;
|
|
||||||
params->output_tensors = TfLiteIntArrayCreate(1);
|
|
||||||
params->output_tensors->data[0] = 1;
|
|
||||||
|
|
||||||
*partition_params_array = interpreter_fp16_gt_op->delegate_params();
|
|
||||||
*num_partitions = interpreter_fp16_gt_op->num_delegate_params();
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -685,7 +672,7 @@ TEST(ModelBuilderTest, GetOpsToReplaceMultiplePartitions) {
|
||||||
|
|
||||||
class InterpreterMultiNode : public DelegatedInterpreter {
|
class InterpreterMultiNode : public DelegatedInterpreter {
|
||||||
public:
|
public:
|
||||||
explicit InterpreterMultiNode(bool add_op_first = true)
|
explicit InterpreterMultiNode(bool both_ops_supported = true)
|
||||||
: DelegatedInterpreter(5) {
|
: DelegatedInterpreter(5) {
|
||||||
void* builtin_data = malloc(sizeof(int));
|
void* builtin_data = malloc(sizeof(int));
|
||||||
EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk);
|
EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk);
|
||||||
|
@ -707,8 +694,8 @@ class InterpreterMultiNode : public DelegatedInterpreter {
|
||||||
kTfLiteOk);
|
kTfLiteOk);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (add_op_first) {
|
if (both_ops_supported) {
|
||||||
// Add the ADD op node that GPU delegate supports.
|
// Add 2 ADD ops.
|
||||||
const TfLiteRegistration reg_add0 = {
|
const TfLiteRegistration reg_add0 = {
|
||||||
[](TfLiteContext* context, const char* buffer, size_t length) {
|
[](TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
return reinterpret_cast<void*>(new int(1));
|
return reinterpret_cast<void*>(new int(1));
|
||||||
|
@ -727,8 +714,7 @@ class InterpreterMultiNode : public DelegatedInterpreter {
|
||||||
/*registration=*/®_add0),
|
/*registration=*/®_add0),
|
||||||
kTfLiteOk);
|
kTfLiteOk);
|
||||||
|
|
||||||
// Add the GREATER op node that GPU delegate doesn't support.
|
const TfLiteRegistration reg_add1 = {
|
||||||
const TfLiteRegistration reg_greater = {
|
|
||||||
[](TfLiteContext* context, const char* buffer, size_t length) {
|
[](TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
return reinterpret_cast<void*>(new int(1));
|
return reinterpret_cast<void*>(new int(1));
|
||||||
},
|
},
|
||||||
|
@ -738,12 +724,12 @@ class InterpreterMultiNode : public DelegatedInterpreter {
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
kTfLiteBuiltinGreater};
|
kTfLiteBuiltinAdd};
|
||||||
EXPECT_EQ(interpreter_.AddNodeWithParameters(
|
EXPECT_EQ(interpreter_.AddNodeWithParameters(
|
||||||
/*inputs=*/{3, 4}, /*outputs=*/{6}, /*init_data=*/nullptr,
|
/*inputs=*/{3, 4}, /*outputs=*/{6}, /*init_data=*/nullptr,
|
||||||
/*init_data_size=*/0,
|
/*init_data_size=*/0,
|
||||||
/*builtin_data=*/builtin_data,
|
/*builtin_data=*/builtin_data,
|
||||||
/*registration=*/®_greater),
|
/*registration=*/®_add1),
|
||||||
kTfLiteOk);
|
kTfLiteOk);
|
||||||
} else {
|
} else {
|
||||||
// Add the GREATER op node that GPU delegate doesn't support.
|
// Add the GREATER op node that GPU delegate doesn't support.
|
||||||
|
@ -828,19 +814,19 @@ class InterpreterMultiNode : public DelegatedInterpreter {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
InterpreterMultiNode* interpreter_mn = new InterpreterMultiNode();
|
InterpreterMultiNode* interpreter_mn =
|
||||||
|
new InterpreterMultiNode(/*both_ops_supported*/ false);
|
||||||
|
|
||||||
TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) {
|
TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectFp16Nodes_SinglePartition) {
|
||||||
// A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'.
|
// A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'.
|
||||||
// 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
|
// 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
|
||||||
// t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(4) -> t6
|
// t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(3) -> t6
|
||||||
// t1 (FP16) --> Dequant(1) --> t4 (FP32) --/
|
// t1 (FP16) --> Dequant(1) --> t4 (FP32) --/
|
||||||
// --\
|
// --\
|
||||||
// t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(3) -> t7
|
// t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7
|
||||||
//
|
//
|
||||||
// OpsToReplace should replace the 'Add' op and the Dequant outputing
|
// OpsToReplace should accept 'Add' & the Dequant nodes that only output to
|
||||||
// t5, but leave the other Dequant nodes because 'Greater' must run
|
// it (in this case, Dequant(2)).
|
||||||
// on the CPU.
|
|
||||||
TfLiteContext* context = interpreter_mn->context();
|
TfLiteContext* context = interpreter_mn->context();
|
||||||
|
|
||||||
// These functions are meant to be called inside delegates. Swap out
|
// These functions are meant to be called inside delegates. Swap out
|
||||||
|
@ -861,33 +847,16 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) {
|
||||||
context->PreviewDelegatePartitioning =
|
context->PreviewDelegatePartitioning =
|
||||||
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
|
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
|
||||||
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
|
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
|
||||||
|
// The FP16GraphPartitioner should only mark the ADD op as accepted.
|
||||||
|
EXPECT_EQ(nodes_to_replace->size, 1);
|
||||||
|
EXPECT_EQ(nodes_to_replace->data[0], 4);
|
||||||
|
// Single partition.
|
||||||
auto params = interpreter_mn->add_delegate_params();
|
auto params = interpreter_mn->add_delegate_params();
|
||||||
// First partition for DequantNode(0)
|
|
||||||
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
||||||
params->nodes_to_replace->data[0] = 0;
|
params->nodes_to_replace->data[0] = 4;
|
||||||
params->input_tensors = TfLiteIntArrayCreate(1);
|
|
||||||
params->input_tensors->data[0] = 0;
|
|
||||||
params->output_tensors = TfLiteIntArrayCreate(1);
|
|
||||||
params->output_tensors->data[0] = 3;
|
|
||||||
|
|
||||||
// Second partition for DequantNode(1)
|
|
||||||
params = interpreter_mn->add_delegate_params();
|
|
||||||
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
|
||||||
params->nodes_to_replace->data[0] = 1;
|
|
||||||
params->input_tensors = TfLiteIntArrayCreate(1);
|
|
||||||
params->input_tensors->data[0] = 1;
|
|
||||||
params->output_tensors = TfLiteIntArrayCreate(1);
|
|
||||||
params->output_tensors->data[0] = 4;
|
|
||||||
|
|
||||||
// Third partition for DequantNode(1), DequantNode(2), ADD(3)
|
|
||||||
params = interpreter_mn->add_delegate_params();
|
|
||||||
params->nodes_to_replace = TfLiteIntArrayCreate(3);
|
|
||||||
params->nodes_to_replace->data[0] = 1;
|
|
||||||
params->nodes_to_replace->data[1] = 2;
|
|
||||||
params->nodes_to_replace->data[2] = 3;
|
|
||||||
params->input_tensors = TfLiteIntArrayCreate(2);
|
params->input_tensors = TfLiteIntArrayCreate(2);
|
||||||
params->input_tensors->data[0] = 1;
|
params->input_tensors->data[0] = 1;
|
||||||
params->input_tensors->data[0] = 3;
|
params->input_tensors->data[1] = 3;
|
||||||
params->output_tensors = TfLiteIntArrayCreate(1);
|
params->output_tensors = TfLiteIntArrayCreate(1);
|
||||||
params->output_tensors->data[0] = 7;
|
params->output_tensors->data[0] = 7;
|
||||||
|
|
||||||
|
@ -898,16 +867,16 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) {
|
||||||
|
|
||||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
||||||
|
|
||||||
|
// Post-PreviewDelegatePartitioning, the partitioner will add Dequant(2) to
|
||||||
|
// ops_to_replace, since it only outputs to a delegated node.
|
||||||
EXPECT_EQ(ops_to_replace->size, 2);
|
EXPECT_EQ(ops_to_replace->size, 2);
|
||||||
// Op at index 2 is the Dequant op (t3 -> t5).
|
// Op at index 4 is the Add op.
|
||||||
EXPECT_EQ(ops_to_replace->data[0], 2);
|
EXPECT_EQ(ops_to_replace->data[0], 4);
|
||||||
// Op at index 3 is the Add op.
|
EXPECT_EQ(ops_to_replace->data[1], 2);
|
||||||
EXPECT_EQ(ops_to_replace->data[1], 3);
|
// Verify that Add op has fp16 inputs.
|
||||||
|
|
||||||
TfLiteNode* node = nullptr;
|
TfLiteNode* node = nullptr;
|
||||||
TfLiteRegistration* registration = nullptr;
|
TfLiteRegistration* registration = nullptr;
|
||||||
// Verify that Add op has fp16 inputs.
|
context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
|
||||||
context->GetNodeAndRegistration(context, ops_to_replace->data[1], &node,
|
|
||||||
®istration);
|
®istration);
|
||||||
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
|
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
|
||||||
TfLiteType::kTfLiteFloat16);
|
TfLiteType::kTfLiteFloat16);
|
||||||
|
@ -917,21 +886,18 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) {
|
||||||
}
|
}
|
||||||
|
|
||||||
InterpreterMultiNode* interpreter_mn2 =
|
InterpreterMultiNode* interpreter_mn2 =
|
||||||
new InterpreterMultiNode(/*add_op_first=*/false);
|
new InterpreterMultiNode(/*both_ops_supported*/ true);
|
||||||
TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsGreaterFirst) {
|
TEST(ModelBuilderTest,
|
||||||
// A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'.
|
GetOpsToReplaceSelectsCorrectFp16Nodes_MultiplePartitions) {
|
||||||
// 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
|
// A graph with three Dequant nodes feeding two Add ops.
|
||||||
// t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(3) -> t6
|
// t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Add(3) -> t6
|
||||||
// t1 (FP16) --> Dequant(1) --> t4 (FP32) --/
|
// t1 (FP16) --> Dequant(1) --> t4 (FP32) --/
|
||||||
// --\
|
// --\
|
||||||
// t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7
|
// t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7
|
||||||
//
|
//
|
||||||
// Note: the graph dependency is exactly same w/ that in
|
// In this test case, we purposely partition Add(3) & Add(4) into different
|
||||||
// GetOpsToReplaceSelectsCorrectDequantsAddFirst, but the unsupported
|
// partitions, to check if Dequant nodes that output *only* to the first
|
||||||
// 'Greater' op appears first in the execution plan. Despite this,
|
// partition nodes are accepted.
|
||||||
// OpsToReplace should still replace the 'Add' op and the Dequant outputing
|
|
||||||
// t5, but leave the other Dequant nodes because 'Greater' must run
|
|
||||||
// on the CPU.
|
|
||||||
|
|
||||||
TfLiteContext* context = interpreter_mn2->context();
|
TfLiteContext* context = interpreter_mn2->context();
|
||||||
|
|
||||||
|
@ -954,33 +920,29 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsGreaterFirst) {
|
||||||
context->PreviewDelegatePartitioning =
|
context->PreviewDelegatePartitioning =
|
||||||
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
|
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
|
||||||
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
|
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
|
||||||
|
// The FP16GraphPartitioner should only mark both ADD ops as accepted.
|
||||||
|
EXPECT_EQ(nodes_to_replace->size, 2);
|
||||||
|
EXPECT_EQ(nodes_to_replace->data[0], 3);
|
||||||
|
EXPECT_EQ(nodes_to_replace->data[1], 4);
|
||||||
|
// Technically, both ADD ops should end up in the same partition.
|
||||||
|
// But we put them in different partitions to test post-processing with
|
||||||
|
// DEQUANTIZE nodes.
|
||||||
|
// First partition with Add(3).
|
||||||
auto params = interpreter_mn2->add_delegate_params();
|
auto params = interpreter_mn2->add_delegate_params();
|
||||||
// First partition for DequantNode(0)
|
|
||||||
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
||||||
params->nodes_to_replace->data[0] = 0;
|
params->nodes_to_replace->data[0] = 3;
|
||||||
params->input_tensors = TfLiteIntArrayCreate(1);
|
|
||||||
params->input_tensors->data[0] = 0;
|
|
||||||
params->output_tensors = TfLiteIntArrayCreate(1);
|
|
||||||
params->output_tensors->data[0] = 3;
|
|
||||||
|
|
||||||
// Second partition for DequantNode(1)
|
|
||||||
params = interpreter_mn2->add_delegate_params();
|
|
||||||
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
|
||||||
params->nodes_to_replace->data[0] = 1;
|
|
||||||
params->input_tensors = TfLiteIntArrayCreate(1);
|
|
||||||
params->input_tensors->data[0] = 1;
|
|
||||||
params->output_tensors = TfLiteIntArrayCreate(1);
|
|
||||||
params->output_tensors->data[0] = 4;
|
|
||||||
|
|
||||||
// Third partition for DequantNode(1), DequantNode(2), ADD(4)
|
|
||||||
params = interpreter_mn2->add_delegate_params();
|
|
||||||
params->nodes_to_replace = TfLiteIntArrayCreate(3);
|
|
||||||
params->nodes_to_replace->data[0] = 1;
|
|
||||||
params->nodes_to_replace->data[1] = 2;
|
|
||||||
params->nodes_to_replace->data[2] = 4;
|
|
||||||
params->input_tensors = TfLiteIntArrayCreate(2);
|
params->input_tensors = TfLiteIntArrayCreate(2);
|
||||||
params->input_tensors->data[0] = 1;
|
|
||||||
params->input_tensors->data[0] = 3;
|
params->input_tensors->data[0] = 3;
|
||||||
|
params->input_tensors->data[1] = 4;
|
||||||
|
params->output_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->output_tensors->data[0] = 6;
|
||||||
|
// Second partition with Add(4).
|
||||||
|
params = interpreter_mn2->add_delegate_params();
|
||||||
|
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
||||||
|
params->nodes_to_replace->data[0] = 4;
|
||||||
|
params->input_tensors = TfLiteIntArrayCreate(2);
|
||||||
|
params->input_tensors->data[0] = 4;
|
||||||
|
params->input_tensors->data[1] = 5;
|
||||||
params->output_tensors = TfLiteIntArrayCreate(1);
|
params->output_tensors = TfLiteIntArrayCreate(1);
|
||||||
params->output_tensors->data[0] = 7;
|
params->output_tensors->data[0] = 7;
|
||||||
|
|
||||||
|
@ -989,23 +951,32 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsGreaterFirst) {
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
};
|
};
|
||||||
|
|
||||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
TfLiteIntArray* ops_to_replace = GetOpsToReplace(
|
||||||
|
context, /*allow_quant_ops*/ false, /*max_delegated_partitions*/ 2);
|
||||||
|
|
||||||
EXPECT_EQ(ops_to_replace->size, 2);
|
// Three ops should be selected:
|
||||||
// Op at index 2 is the Dequant op (t3 -> t5).
|
// Add(3), Dequant(x), Add(4)
|
||||||
EXPECT_EQ(ops_to_replace->data[0], 2);
|
// Since both partitions are of size 1, either could end up as the 'first'
|
||||||
// Op at index 4 is the Add op.
|
// partition with one Dequant node added for it.
|
||||||
EXPECT_EQ(ops_to_replace->data[1], 4);
|
EXPECT_EQ(ops_to_replace->size, 3);
|
||||||
|
|
||||||
TfLiteNode* node = nullptr;
|
TfLiteNode* node = nullptr;
|
||||||
TfLiteRegistration* registration = nullptr;
|
TfLiteRegistration* registration = nullptr;
|
||||||
// Verify that Add op has fp16 inputs.
|
// Verify that both Add ops have fp16 inputs.
|
||||||
context->GetNodeAndRegistration(context, ops_to_replace->data[1], &node,
|
context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
|
||||||
®istration);
|
®istration);
|
||||||
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
|
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
|
||||||
TfLiteType::kTfLiteFloat16);
|
TfLiteType::kTfLiteFloat16);
|
||||||
EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
|
EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
|
||||||
TfLiteType::kTfLiteFloat16);
|
TfLiteType::kTfLiteFloat16);
|
||||||
|
context->GetNodeAndRegistration(context, ops_to_replace->data[2], &node,
|
||||||
|
®istration);
|
||||||
|
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
|
||||||
|
TfLiteType::kTfLiteFloat16);
|
||||||
|
EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
|
||||||
|
TfLiteType::kTfLiteFloat16);
|
||||||
|
// Verify that the op at index 1 is a Dequant outputing to a single Add.
|
||||||
|
EXPECT_TRUE(ops_to_replace->data[1] == 0 || ops_to_replace->data[1] == 2);
|
||||||
TfLiteIntArrayFree(ops_to_replace);
|
TfLiteIntArrayFree(ops_to_replace);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -150,74 +150,112 @@ TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes(
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus FP16GraphPartitionHelper::Partition(
|
|
||||||
std::set<std::string>* unsupported_nodes_info) {
|
|
||||||
const auto status = GraphPartitionHelper::Partition(unsupported_nodes_info);
|
|
||||||
// Clean up those partitions that have a single dequant op. NoteThose
|
|
||||||
// removed dequant ops have to be reserved in the graph and should not be
|
|
||||||
// delegated.
|
|
||||||
RemoveSingleDequantNodePartitions();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int>
|
std::vector<int>
|
||||||
FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
|
FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
|
||||||
int n, int min_nodes_per_partition) {
|
int n, int min_nodes_per_partition) {
|
||||||
std::vector<int> ops_to_replace =
|
auto first_n_partitions =
|
||||||
GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
|
GetFirstNLargestPartitions(n, min_nodes_per_partition);
|
||||||
n, min_nodes_per_partition);
|
std::vector<int> ops_to_replace;
|
||||||
RemapInputTensors(ops_to_replace);
|
if (first_n_partitions.empty()) return ops_to_replace;
|
||||||
RemoveReservedDequantsFromNodes(&ops_to_replace);
|
|
||||||
|
// Handle the first delegated partition specially.
|
||||||
|
// All fp16 DEQUANTIZE nodes whose consumers exist only in this partition can
|
||||||
|
// be added to the ops to delegate. Others have to be preserved in the graph,
|
||||||
|
// since the partitioning algorithm will put such nodes greedily in the first
|
||||||
|
// partition.
|
||||||
|
const auto* first_partition = first_n_partitions[0];
|
||||||
|
std::unordered_map<int, int> delegated_dequant_consumers;
|
||||||
|
for (int i = 0; i < first_partition->nodes_to_replace->size; ++i) {
|
||||||
|
const int node_id = first_partition->nodes_to_replace->data[i];
|
||||||
|
ops_to_replace.push_back(node_id);
|
||||||
|
TfLiteNode* node;
|
||||||
|
TfLiteRegistration* registration;
|
||||||
|
const auto status = context_->GetNodeAndRegistration(context_, node_id,
|
||||||
|
&node, ®istration);
|
||||||
|
if (status != kTfLiteOk) {
|
||||||
|
TF_LITE_KERNEL_LOG(context_,
|
||||||
|
"Couldn't get node and registration info for op: %d\n",
|
||||||
|
node_id);
|
||||||
|
ops_to_replace.clear();
|
||||||
|
return ops_to_replace;
|
||||||
|
}
|
||||||
|
// See if any input to the op is a (converted) fp16 value. If yes, increment
|
||||||
|
// its value in delegated_dequant_consumers.
|
||||||
|
for (int j = 0; j < node->inputs->size; ++j) {
|
||||||
|
const int input_tid = node->inputs->data[j];
|
||||||
|
if (dequant_consumers_.find(input_tid) != dequant_consumers_.end()) {
|
||||||
|
delegated_dequant_consumers[input_tid] += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check all dequant nodes that have some consumers in the first partition.
|
||||||
|
// If the number of delegated consumers is same as total number of consumers,
|
||||||
|
// add the corresponding DEQUANTIZE op to the delegated nodes.
|
||||||
|
for (auto tensor_and_consumers : delegated_dequant_consumers) {
|
||||||
|
if (dequant_consumers_[tensor_and_consumers.first] ==
|
||||||
|
tensor_and_consumers.second) {
|
||||||
|
ops_to_replace.emplace_back(dequant_nodes_[tensor_and_consumers.first]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For all other partitions after the first one, insert all nodes into
|
||||||
|
// ops_to_replace.
|
||||||
|
for (int i = 1; i < first_n_partitions.size(); ++i) {
|
||||||
|
auto nodes = first_n_partitions[i]->nodes_to_replace;
|
||||||
|
ops_to_replace.insert(ops_to_replace.end(), nodes->data,
|
||||||
|
nodes->data + nodes->size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify the inputs of relevant ops that support fp16 constants.
|
||||||
|
// TODO(b/156707497): Ensure that these inputs are remapped during the
|
||||||
|
// delegate's 'free', so that CPU fallback works for fp16 models.
|
||||||
|
RemapFp16InputTensors(ops_to_replace);
|
||||||
return ops_to_replace;
|
return ops_to_replace;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FP16GraphPartitionHelper::IsNodeSupported(
|
bool FP16GraphPartitionHelper::IsNodeSupported(
|
||||||
TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration,
|
TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration,
|
||||||
int node_id, std::string* unsupported_details) {
|
int node_id, std::string* unsupported_details) {
|
||||||
// If we need to handle dequant nodes, we have to remap input tensors of
|
if (registration->builtin_code == kTfLiteBuiltinDequantize &&
|
||||||
// this node if some of them come from a dequant node before testing if
|
|
||||||
// the node is supported.
|
|
||||||
std::vector<int> orig_inputs;
|
|
||||||
if (RecordAndRemapInputTensors(registration->builtin_code, node_id, node,
|
|
||||||
&orig_inputs)) {
|
|
||||||
// We have a dequant op here. Note that we retrun an Ok status because a
|
|
||||||
// dequant node is first added as supported. Later, this dequant node
|
|
||||||
// will be removed if it has to be preserved in the graph which happens
|
|
||||||
// when its immediate downstream nodes cannot be supported.
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
const auto status = GraphPartitionHelper::IsNodeSupported(
|
|
||||||
context, node, registration, node_id, unsupported_details);
|
|
||||||
RestoreToOrigInputTensors(node, orig_inputs);
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool FP16GraphPartitionHelper::RecordAndRemapInputTensors(
|
|
||||||
int32_t op_code, int node_id, TfLiteNode* node,
|
|
||||||
std::vector<int>* orig_inputs) {
|
|
||||||
orig_inputs->clear();
|
|
||||||
// Record the dequant node.
|
|
||||||
if (op_code == kTfLiteBuiltinDequantize &&
|
|
||||||
context_->tensors[node->inputs->data[0]].type ==
|
context_->tensors[node->inputs->data[0]].type ==
|
||||||
TfLiteType::kTfLiteFloat16) {
|
TfLiteType::kTfLiteFloat16) {
|
||||||
dequant_nodes_[node->outputs->data[0]] = node->inputs->data[0];
|
// Update mappings if this node is a fp16 DEQUANTIZE node.
|
||||||
return true;
|
dequant_map_[node->outputs->data[0]] = node->inputs->data[0];
|
||||||
}
|
dequant_nodes_[node->outputs->data[0]] = node_id;
|
||||||
// For a dequantize op, there's no need to remap its input tensors.
|
// We do not accept these ops right now.
|
||||||
if (dequant_nodes_.empty()) return false;
|
// This is done to support use-cases where a DEQUANTIZE output might be
|
||||||
RemapInputTensors(node, orig_inputs);
|
// consumed by a CPU op.
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FP16GraphPartitionHelper::RestoreToOrigInputTensors(
|
// To check if a (possibly) FP16 node is supported, we temporarily point the
|
||||||
TfLiteNode* node, const std::vector<int>& orig_inputs) {
|
// node's inputs to the original fp16 tensors. This 'mutated' node is then
|
||||||
if (node->inputs->size != orig_inputs.size()) return;
|
// passed to the base IsNodeSupported function for checking. After the check,
|
||||||
|
// we remap the original node inputs, so that the TFLite graph remains the
|
||||||
|
// same.
|
||||||
|
std::vector<int> orig_inputs;
|
||||||
|
if (!dequant_nodes_.empty()) {
|
||||||
|
RemapFp16InputTensors(node, &orig_inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto is_supported = GraphPartitionHelper::IsNodeSupported(
|
||||||
|
context, node, registration, node_id, unsupported_details);
|
||||||
|
|
||||||
|
if (!orig_inputs.empty() && node->inputs->size == orig_inputs.size()) {
|
||||||
|
// Remapping happened. Restore original inputs.
|
||||||
for (int j = 0; j < node->inputs->size; ++j) {
|
for (int j = 0; j < node->inputs->size; ++j) {
|
||||||
node->inputs->data[j] = orig_inputs[j];
|
node->inputs->data[j] = orig_inputs[j];
|
||||||
|
if (dequant_nodes_.find(orig_inputs[j]) != dequant_nodes_.end()) {
|
||||||
|
// If its a fp16 tensor, increment number of consumers of the
|
||||||
|
// corresponding DEQUANTIZE.
|
||||||
|
dequant_consumers_[orig_inputs[j]] += 1;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return is_supported;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FP16GraphPartitionHelper::RemapInputTensors(
|
void FP16GraphPartitionHelper::RemapFp16InputTensors(
|
||||||
const std::vector<int>& nodes) const {
|
const std::vector<int>& nodes) const {
|
||||||
for (int node_id : nodes) {
|
for (int node_id : nodes) {
|
||||||
TfLiteNode* node;
|
TfLiteNode* node;
|
||||||
|
@ -229,56 +267,11 @@ void FP16GraphPartitionHelper::RemapInputTensors(
|
||||||
"Couldn't get node and registration info for op: %d\n",
|
"Couldn't get node and registration info for op: %d\n",
|
||||||
node_id);
|
node_id);
|
||||||
}
|
}
|
||||||
RemapInputTensors(node, nullptr /* orig_inputs*/);
|
RemapFp16InputTensors(node, nullptr /* orig_inputs*/);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FP16GraphPartitionHelper::RemoveSingleDequantNodePartitions() {
|
void FP16GraphPartitionHelper::RemapFp16InputTensors(
|
||||||
auto it = partitions_.begin();
|
|
||||||
while (it != partitions_.end()) {
|
|
||||||
auto p = *it;
|
|
||||||
if (p->nodes_to_replace->size != 1) {
|
|
||||||
++it;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
int node_id = p->nodes_to_replace->data[0];
|
|
||||||
TfLiteNode* node = nullptr;
|
|
||||||
TfLiteRegistration* registration = nullptr;
|
|
||||||
|
|
||||||
TfLiteStatus status = context_->GetNodeAndRegistration(
|
|
||||||
context_, node_id, &node, ®istration);
|
|
||||||
if (status != kTfLiteOk) {
|
|
||||||
TF_LITE_KERNEL_LOG(context_,
|
|
||||||
"Couldn't get node and registration info for op: %d\n",
|
|
||||||
node_id);
|
|
||||||
}
|
|
||||||
if (registration->builtin_code != kTfLiteBuiltinDequantize ||
|
|
||||||
context_->tensors[node->inputs->data[0]].type !=
|
|
||||||
TfLiteType::kTfLiteFloat16) {
|
|
||||||
++it;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// Note such dequant nodes have to be preserved in the graph as dequant
|
|
||||||
// ops are not actually supported in the GPU delegate.
|
|
||||||
dequant_nodes_to_save_.insert(node_id);
|
|
||||||
it = partitions_.erase(it);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void FP16GraphPartitionHelper::RemoveReservedDequantsFromNodes(
|
|
||||||
std::vector<int>* nodes) {
|
|
||||||
if (dequant_nodes_to_save_.empty()) return;
|
|
||||||
auto it = nodes->begin();
|
|
||||||
while (it != nodes->end()) {
|
|
||||||
if (dequant_nodes_to_save_.find(*it) == dequant_nodes_to_save_.end()) {
|
|
||||||
++it;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
it = nodes->erase(it);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void FP16GraphPartitionHelper::RemapInputTensors(
|
|
||||||
TfLiteNode* node, std::vector<int>* orig_inputs) const {
|
TfLiteNode* node, std::vector<int>* orig_inputs) const {
|
||||||
TfLiteIntArray* inputs = node->inputs;
|
TfLiteIntArray* inputs = node->inputs;
|
||||||
auto inputs_view = TfLiteIntArrayView(inputs);
|
auto inputs_view = TfLiteIntArrayView(inputs);
|
||||||
|
@ -296,8 +289,8 @@ void FP16GraphPartitionHelper::RemapInputTensors(
|
||||||
bool is_remapped = false;
|
bool is_remapped = false;
|
||||||
for (int j = 0; j < inputs->size; ++j) {
|
for (int j = 0; j < inputs->size; ++j) {
|
||||||
const int input_tid = inputs->data[j];
|
const int input_tid = inputs->data[j];
|
||||||
const auto it = dequant_nodes_.find(input_tid);
|
const auto it = dequant_map_.find(input_tid);
|
||||||
if (it != dequant_nodes_.end()) {
|
if (it != dequant_map_.end()) {
|
||||||
inputs->data[j] = it->second;
|
inputs->data[j] = it->second;
|
||||||
is_remapped = true;
|
is_remapped = true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -127,19 +127,23 @@ class GraphPartitionHelper {
|
||||||
TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory
|
TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory
|
||||||
};
|
};
|
||||||
|
|
||||||
// While partitioning the graph, this claims DEQUANTIZE nodes (FP16->FP32) in
|
// Specialized partitioner for graphs that possibly contain fp16 tensors.
|
||||||
// addition to supported nodes for the delegate, when the DEQUANTIZE node's
|
//
|
||||||
// output is an input to the kernel that supports FP16 input.
|
// From nodes that accept fp16 inputs, this delegates the following:
|
||||||
|
// 1. All nodes (except DEQUANTIZE) that are supported with fp16 inputs by the
|
||||||
|
// delegate (in the TFLite graph, these nodes take in dequantized FP32
|
||||||
|
// outputs).
|
||||||
|
// 2. All fp16 DEQUANTIZE nodes that have *all* their consumers in the *first*
|
||||||
|
// delegated partition. This is because TFLite's partitioning algorithm
|
||||||
|
// greedily puts all such nodes in the first partition.
|
||||||
class FP16GraphPartitionHelper : public GraphPartitionHelper {
|
class FP16GraphPartitionHelper : public GraphPartitionHelper {
|
||||||
public:
|
public:
|
||||||
FP16GraphPartitionHelper(TfLiteContext* context,
|
FP16GraphPartitionHelper(TfLiteContext* context,
|
||||||
IsNodeSupportedFn is_node_supported_fn)
|
IsNodeSupportedFn is_node_supported_fn)
|
||||||
: GraphPartitionHelper(context, std::move(is_node_supported_fn)) {}
|
: GraphPartitionHelper(context, std::move(is_node_supported_fn)) {}
|
||||||
|
|
||||||
TfLiteStatus Partition(
|
|
||||||
std::set<std::string>* unsupported_nodes_info) override;
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
// Specialized function to handle fp16 nodes.
|
||||||
bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
|
bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
|
||||||
TfLiteRegistration* registration, int node_id,
|
TfLiteRegistration* registration, int node_id,
|
||||||
std::string* unsupported_details) override;
|
std::string* unsupported_details) override;
|
||||||
|
@ -149,39 +153,25 @@ class FP16GraphPartitionHelper : public GraphPartitionHelper {
|
||||||
int n, int min_nodes_per_partition) override;
|
int n, int min_nodes_per_partition) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true.
|
// This remaps fp32 inputs of the given node to their corresponding fp16
|
||||||
// When it's not a dequant op, remap its inputs to the inputs of the preceding
|
// version, if applicable. Can be summarized as:
|
||||||
// dequant if there's a one and returns false. 'orig_inputs' records original
|
// fp16 -> DEQUANTIZE -> fp32 -> OP -> output
|
||||||
// input tensor ids of this node if any input is remapped.
|
// becomes
|
||||||
bool RecordAndRemapInputTensors(int32_t op_code, int node_id,
|
// fp16 -> OP -> output
|
||||||
TfLiteNode* node,
|
void RemapFp16InputTensors(TfLiteNode* node,
|
||||||
std::vector<int>* orig_inputs);
|
std::vector<int>* orig_inputs) const;
|
||||||
|
|
||||||
// Restore inputs of 'node' to 'orig_inputs' only if two sizes match.
|
// Performs the above remapping for all nodes in the given list, without
|
||||||
void RestoreToOrigInputTensors(TfLiteNode* node,
|
// tracking the original inputs.
|
||||||
const std::vector<int>& orig_inputs);
|
void RemapFp16InputTensors(const std::vector<int>& nodes) const;
|
||||||
|
|
||||||
// Remap input tensors of every node in 'nodes' (i.e. node indices) if some of
|
// ('dequantize' here refers to fp16 DEQUANTIZE)
|
||||||
// them are from dequant ops.
|
// Mapping of dequantize nodes' output tensor-id to its node id.
|
||||||
void RemapInputTensors(const std::vector<int>& nodes) const;
|
|
||||||
|
|
||||||
void RemoveSingleDequantNodePartitions();
|
|
||||||
|
|
||||||
void RemoveReservedDequantsFromNodes(std::vector<int>* nodes);
|
|
||||||
|
|
||||||
// Remap input tensors of a single 'node' if some of come from a dequant op.
|
|
||||||
// If 'orig_inputs' isn't nullptr, it records original input tensor ids of
|
|
||||||
// this node if any input is remapped.
|
|
||||||
void RemapInputTensors(TfLiteNode* node, std::vector<int>* orig_inputs) const;
|
|
||||||
|
|
||||||
// A map recording dequantize nodes's input/output tensors of this selected
|
|
||||||
// graph. The key is the output tensor id, and the value is the input tensor
|
|
||||||
// id.
|
|
||||||
std::unordered_map<int, int> dequant_nodes_;
|
std::unordered_map<int, int> dequant_nodes_;
|
||||||
|
// Mapping of DEQUANTIZE node's output (fp32) to its input (fp16).
|
||||||
// A set of dequant nodes as in node indices that have to be preserved in the
|
std::unordered_map<int, int> dequant_map_;
|
||||||
// graph.
|
// mapping of DEQUANTIZE output tensor-id to its number of consumers.
|
||||||
std::set<int> dequant_nodes_to_save_;
|
std::unordered_map<int, int> dequant_consumers_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace delegates
|
} // namespace delegates
|
||||||
|
|
|
@ -29,6 +29,10 @@ namespace {
|
||||||
// PartitionGraphIntoIndependentNodeSubsetsImpl partitioner(
|
// PartitionGraphIntoIndependentNodeSubsetsImpl partitioner(
|
||||||
// info, nodes_to_part, node_subsets);
|
// info, nodes_to_part, node_subsets);
|
||||||
// partitioner.Partition();
|
// partitioner.Partition();
|
||||||
|
//
|
||||||
|
// NOTE: Changing the partitioning logic would require a change to
|
||||||
|
// FP16GraphPartitionHelper.
|
||||||
|
// LINT.IfChange
|
||||||
class PartitionGraphIntoIndependentNodeSubsetsImpl {
|
class PartitionGraphIntoIndependentNodeSubsetsImpl {
|
||||||
public:
|
public:
|
||||||
PartitionGraphIntoIndependentNodeSubsetsImpl(
|
PartitionGraphIntoIndependentNodeSubsetsImpl(
|
||||||
|
@ -198,6 +202,7 @@ class PartitionGraphIntoIndependentNodeSubsetsImpl {
|
||||||
// negative values of kEpochNotReady if not assigned.
|
// negative values of kEpochNotReady if not assigned.
|
||||||
std::vector<int> node_epochs_;
|
std::vector<int> node_epochs_;
|
||||||
};
|
};
|
||||||
|
// LINT.ThenChange(//tensorflow/lite/delegates/utils.h)
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue