Move GraphWithDequantPartitionHelper out of delegates/gpu, and put into util.h as the logic remains same w/ other delegates that need to support FP16.
PiperOrigin-RevId: 312243729 Change-Id: I7e2ff7cf80c4860f016cf5dcb60efd94cd2d39dc
This commit is contained in:
parent
a98f72c490
commit
686908251a
@ -116,6 +116,7 @@ cc_library(
|
|||||||
":status",
|
":status",
|
||||||
":tensor",
|
":tensor",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"//tensorflow/lite/delegates:utils",
|
||||||
"//tensorflow/lite:context",
|
"//tensorflow/lite:context",
|
||||||
"//tensorflow/lite:kernel_api",
|
"//tensorflow/lite:kernel_api",
|
||||||
"//tensorflow/lite:util",
|
"//tensorflow/lite:util",
|
||||||
|
@ -45,6 +45,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
|
#include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h"
|
||||||
|
#include "tensorflow/lite/delegates/utils.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
|
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
@ -2809,7 +2810,8 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops,
|
|||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
GraphWithDequantPartitionHelper partition_helper(context, node_supported_fn);
|
delegates::FP16GraphPartitionHelper partition_helper(context,
|
||||||
|
node_supported_fn);
|
||||||
std::set<std::string> unsupported_nodes_info;
|
std::set<std::string> unsupported_nodes_info;
|
||||||
if (partition_helper.Partition(&unsupported_nodes_info) != kTfLiteOk) {
|
if (partition_helper.Partition(&unsupported_nodes_info) != kTfLiteOk) {
|
||||||
return TfLiteIntArrayCreate(0);
|
return TfLiteIntArrayCreate(0);
|
||||||
|
@ -15,9 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
|
#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
|
||||||
|
|
||||||
#include <set>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include <fp16.h>
|
#include <fp16.h>
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
@ -33,157 +31,6 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
TfLiteStatus GraphWithDequantPartitionHelper::Partition(
|
|
||||||
std::set<std::string>* unsupported_nodes_info) {
|
|
||||||
const auto status = GraphPartitionHelper::Partition(unsupported_nodes_info);
|
|
||||||
// Clean up those partitions that have a single dequant op. NoteThose
|
|
||||||
// removed dequant ops have to be reserved in the graph and should not be
|
|
||||||
// delegated.
|
|
||||||
RemoveSingleDequantNodePartitions();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int>
|
|
||||||
GraphWithDequantPartitionHelper::GetNodesOfFirstNLargestPartitions(int n) {
|
|
||||||
// We first get partitions to reduce the number of nodes to be checked in
|
|
||||||
// deciding which dequant ops could actually be replaced. And then we
|
|
||||||
// remap input-tensor to dequant nodes' inputs and remove those
|
|
||||||
// to-be-reserved dequant nodes.
|
|
||||||
auto first_nps = GetFirstNLargestPartitions(n);
|
|
||||||
std::vector<int> ops_to_replace;
|
|
||||||
for (const auto p : first_nps) {
|
|
||||||
auto nodes = p->nodes_to_replace;
|
|
||||||
ops_to_replace.insert(ops_to_replace.end(), nodes->data,
|
|
||||||
nodes->data + nodes->size);
|
|
||||||
}
|
|
||||||
RemapInputTensors(ops_to_replace);
|
|
||||||
RemoveReservedDequantsFromNodes(&ops_to_replace);
|
|
||||||
return ops_to_replace;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool GraphWithDequantPartitionHelper::IsNodeSupported(
|
|
||||||
TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration,
|
|
||||||
int node_id, std::string* unsupported_details) {
|
|
||||||
// If we need to handle dequant nodes, we have to remap input tensors of
|
|
||||||
// this node if some of them come from a dequant node before testing if
|
|
||||||
// the node is supported.
|
|
||||||
std::vector<int> orig_inputs;
|
|
||||||
if (RecordAndRemapInputTensors(registration->builtin_code, node_id, node,
|
|
||||||
&orig_inputs)) {
|
|
||||||
// We have a dequant op here. Note that we retrun an Ok status because a
|
|
||||||
// dequant node is first added as supported. Later, this dequant node
|
|
||||||
// will be removed if it has to be preserved in the graph which happens
|
|
||||||
// when its immediate downstream nodes cannot be supported.
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
const auto status = GraphPartitionHelper::IsNodeSupported(
|
|
||||||
context, node, registration, node_id, unsupported_details);
|
|
||||||
RestoreToOrigInputTensors(node, orig_inputs);
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool GraphWithDequantPartitionHelper::RecordAndRemapInputTensors(
|
|
||||||
int32_t op_code, int node_id, TfLiteNode* node,
|
|
||||||
std::vector<int>* orig_inputs) {
|
|
||||||
orig_inputs->clear();
|
|
||||||
// Record the dequant node.
|
|
||||||
if (op_code == kTfLiteBuiltinDequantize &&
|
|
||||||
context_->tensors[node->inputs->data[0]].type ==
|
|
||||||
TfLiteType::kTfLiteFloat16) {
|
|
||||||
dequant_nodes_[node->outputs->data[0]] = node->inputs->data[0];
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
// For a dequantize op, there's no need to remap its input tensors.
|
|
||||||
if (dequant_nodes_.empty()) return false;
|
|
||||||
RemapInputTensors(node, orig_inputs);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void GraphWithDequantPartitionHelper::RestoreToOrigInputTensors(
|
|
||||||
TfLiteNode* node, const std::vector<int>& orig_inputs) {
|
|
||||||
if (node->inputs->size != orig_inputs.size()) return;
|
|
||||||
for (int j = 0; j < node->inputs->size; ++j) {
|
|
||||||
node->inputs->data[j] = orig_inputs[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void GraphWithDequantPartitionHelper::RemapInputTensors(
|
|
||||||
const std::vector<int>& nodes) const {
|
|
||||||
for (int node_id : nodes) {
|
|
||||||
TfLiteNode* node;
|
|
||||||
TfLiteRegistration* registration;
|
|
||||||
GetNodeAndRegistration(context_, node_id, &node, ®istration)
|
|
||||||
.IgnoreError();
|
|
||||||
RemapInputTensors(node, nullptr /* orig_inputs*/);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void GraphWithDequantPartitionHelper::RemoveSingleDequantNodePartitions() {
|
|
||||||
auto it = partitions_.begin();
|
|
||||||
while (it != partitions_.end()) {
|
|
||||||
auto p = *it;
|
|
||||||
if (p->nodes_to_replace->size != 1) {
|
|
||||||
++it;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
int node_id = p->nodes_to_replace->data[0];
|
|
||||||
TfLiteNode* node = nullptr;
|
|
||||||
TfLiteRegistration* registration = nullptr;
|
|
||||||
GetNodeAndRegistration(context_, node_id, &node, ®istration)
|
|
||||||
.IgnoreError();
|
|
||||||
if (registration->builtin_code != kTfLiteBuiltinDequantize ||
|
|
||||||
context_->tensors[node->inputs->data[0]].type !=
|
|
||||||
TfLiteType::kTfLiteFloat16) {
|
|
||||||
++it;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// Note such dequant nodes have to be preserved in the graph as dequant
|
|
||||||
// ops are not actually supported in the GPU delegate.
|
|
||||||
dequant_nodes_to_save_.insert(node_id);
|
|
||||||
it = partitions_.erase(it);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void GraphWithDequantPartitionHelper::RemoveReservedDequantsFromNodes(
|
|
||||||
std::vector<int>* nodes) {
|
|
||||||
if (dequant_nodes_to_save_.empty()) return;
|
|
||||||
auto it = nodes->begin();
|
|
||||||
while (it != nodes->end()) {
|
|
||||||
if (dequant_nodes_to_save_.find(*it) == dequant_nodes_to_save_.end()) {
|
|
||||||
++it;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
it = nodes->erase(it);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void GraphWithDequantPartitionHelper::RemapInputTensors(
|
|
||||||
TfLiteNode* node, std::vector<int>* orig_inputs) const {
|
|
||||||
TfLiteIntArray* inputs = node->inputs;
|
|
||||||
auto inputs_view = TfLiteIntArrayView(inputs);
|
|
||||||
// Prepopulate 'orig_inputs' first and clear it if there's no input from a
|
|
||||||
// dequant op.
|
|
||||||
if (orig_inputs) {
|
|
||||||
orig_inputs->clear();
|
|
||||||
orig_inputs->reserve(inputs->size);
|
|
||||||
for (auto tid : inputs_view) {
|
|
||||||
orig_inputs->push_back(tid);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Fix this node's inputs (i.e. prune out the preceding dequantize node) in
|
|
||||||
// order to test if it is supported.
|
|
||||||
bool is_remapped = false;
|
|
||||||
for (int j = 0; j < inputs->size; ++j) {
|
|
||||||
const int input_tid = inputs->data[j];
|
|
||||||
const auto it = dequant_nodes_.find(input_tid);
|
|
||||||
if (it != dequant_nodes_.end()) {
|
|
||||||
inputs->data[j] = it->second;
|
|
||||||
is_remapped = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!is_remapped && orig_inputs) orig_inputs->clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status GetNodeAndRegistration(TfLiteContext* context, int node_id,
|
absl::Status GetNodeAndRegistration(TfLiteContext* context, int node_id,
|
||||||
TfLiteNode** tflite_node,
|
TfLiteNode** tflite_node,
|
||||||
TfLiteRegistration** registration) {
|
TfLiteRegistration** registration) {
|
||||||
|
@ -16,17 +16,12 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_
|
||||||
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_
|
||||||
|
|
||||||
#include <set>
|
|
||||||
#include <string>
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||||
#include "tensorflow/lite/delegates/utils.h"
|
|
||||||
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
|
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
@ -35,61 +30,6 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
class GraphWithDequantPartitionHelper : public delegates::GraphPartitionHelper {
|
|
||||||
public:
|
|
||||||
GraphWithDequantPartitionHelper(
|
|
||||||
TfLiteContext* context, delegates::IsNodeSupportedFn is_node_supported_fn)
|
|
||||||
: GraphPartitionHelper(context, std::move(is_node_supported_fn)) {}
|
|
||||||
|
|
||||||
TfLiteStatus Partition(
|
|
||||||
std::set<std::string>* unsupported_nodes_info) override;
|
|
||||||
|
|
||||||
// Returns a list of node indices of all nodes from the first n largest
|
|
||||||
// partitions. If there are fewer paritions than n, all nodes will be
|
|
||||||
// returned. The partition is ranked according to the number of nodes.
|
|
||||||
std::vector<int> GetNodesOfFirstNLargestPartitions(int n);
|
|
||||||
|
|
||||||
protected:
|
|
||||||
bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
|
|
||||||
TfLiteRegistration* registration, int node_id,
|
|
||||||
std::string* unsupported_details) override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true.
|
|
||||||
// When it's not a dequant op, remap its inputs to the inputs of the preceding
|
|
||||||
// dequant if there's a one and returns false. 'orig_inputs' records original
|
|
||||||
// input tensor ids of this node if any input is remapped.
|
|
||||||
bool RecordAndRemapInputTensors(int32_t op_code, int node_id,
|
|
||||||
TfLiteNode* node,
|
|
||||||
std::vector<int>* orig_inputs);
|
|
||||||
|
|
||||||
// Restore inputs of 'node' to 'orig_inputs' only if two sizes match.
|
|
||||||
void RestoreToOrigInputTensors(TfLiteNode* node,
|
|
||||||
const std::vector<int>& orig_inputs);
|
|
||||||
|
|
||||||
// Remap input tensors of every node in 'nodes' (i.e. node indices) if some of
|
|
||||||
// them are from dequant ops.
|
|
||||||
void RemapInputTensors(const std::vector<int>& nodes) const;
|
|
||||||
|
|
||||||
void RemoveSingleDequantNodePartitions();
|
|
||||||
|
|
||||||
void RemoveReservedDequantsFromNodes(std::vector<int>* nodes);
|
|
||||||
|
|
||||||
// Remap input tensors of a single 'node' if some of come from a dequant op.
|
|
||||||
// If 'orig_inputs' isn't nullptr, it records original input tensor ids of
|
|
||||||
// this node if any input is remapped.
|
|
||||||
void RemapInputTensors(TfLiteNode* node, std::vector<int>* orig_inputs) const;
|
|
||||||
|
|
||||||
// A map recording dequantize nodes's input/output tensors of this selected
|
|
||||||
// graph. The key is the output tensor id, and the value is the input tensor
|
|
||||||
// id.
|
|
||||||
std::unordered_map<int, int> dequant_nodes_;
|
|
||||||
|
|
||||||
// A set of dequant nodes as in node indices that have to be preserved in the
|
|
||||||
// graph.
|
|
||||||
std::set<int> dequant_nodes_to_save_;
|
|
||||||
};
|
|
||||||
|
|
||||||
absl::Status GetNodeAndRegistration(TfLiteContext* context, int node_id,
|
absl::Status GetNodeAndRegistration(TfLiteContext* context, int node_id,
|
||||||
TfLiteNode** tflite_node,
|
TfLiteNode** tflite_node,
|
||||||
TfLiteRegistration** registration);
|
TfLiteRegistration** registration);
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/builtin_ops.h"
|
||||||
#include "tensorflow/lite/context_util.h"
|
#include "tensorflow/lite/context_util.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -136,5 +137,167 @@ TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes(
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfLiteStatus FP16GraphPartitionHelper::Partition(
|
||||||
|
std::set<std::string>* unsupported_nodes_info) {
|
||||||
|
const auto status = GraphPartitionHelper::Partition(unsupported_nodes_info);
|
||||||
|
// Clean up those partitions that have a single dequant op. NoteThose
|
||||||
|
// removed dequant ops have to be reserved in the graph and should not be
|
||||||
|
// delegated.
|
||||||
|
RemoveSingleDequantNodePartitions();
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitions(
|
||||||
|
int n) {
|
||||||
|
// We first get partitions to reduce the number of nodes to be checked in
|
||||||
|
// deciding which dequant ops could actually be replaced. And then we
|
||||||
|
// remap input-tensor to dequant nodes' inputs and remove those
|
||||||
|
// to-be-reserved dequant nodes.
|
||||||
|
auto first_nps = GetFirstNLargestPartitions(n);
|
||||||
|
std::vector<int> ops_to_replace;
|
||||||
|
for (const auto p : first_nps) {
|
||||||
|
auto nodes = p->nodes_to_replace;
|
||||||
|
ops_to_replace.insert(ops_to_replace.end(), nodes->data,
|
||||||
|
nodes->data + nodes->size);
|
||||||
|
}
|
||||||
|
RemapInputTensors(ops_to_replace);
|
||||||
|
RemoveReservedDequantsFromNodes(&ops_to_replace);
|
||||||
|
return ops_to_replace;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool FP16GraphPartitionHelper::IsNodeSupported(
|
||||||
|
TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration,
|
||||||
|
int node_id, std::string* unsupported_details) {
|
||||||
|
// If we need to handle dequant nodes, we have to remap input tensors of
|
||||||
|
// this node if some of them come from a dequant node before testing if
|
||||||
|
// the node is supported.
|
||||||
|
std::vector<int> orig_inputs;
|
||||||
|
if (RecordAndRemapInputTensors(registration->builtin_code, node_id, node,
|
||||||
|
&orig_inputs)) {
|
||||||
|
// We have a dequant op here. Note that we retrun an Ok status because a
|
||||||
|
// dequant node is first added as supported. Later, this dequant node
|
||||||
|
// will be removed if it has to be preserved in the graph which happens
|
||||||
|
// when its immediate downstream nodes cannot be supported.
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
const auto status = GraphPartitionHelper::IsNodeSupported(
|
||||||
|
context, node, registration, node_id, unsupported_details);
|
||||||
|
RestoreToOrigInputTensors(node, orig_inputs);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool FP16GraphPartitionHelper::RecordAndRemapInputTensors(
|
||||||
|
int32_t op_code, int node_id, TfLiteNode* node,
|
||||||
|
std::vector<int>* orig_inputs) {
|
||||||
|
orig_inputs->clear();
|
||||||
|
// Record the dequant node.
|
||||||
|
if (op_code == kTfLiteBuiltinDequantize &&
|
||||||
|
context_->tensors[node->inputs->data[0]].type ==
|
||||||
|
TfLiteType::kTfLiteFloat16) {
|
||||||
|
dequant_nodes_[node->outputs->data[0]] = node->inputs->data[0];
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// For a dequantize op, there's no need to remap its input tensors.
|
||||||
|
if (dequant_nodes_.empty()) return false;
|
||||||
|
RemapInputTensors(node, orig_inputs);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void FP16GraphPartitionHelper::RestoreToOrigInputTensors(
|
||||||
|
TfLiteNode* node, const std::vector<int>& orig_inputs) {
|
||||||
|
if (node->inputs->size != orig_inputs.size()) return;
|
||||||
|
for (int j = 0; j < node->inputs->size; ++j) {
|
||||||
|
node->inputs->data[j] = orig_inputs[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FP16GraphPartitionHelper::RemapInputTensors(
|
||||||
|
const std::vector<int>& nodes) const {
|
||||||
|
for (int node_id : nodes) {
|
||||||
|
TfLiteNode* node;
|
||||||
|
TfLiteRegistration* registration;
|
||||||
|
TfLiteStatus status = context_->GetNodeAndRegistration(
|
||||||
|
context_, node_id, &node, ®istration);
|
||||||
|
if (status != kTfLiteOk) {
|
||||||
|
TF_LITE_KERNEL_LOG(context_,
|
||||||
|
"Couldn't get node and registration info for op: %d\n",
|
||||||
|
node_id);
|
||||||
|
}
|
||||||
|
RemapInputTensors(node, nullptr /* orig_inputs*/);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FP16GraphPartitionHelper::RemoveSingleDequantNodePartitions() {
|
||||||
|
auto it = partitions_.begin();
|
||||||
|
while (it != partitions_.end()) {
|
||||||
|
auto p = *it;
|
||||||
|
if (p->nodes_to_replace->size != 1) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int node_id = p->nodes_to_replace->data[0];
|
||||||
|
TfLiteNode* node = nullptr;
|
||||||
|
TfLiteRegistration* registration = nullptr;
|
||||||
|
|
||||||
|
TfLiteStatus status = context_->GetNodeAndRegistration(
|
||||||
|
context_, node_id, &node, ®istration);
|
||||||
|
if (status != kTfLiteOk) {
|
||||||
|
TF_LITE_KERNEL_LOG(context_,
|
||||||
|
"Couldn't get node and registration info for op: %d\n",
|
||||||
|
node_id);
|
||||||
|
}
|
||||||
|
if (registration->builtin_code != kTfLiteBuiltinDequantize ||
|
||||||
|
context_->tensors[node->inputs->data[0]].type !=
|
||||||
|
TfLiteType::kTfLiteFloat16) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Note such dequant nodes have to be preserved in the graph as dequant
|
||||||
|
// ops are not actually supported in the GPU delegate.
|
||||||
|
dequant_nodes_to_save_.insert(node_id);
|
||||||
|
it = partitions_.erase(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FP16GraphPartitionHelper::RemoveReservedDequantsFromNodes(
|
||||||
|
std::vector<int>* nodes) {
|
||||||
|
if (dequant_nodes_to_save_.empty()) return;
|
||||||
|
auto it = nodes->begin();
|
||||||
|
while (it != nodes->end()) {
|
||||||
|
if (dequant_nodes_to_save_.find(*it) == dequant_nodes_to_save_.end()) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
it = nodes->erase(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FP16GraphPartitionHelper::RemapInputTensors(
|
||||||
|
TfLiteNode* node, std::vector<int>* orig_inputs) const {
|
||||||
|
TfLiteIntArray* inputs = node->inputs;
|
||||||
|
auto inputs_view = TfLiteIntArrayView(inputs);
|
||||||
|
// Prepopulate 'orig_inputs' first and clear it if there's no input from a
|
||||||
|
// dequant op.
|
||||||
|
if (orig_inputs) {
|
||||||
|
orig_inputs->clear();
|
||||||
|
orig_inputs->reserve(inputs->size);
|
||||||
|
for (auto tid : inputs_view) {
|
||||||
|
orig_inputs->push_back(tid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fix this node's inputs (i.e. prune out the preceding dequantize node) in
|
||||||
|
// order to test if it is supported.
|
||||||
|
bool is_remapped = false;
|
||||||
|
for (int j = 0; j < inputs->size; ++j) {
|
||||||
|
const int input_tid = inputs->data[j];
|
||||||
|
const auto it = dequant_nodes_.find(input_tid);
|
||||||
|
if (it != dequant_nodes_.end()) {
|
||||||
|
inputs->data[j] = it->second;
|
||||||
|
is_remapped = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!is_remapped && orig_inputs) orig_inputs->clear();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace delegates
|
} // namespace delegates
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -20,6 +20,8 @@ limitations under the License.
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
@ -109,6 +111,70 @@ class GraphPartitionHelper {
|
|||||||
// Contains an array of supported node indices.
|
// Contains an array of supported node indices.
|
||||||
TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory
|
TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// While partitioning the graph, this claims DEQUANTIZE nodes (FP16->FP32) in
|
||||||
|
// addition to supported nodes for the delegate, when the DEQUANTIZE node's
|
||||||
|
// output is an input to the kernel that supports FP16 input.
|
||||||
|
// Noth that you have to use `GetNodesOfFirstNLargestPartitions` instead of
|
||||||
|
// superclass' `GetFirstNLargestPartitions` to do actual remapping of FP16
|
||||||
|
// inputs.
|
||||||
|
class FP16GraphPartitionHelper : public GraphPartitionHelper {
|
||||||
|
public:
|
||||||
|
FP16GraphPartitionHelper(TfLiteContext* context,
|
||||||
|
IsNodeSupportedFn is_node_supported_fn)
|
||||||
|
: GraphPartitionHelper(context, std::move(is_node_supported_fn)) {}
|
||||||
|
|
||||||
|
TfLiteStatus Partition(
|
||||||
|
std::set<std::string>* unsupported_nodes_info) override;
|
||||||
|
|
||||||
|
// Returns a list of node indices of all nodes from the first n largest
|
||||||
|
// partitions. If there are fewer paritions than n, all nodes will be
|
||||||
|
// returned. The partition is ranked according to the number of nodes.
|
||||||
|
// TODO(b/156707497): Add this to superclass besides
|
||||||
|
// GetFirstNLargestPartitions (one that returns partitions instead of nodes)
|
||||||
|
std::vector<int> GetNodesOfFirstNLargestPartitions(int n);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
|
||||||
|
TfLiteRegistration* registration, int node_id,
|
||||||
|
std::string* unsupported_details) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true.
|
||||||
|
// When it's not a dequant op, remap its inputs to the inputs of the preceding
|
||||||
|
// dequant if there's a one and returns false. 'orig_inputs' records original
|
||||||
|
// input tensor ids of this node if any input is remapped.
|
||||||
|
bool RecordAndRemapInputTensors(int32_t op_code, int node_id,
|
||||||
|
TfLiteNode* node,
|
||||||
|
std::vector<int>* orig_inputs);
|
||||||
|
|
||||||
|
// Restore inputs of 'node' to 'orig_inputs' only if two sizes match.
|
||||||
|
void RestoreToOrigInputTensors(TfLiteNode* node,
|
||||||
|
const std::vector<int>& orig_inputs);
|
||||||
|
|
||||||
|
// Remap input tensors of every node in 'nodes' (i.e. node indices) if some of
|
||||||
|
// them are from dequant ops.
|
||||||
|
void RemapInputTensors(const std::vector<int>& nodes) const;
|
||||||
|
|
||||||
|
void RemoveSingleDequantNodePartitions();
|
||||||
|
|
||||||
|
void RemoveReservedDequantsFromNodes(std::vector<int>* nodes);
|
||||||
|
|
||||||
|
// Remap input tensors of a single 'node' if some of come from a dequant op.
|
||||||
|
// If 'orig_inputs' isn't nullptr, it records original input tensor ids of
|
||||||
|
// this node if any input is remapped.
|
||||||
|
void RemapInputTensors(TfLiteNode* node, std::vector<int>* orig_inputs) const;
|
||||||
|
|
||||||
|
// A map recording dequantize nodes's input/output tensors of this selected
|
||||||
|
// graph. The key is the output tensor id, and the value is the input tensor
|
||||||
|
// id.
|
||||||
|
std::unordered_map<int, int> dequant_nodes_;
|
||||||
|
|
||||||
|
// A set of dequant nodes as in node indices that have to be preserved in the
|
||||||
|
// graph.
|
||||||
|
std::set<int> dequant_nodes_to_save_;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace delegates
|
} // namespace delegates
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user