Disable FP16 delegation for DEQUANTIZE nodes with a non-constant input

PiperOrigin-RevId: 348054092
Change-Id: Ib9b8f7b0a8d778c8967ec0c0d8b65d6a07684777
This commit is contained in:
Sachin Joglekar 2020-12-17 11:01:32 -08:00 committed by TensorFlower Gardener
parent fd0d7123b1
commit 13de383a2f
4 changed files with 111 additions and 24 deletions

View File

@ -60,6 +60,7 @@ cc_library(
"//tensorflow/lite:kernel_api",
"//tensorflow/lite:util",
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:kernel_util",
],
)

View File

@ -187,7 +187,9 @@ class DelegatedInterpreter {
class InterpreterFp16 : public DelegatedInterpreter {
public:
explicit InterpreterFp16(TfLiteBuiltinOperator op) : DelegatedInterpreter(3) {
explicit InterpreterFp16(TfLiteBuiltinOperator op,
bool const_dequantize_inputs = true)
: DelegatedInterpreter(3) {
void* builtin_data = malloc(sizeof(int));
EXPECT_EQ(interpreter_.AddTensors(5), kTfLiteOk);
EXPECT_EQ(interpreter_.SetInputs({0, 1}), kTfLiteOk);
@ -243,6 +245,15 @@ class InterpreterFp16 : public DelegatedInterpreter {
interpreter_.SetTensorParametersReadWrite(
2, TfLiteType::kTfLiteFloat16, "t2", dims, quantization, false),
kTfLiteOk);
if (const_dequantize_inputs) {
// This simulates the dequantize inputs being constants in the graph.
// If this is not true, FP16GraphPartitionHelper should not consider the
// corresponding DEQUANTIZE ops.
auto* tensor0 = interpreter_.tensor(0);
auto* tensor2 = interpreter_.tensor(2);
tensor0->allocation_type = kTfLiteMmapRo;
tensor2->allocation_type = kTfLiteMmapRo;
}
EXPECT_EQ(
interpreter_.SetTensorParametersReadWrite(
1, TfLiteType::kTfLiteFloat32, "t1", dims, quantization, false),
@ -337,6 +348,64 @@ TEST(ModelBuilderTest, GetOpsToReplaceAcceptsFp16DequantizeNodes) {
TfLiteIntArrayFree(ops_to_replace);
}
InterpreterFp16* interpreter_fp16_non_constant =
new InterpreterFp16(kTfLiteBuiltinAdd, /*const_dequantize_inputs=*/false);
// Same as GetOpsToReplaceAcceptsFp16DequantizeNodes, but the DEQUANTIZE inputs
// are not constant. As a result, we don't allow the delegate to accept them.
TEST(ModelBuilderTest, GetOpsToReplaceRejectsNonConstantFp16DequantizeNodes) {
TfLiteContext* context = interpreter_fp16_non_constant->context();
// These functions are meant to be called inside delegates. Swap out
// for similar functions to permit direct calling of GetOpsToReplace.
context->GetExecutionPlan = [](struct TfLiteContext* context,
TfLiteIntArray** execution_plan) {
*execution_plan = interpreter_fp16_non_constant->exec_plan();
return kTfLiteOk;
};
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
TfLiteNode** node,
TfLiteRegistration** registration) {
*node = interpreter_fp16_non_constant->node(node_index);
*registration = interpreter_fp16_non_constant->registration(node_index);
return kTfLiteOk;
};
context->PreviewDelegatePartitioning =
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
// The partitioner should accept only the Add op initially.
EXPECT_EQ(nodes_to_replace->size, 1);
// Single partition output.
auto params = interpreter_fp16_non_constant->add_delegate_params();
params->nodes_to_replace = TfLiteIntArrayCreate(1);
params->nodes_to_replace->data[0] = 2;
params->input_tensors = TfLiteIntArrayCreate(2);
params->input_tensors->data[0] = 1;
params->input_tensors->data[1] = 3;
params->output_tensors = TfLiteIntArrayCreate(1);
params->output_tensors->data[0] = 4;
*partition_params_array =
interpreter_fp16_non_constant->delegate_params();
*num_partitions = interpreter_fp16_non_constant->num_delegate_params();
return kTfLiteOk;
};
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
// Only ADD is delegated, with FP32 (dequantized) inputs.
EXPECT_EQ(ops_to_replace->size, 1);
TfLiteNode* node = nullptr;
TfLiteRegistration* registration = nullptr;
context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
&registration);
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
TfLiteType::kTfLiteFloat32);
EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
TfLiteType::kTfLiteFloat32);
TfLiteIntArrayFree(ops_to_replace);
}
InterpreterFp16* interpreter_fp16_gt_op =
new InterpreterFp16(kTfLiteBuiltinGreater);
@ -800,6 +869,13 @@ class InterpreterMultiNode : public DelegatedInterpreter {
interpreter_.SetTensorParametersReadWrite(
2, TfLiteType::kTfLiteFloat16, "t2", dims, quantization, false),
kTfLiteOk);
// Simulate DEQUANTIZE inputs being constants.
auto* tensor0 = interpreter_.tensor(0);
auto* tensor1 = interpreter_.tensor(1);
auto* tensor2 = interpreter_.tensor(2);
tensor0->allocation_type = kTfLiteMmapRo;
tensor1->allocation_type = kTfLiteMmapRo;
tensor2->allocation_type = kTfLiteMmapRo;
EXPECT_EQ(
interpreter_.SetTensorParametersReadWrite(
3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false),

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace delegates {
@ -183,7 +184,8 @@ FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
// its value in delegated_dequant_consumers.
for (int j = 0; j < node->inputs->size; ++j) {
const int input_tid = node->inputs->data[j];
if (dequant_consumers_.find(input_tid) != dequant_consumers_.end()) {
if (constant_dequant_consumers_.find(input_tid) !=
constant_dequant_consumers_.end()) {
delegated_dequant_consumers[input_tid] += 1;
}
}
@ -192,9 +194,10 @@ FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
// If the number of delegated consumers is same as total number of consumers,
// add the corresponding DEQUANTIZE op to the delegated nodes.
for (auto tensor_and_consumers : delegated_dequant_consumers) {
if (dequant_consumers_[tensor_and_consumers.first] ==
if (constant_dequant_consumers_[tensor_and_consumers.first] ==
tensor_and_consumers.second) {
ops_to_replace.emplace_back(dequant_nodes_[tensor_and_consumers.first]);
ops_to_replace.emplace_back(
constant_dequant_nodes_[tensor_and_consumers.first]);
}
}
@ -216,16 +219,21 @@ FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
bool FP16GraphPartitionHelper::IsNodeSupported(
TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration,
int node_id, std::string* unsupported_details) {
if (registration->builtin_code == kTfLiteBuiltinDequantize &&
context_->tensors[node->inputs->data[0]].type ==
TfLiteType::kTfLiteFloat16) {
// Update mappings if this node is a fp16 DEQUANTIZE node.
dequant_map_[node->outputs->data[0]] = node->inputs->data[0];
dequant_nodes_[node->outputs->data[0]] = node_id;
// We do not accept these ops right now.
// This is done to support use-cases where a DEQUANTIZE output might be
// consumed by a CPU op.
return false;
if (registration->builtin_code == kTfLiteBuiltinDequantize) {
auto& dequantize_input = context_->tensors[node->inputs->data[0]];
if (dequantize_input.type == kTfLiteFloat16 &&
IsConstantTensor(&dequantize_input)) {
// Update mappings if this node is a fp16 DEQUANTIZE node that
// works on a **constant** input tensor.
// If the input is not a constant, the remapping that we do here will
// cause bugs due to preceding ops such as DENSIFY.
constant_dequant_map_[node->outputs->data[0]] = node->inputs->data[0];
constant_dequant_nodes_[node->outputs->data[0]] = node_id;
// We do not accept these ops right now.
// This is done to support use-cases where a DEQUANTIZE output might be
// consumed by a CPU op.
return false;
}
}
// To check if a (possibly) FP16 node is supported, we temporarily point the
@ -234,7 +242,7 @@ bool FP16GraphPartitionHelper::IsNodeSupported(
// we remap the original node inputs, so that the TFLite graph remains the
// same.
std::vector<int> orig_inputs;
if (!dequant_nodes_.empty()) {
if (!constant_dequant_nodes_.empty()) {
RemapFp16InputTensors(node, &orig_inputs);
}
@ -245,10 +253,11 @@ bool FP16GraphPartitionHelper::IsNodeSupported(
// Remapping happened. Restore original inputs.
for (int j = 0; j < node->inputs->size; ++j) {
node->inputs->data[j] = orig_inputs[j];
if (dequant_nodes_.find(orig_inputs[j]) != dequant_nodes_.end()) {
if (constant_dequant_nodes_.find(orig_inputs[j]) !=
constant_dequant_nodes_.end()) {
// If its a fp16 tensor, increment number of consumers of the
// corresponding DEQUANTIZE.
dequant_consumers_[orig_inputs[j]] += 1;
constant_dequant_consumers_[orig_inputs[j]] += 1;
}
}
}
@ -289,8 +298,8 @@ void FP16GraphPartitionHelper::RemapFp16InputTensors(
bool is_remapped = false;
for (int j = 0; j < inputs->size; ++j) {
const int input_tid = inputs->data[j];
const auto it = dequant_map_.find(input_tid);
if (it != dequant_map_.end()) {
const auto it = constant_dequant_map_.find(input_tid);
if (it != constant_dequant_map_.end()) {
inputs->data[j] = it->second;
is_remapped = true;
}

View File

@ -131,8 +131,8 @@ class GraphPartitionHelper {
// Specialized partitioner for graphs that possibly contain fp16 tensors.
//
// From nodes that accept fp16 inputs, this delegates the following:
// 1. All nodes (except DEQUANTIZE) that are supported with fp16 inputs by the
// delegate (in the TFLite graph, these nodes take in dequantized FP32
// 1. All nodes (except DEQUANTIZE) that are supported with constant fp16 inputs
// by the delegate (in the TFLite graph, these nodes take in dequantized FP32
// outputs).
// 2. All fp16 DEQUANTIZE nodes that have *all* their consumers in the *first*
// delegated partition. This is because TFLite's partitioning algorithm
@ -168,11 +168,12 @@ class FP16GraphPartitionHelper : public GraphPartitionHelper {
// ('dequantize' here refers to fp16 DEQUANTIZE)
// Mapping of dequantize nodes' output tensor-id to its node id.
std::unordered_map<int, int> dequant_nodes_;
// TODO(b/156707497): Use absl hash_maps here.
std::unordered_map<int, int> constant_dequant_nodes_;
// Mapping of DEQUANTIZE node's output (fp32) to its input (fp16).
std::unordered_map<int, int> dequant_map_;
std::unordered_map<int, int> constant_dequant_map_;
// mapping of DEQUANTIZE output tensor-id to its number of consumers.
std::unordered_map<int, int> dequant_consumers_;
std::unordered_map<int, int> constant_dequant_consumers_;
};
} // namespace delegates