Improve GPU delegate graph partitioning logic
1. Allow offloading computation starting from intermediate nodes of a model in GPU delegate. And when there are multiple partitions, only offload the first largest partition. 2. Unify GetOpsToReplace and GetOpsToReplaceFromGraphWithDequantize into the same function. RELNOTES: [TFLite] Allow GPU acceleration starting with internal nodes PiperOrigin-RevId: 300889638 Change-Id: I09da4043532a95a9efb4f167c8af4243c3889a45
This commit is contained in:
parent
01641aee30
commit
3492a0f355
tensorflow/lite
@ -416,6 +416,7 @@ cc_library(
|
|||||||
hdrs = ["util.h"],
|
hdrs = ["util.h"],
|
||||||
copts = TFLITE_DEFAULT_COPTS + tflite_copts(),
|
copts = TFLITE_DEFAULT_COPTS + tflite_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
|
":kernel_api",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
],
|
],
|
||||||
@ -432,6 +433,7 @@ cc_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":util",
|
":util",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -20,9 +20,12 @@ limitations under the License.
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#include <list>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <fp16.h>
|
#include <fp16.h>
|
||||||
@ -36,6 +39,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/context.h"
|
#include "tensorflow/lite/context.h"
|
||||||
|
#include "tensorflow/lite/context_util.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
|
#include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
@ -2746,96 +2750,309 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
|||||||
return absl::make_unique<UnsupportedOperationParser>();
|
return absl::make_unique<UnsupportedOperationParser>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// A utility class to remove Dequantize op in the graph.
|
Status GetNodeAndRegistration(TfLiteContext* context, int node_id,
|
||||||
class DequantizeOpRemover {
|
TfLiteNode** tflite_node,
|
||||||
|
TfLiteRegistration** registration) {
|
||||||
|
if (context->GetNodeAndRegistration(context, node_id, tflite_node,
|
||||||
|
registration) != kTfLiteOk) {
|
||||||
|
return InvalidArgumentError(absl::StrCat(
|
||||||
|
"Couldn't get node and registration info for op: ", node_id));
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
using IsNodeSupportedFn =
|
||||||
|
std::function<Status(TfLiteContext*, TfLiteNode*, TfLiteRegistration*)>;
|
||||||
|
|
||||||
|
// A utility class to help model graph parition and decide the partition to be
|
||||||
|
// offloaded to GPU.
|
||||||
|
// TODO(b/151152967): move the following to lite/delegates/utils
|
||||||
|
class GraphPartitionHelper {
|
||||||
public:
|
public:
|
||||||
bool MayAddNode(const TfLiteNode& node, int32_t op_code, int node_id,
|
GraphPartitionHelper(TfLiteContext* context,
|
||||||
const TfLiteTensor* tensors) {
|
IsNodeSupportedFn is_node_supported_fn)
|
||||||
if (op_code == kTfLiteBuiltinDequantize &&
|
: is_node_supported_fn_(is_node_supported_fn), context_(context) {}
|
||||||
tensors[node.inputs->data[0]].type == TfLiteType::kTfLiteFloat16) {
|
|
||||||
dequant_nodes_[node.outputs->data[0]] = {node_id, node.inputs->data[0]};
|
virtual ~GraphPartitionHelper() { TfLiteIntArrayFree(supported_nodes_); }
|
||||||
input_tensors_.insert(node.inputs->data[0]);
|
|
||||||
return true;
|
// Partitions the graph into multiple subgraphs, each of which is in
|
||||||
|
// dependency order with others
|
||||||
|
virtual Status Partition(std::set<std::string>* unsupported_nodes_info) {
|
||||||
|
RETURN_IF_ERROR(PrepareSupportedNodes(unsupported_nodes_info));
|
||||||
|
|
||||||
|
TfLiteDelegateParams* partition_params_array_ = nullptr;
|
||||||
|
int num_partitions_ = 0;
|
||||||
|
if (context_->PreviewDelegatePartitioning(context_, supported_nodes_,
|
||||||
|
&partition_params_array_,
|
||||||
|
&num_partitions_) != kTfLiteOk) {
|
||||||
|
return InvalidArgumentError("Unable to preview delegate partition.");
|
||||||
}
|
}
|
||||||
return false;
|
|
||||||
|
for (int i = 0; i < num_partitions_; ++i) {
|
||||||
|
partitions_.push_back(partition_params_array_ + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remap inputs of 'node' to the inputs of the preceding dequant if there's a
|
// Returns the first n largest partitions or all if #partitions is less than
|
||||||
// one.
|
// 'n'. Note that partitions are ranked according to the number of nodes that
|
||||||
// inputs_from_dequant: records the node id of a dequantize node whose
|
// a partition has, and the returned TfLiteDelegateParams objects are *owned*
|
||||||
// output tensor is one of the input tensor of this 'node'.
|
// by the TfLite runtime.
|
||||||
// orig_inputs: records original input tensor ids of this node if any input is
|
std::vector<TfLiteDelegateParams*> GetFirstNLargestPartitions(int n) {
|
||||||
// remapped.
|
const int total = num_partitions();
|
||||||
void MayRemapInputTensors(TfLiteNode* node,
|
// We only sort partitions according to their sizes if necessary.
|
||||||
std::vector<int>* inputs_from_dequant,
|
if (n < total) {
|
||||||
std::vector<int>* orig_inputs) const {
|
partitions_.sort(CompareTwoPartitions);
|
||||||
inputs_from_dequant->clear();
|
}
|
||||||
orig_inputs->clear();
|
std::vector<TfLiteDelegateParams*> results;
|
||||||
if (dequant_nodes_.empty()) return;
|
auto p_it = partitions_.begin();
|
||||||
|
for (int i = 0; i < std::min(total, n); ++i, ++p_it) {
|
||||||
|
results.push_back(*p_it);
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteIntArray* inputs = node->inputs;
|
int num_total_nodes() const { return num_total_nodes_; }
|
||||||
orig_inputs->reserve(inputs->size);
|
int num_partitions() const { return partitions_.size(); }
|
||||||
// Fix the node's inputs (i.e. prune out the preceding dequantize node)
|
|
||||||
// in order to test if it is supported on the GPU.
|
private:
|
||||||
for (int j = 0; j < inputs->size; ++j) {
|
static bool CompareTwoPartitions(TfLiteDelegateParams* left,
|
||||||
const int input_tid = inputs->data[j];
|
TfLiteDelegateParams* right) {
|
||||||
orig_inputs->push_back(input_tid);
|
// Reverse sort
|
||||||
const auto it = dequant_nodes_.find(input_tid);
|
return left->nodes_to_replace->size > right->nodes_to_replace->size;
|
||||||
if (it != dequant_nodes_.end()) {
|
}
|
||||||
inputs_from_dequant->push_back(it->second.node_id);
|
|
||||||
// Remap inputs of this node to the inputs of the preceding dequant.
|
Status PrepareSupportedNodes(
|
||||||
inputs->data[j] = it->second.input_tensor_id;
|
std::set<std::string>* unsupported_nodes_info = nullptr) {
|
||||||
|
TfLiteIntArray* execution_plan = nullptr;
|
||||||
|
if (context_->GetExecutionPlan(context_, &execution_plan) != kTfLiteOk) {
|
||||||
|
return InvalidArgumentError("Unable to get graph execution plan.");
|
||||||
|
}
|
||||||
|
|
||||||
|
num_total_nodes_ = execution_plan->size;
|
||||||
|
supported_nodes_ = TfLiteIntArrayCreate(num_total_nodes_);
|
||||||
|
supported_nodes_->size = 0;
|
||||||
|
for (int node_id : TfLiteIntArrayView(execution_plan)) {
|
||||||
|
TfLiteNode* node;
|
||||||
|
TfLiteRegistration* registration;
|
||||||
|
auto status =
|
||||||
|
GetNodeAndRegistration(context_, node_id, &node, ®istration);
|
||||||
|
if (!status.ok()) {
|
||||||
|
supported_nodes_->size = 0;
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
status = IsNodeSupported(context_, node, registration, node_id);
|
||||||
|
if (status.ok()) {
|
||||||
|
supported_nodes_->data[supported_nodes_->size++] = node_id;
|
||||||
|
} else if (unsupported_nodes_info) {
|
||||||
|
unsupported_nodes_info->insert(
|
||||||
|
absl::StrCat(GetOpNameByRegistration(*registration), ": ",
|
||||||
|
status.error_message()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
// May restore inputs of 'node' to 'orig_inputs' if there're inputs from
|
// The number of total nodes passed in for partition (i.e. the
|
||||||
// dequant nodes (i.e. denoted by 'inputs_from_dequant'). We will also mark
|
// execution_plan size)
|
||||||
// such dequantize nodes to be preserved.
|
int num_total_nodes_ = 0;
|
||||||
void MayRestoreInputTensors(TfLiteNode* node,
|
|
||||||
const std::vector<int>& inputs_from_dequant,
|
|
||||||
const std::vector<int>& orig_inputs) {
|
|
||||||
if (inputs_from_dequant.empty()) return;
|
|
||||||
|
|
||||||
for (int j = 0; j < node->inputs->size; ++j) {
|
// Tells whether a node is replaceable.
|
||||||
node->inputs->data[j] = orig_inputs[j];
|
const IsNodeSupportedFn is_node_supported_fn_;
|
||||||
}
|
TfLiteIntArray* supported_nodes_; // owns the memory
|
||||||
// Mark those dequantize nodes to be presevered in the graph.
|
|
||||||
dequant_nodes_to_save_.insert(dequant_nodes_to_save_.end(),
|
protected:
|
||||||
inputs_from_dequant.begin(),
|
virtual Status IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
|
||||||
inputs_from_dequant.end());
|
TfLiteRegistration* registration,
|
||||||
|
int node_id) {
|
||||||
|
return is_node_supported_fn_(context, node, registration);
|
||||||
}
|
}
|
||||||
|
|
||||||
void RemovePreservedNodesFrom(std::vector<int>* nodes) const {
|
TfLiteContext* const context_ = nullptr;
|
||||||
for (const int nid : dequant_nodes_to_save_) {
|
|
||||||
auto it = std::find(nodes->begin(), nodes->end(), nid);
|
// Doesn't own the memory of each TfLiteDelegateParams object as it's
|
||||||
if (it != nodes->end()) nodes->erase(it);
|
// managed by the TfLite runtime itself. See
|
||||||
|
// TfLiteContext::PreviewDelegatePartitioning for details.
|
||||||
|
std::list<TfLiteDelegateParams*> partitions_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class GraphWithDequantPartitionHelper : public GraphPartitionHelper {
|
||||||
|
public:
|
||||||
|
GraphWithDequantPartitionHelper(TfLiteContext* context,
|
||||||
|
IsNodeSupportedFn is_node_supported_fn)
|
||||||
|
: GraphPartitionHelper(context, std::move(is_node_supported_fn)) {}
|
||||||
|
|
||||||
|
Status Partition(std::set<std::string>* unsupported_nodes_info) override {
|
||||||
|
auto status = GraphPartitionHelper::Partition(unsupported_nodes_info);
|
||||||
|
// Clean up those partitions that have a single dequant op. NoteThose
|
||||||
|
// removed dequant ops have to be reserved in the graph and should not be
|
||||||
|
// delegated.
|
||||||
|
RemoveSingleDequantNodePartitions();
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a list of node indices of all nodes from the first n largest
|
||||||
|
// partitions. If there are fewer paritions than n, all nodes will be
|
||||||
|
// returned. The partition is ranked according to the number of nodes.
|
||||||
|
std::vector<int> GetNodesOfFirstNLargestPartitions(int n) {
|
||||||
|
// We first get partitions to reduce the number of nodes to be checked in
|
||||||
|
// deciding which dequant ops could actually be replaced. And then we
|
||||||
|
// remap input-tensor to dequant nodes' inputs and remove those
|
||||||
|
// to-be-reserved dequant nodes.
|
||||||
|
auto first_nps = GetFirstNLargestPartitions(n);
|
||||||
|
std::vector<int> ops_to_replace;
|
||||||
|
for (const auto p : first_nps) {
|
||||||
|
auto nodes = p->nodes_to_replace;
|
||||||
|
ops_to_replace.insert(ops_to_replace.end(), nodes->data,
|
||||||
|
nodes->data + nodes->size);
|
||||||
}
|
}
|
||||||
|
RemapInputTensors(ops_to_replace);
|
||||||
|
RemoveReservedDequantsFromNodes(&ops_to_replace);
|
||||||
|
return ops_to_replace;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
|
||||||
|
TfLiteRegistration* registration,
|
||||||
|
int node_id) override {
|
||||||
|
// If we need to handle dequant nodes, we have to remap input tensors of
|
||||||
|
// this node if some of them come from a dequant node before testing if
|
||||||
|
// the node is supported.
|
||||||
|
std::vector<int> orig_inputs;
|
||||||
|
if (RecordAndRemapInputTensors(registration->builtin_code, node_id, node,
|
||||||
|
&orig_inputs)) {
|
||||||
|
// We have a dequant op here. Note that we retrun an Ok status because a
|
||||||
|
// dequant node is first added as supported. Later, this dequant node
|
||||||
|
// will be removed if it has to be preserved in the graph which happens
|
||||||
|
// when its immediate downstream nodes cannot be supported.
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
const auto status = GraphPartitionHelper::IsNodeSupported(
|
||||||
|
context, node, registration, node_id);
|
||||||
|
RestoreToOrigInputTensors(node, orig_inputs);
|
||||||
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct NodeInfo {
|
// Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true.
|
||||||
int node_id;
|
// When it's not a dequant op, remap its inputs to the inputs of the preceding
|
||||||
int input_tensor_id;
|
// dequant if there's a one and returns false. 'orig_inputs' records original
|
||||||
};
|
// input tensor ids of this node if any input is remapped.
|
||||||
|
bool RecordAndRemapInputTensors(int32_t op_code, int node_id,
|
||||||
|
TfLiteNode* node,
|
||||||
|
std::vector<int>* orig_inputs) {
|
||||||
|
orig_inputs->clear();
|
||||||
|
// Record the dequant node.
|
||||||
|
if (op_code == kTfLiteBuiltinDequantize &&
|
||||||
|
context_->tensors[node->inputs->data[0]].type ==
|
||||||
|
TfLiteType::kTfLiteFloat16) {
|
||||||
|
dequant_nodes_[node->outputs->data[0]] = node->inputs->data[0];
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// For a dequantize op, there's no need to remap its input tensors.
|
||||||
|
if (dequant_nodes_.empty()) return false;
|
||||||
|
RemapInputTensors(node, orig_inputs);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// A map recording dequantize nodes of this graph. A dequantize node is
|
// Restore inputs of 'node' to 'orig_inputs' only if two sizes match.
|
||||||
// identified by the output tensor id. The value is a NodeInfo.
|
void RestoreToOrigInputTensors(TfLiteNode* node,
|
||||||
std::unordered_map<int, NodeInfo> dequant_nodes_;
|
const std::vector<int>& orig_inputs) {
|
||||||
|
if (node->inputs->size != orig_inputs.size()) return;
|
||||||
|
for (int j = 0; j < node->inputs->size; ++j) {
|
||||||
|
node->inputs->data[j] = orig_inputs[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// A set of input tensor ids of dequantize nodes.
|
// Remap input tensors of every node in 'nodes' (i.e. node indices) if some of
|
||||||
std::set<int> input_tensors_;
|
// them are from dequant ops.
|
||||||
|
void RemapInputTensors(const std::vector<int>& nodes) const {
|
||||||
|
for (int node_id : nodes) {
|
||||||
|
TfLiteNode* node;
|
||||||
|
TfLiteRegistration* registration;
|
||||||
|
GetNodeAndRegistration(context_, node_id, &node, ®istration)
|
||||||
|
.IgnoreError();
|
||||||
|
RemapInputTensors(node, nullptr /* orig_inputs*/);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// The node ids of dequantize nodes that has to be preserved in the graph.
|
void RemoveSingleDequantNodePartitions() {
|
||||||
std::vector<int> dequant_nodes_to_save_;
|
auto it = partitions_.begin();
|
||||||
|
while (it != partitions_.end()) {
|
||||||
|
auto p = *it;
|
||||||
|
if (p->nodes_to_replace->size != 1) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int node_id = p->nodes_to_replace->data[0];
|
||||||
|
TfLiteNode* node = nullptr;
|
||||||
|
TfLiteRegistration* registration = nullptr;
|
||||||
|
GetNodeAndRegistration(context_, node_id, &node, ®istration)
|
||||||
|
.IgnoreError();
|
||||||
|
if (registration->builtin_code != kTfLiteBuiltinDequantize) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Note such dequant nodes have to be preserved in the graph as dequant
|
||||||
|
// ops are not actually supported in the GPU delegate.
|
||||||
|
dequant_nodes_to_save_.insert(node_id);
|
||||||
|
it = partitions_.erase(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void RemoveReservedDequantsFromNodes(std::vector<int>* nodes) {
|
||||||
|
if (dequant_nodes_to_save_.empty()) return;
|
||||||
|
auto it = nodes->begin();
|
||||||
|
while (it != nodes->end()) {
|
||||||
|
if (dequant_nodes_to_save_.find(*it) == dequant_nodes_to_save_.end()) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
it = nodes->erase(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remap input tensors of a single 'node' if some of come from a dequant op.
|
||||||
|
// If 'orig_inputs' isn't nullptr, it records original input tensor ids of
|
||||||
|
// this node if any input is remapped.
|
||||||
|
void RemapInputTensors(TfLiteNode* node,
|
||||||
|
std::vector<int>* orig_inputs) const {
|
||||||
|
TfLiteIntArray* inputs = node->inputs;
|
||||||
|
auto inputs_view = TfLiteIntArrayView(inputs);
|
||||||
|
// Prepopulate 'orig_inputs' first and clear it if there's no input from a
|
||||||
|
// dequant op.
|
||||||
|
if (orig_inputs) {
|
||||||
|
orig_inputs->clear();
|
||||||
|
orig_inputs->reserve(inputs->size);
|
||||||
|
for (auto tid : inputs_view) {
|
||||||
|
orig_inputs->push_back(tid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fix this node's inputs (i.e. prune out the preceding dequantize node) in
|
||||||
|
// order to test if it is supported.
|
||||||
|
bool is_remapped = false;
|
||||||
|
for (int j = 0; j < inputs->size; ++j) {
|
||||||
|
const int input_tid = inputs->data[j];
|
||||||
|
const auto it = dequant_nodes_.find(input_tid);
|
||||||
|
if (it != dequant_nodes_.end()) {
|
||||||
|
inputs->data[j] = it->second;
|
||||||
|
is_remapped = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!is_remapped && orig_inputs) orig_inputs->clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
// A map recording dequantize nodes's input/output tensors of this selected
|
||||||
|
// graph. The key is the output tensor id, and the value is the input tensor
|
||||||
|
// id.
|
||||||
|
std::unordered_map<int, int> dequant_nodes_;
|
||||||
|
|
||||||
|
// A set of dequant nodes as in node indices that have to be preserved in the
|
||||||
|
// graph.
|
||||||
|
std::set<int> dequant_nodes_to_save_;
|
||||||
};
|
};
|
||||||
} // namespace
|
|
||||||
|
|
||||||
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
|
|
||||||
TensorRef<BHWC>* tensor_ref) {
|
|
||||||
tensor_ref->type = ToDataType(tflite_tensor.type);
|
|
||||||
return ExtractTensorShape(tflite_tensor, &tensor_ref->shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status IsSupported(const TfLiteContext* context, TfLiteNode* node,
|
Status IsSupported(const TfLiteContext* context, TfLiteNode* node,
|
||||||
const TfLiteRegistration* registration) {
|
const TfLiteRegistration* registration) {
|
||||||
@ -2856,158 +3073,63 @@ bool IsAllFloatTensors(const TfLiteContext* context,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GetOpNameByRegistration(const TfLiteRegistration* registration) {
|
} // namespace
|
||||||
auto op = registration->builtin_code;
|
|
||||||
std::string result =
|
|
||||||
EnumNameBuiltinOperator(static_cast<BuiltinOperator>(op));
|
|
||||||
if (op == kTfLiteBuiltinCustom) {
|
|
||||||
result += " " + std::string(registration->custom_name);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GetNodeAndRegistration(TfLiteContext* context, int node_id,
|
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
|
||||||
TfLiteNode** tflite_node,
|
TensorRef<BHWC>* tensor_ref) {
|
||||||
TfLiteRegistration** registration) {
|
tensor_ref->type = ToDataType(tflite_tensor.type);
|
||||||
if (context->GetNodeAndRegistration(context, node_id, tflite_node,
|
return ExtractTensorShape(tflite_tensor, &tensor_ref->shape);
|
||||||
registration) != kTfLiteOk) {
|
|
||||||
return InvalidArgumentError(absl::StrCat(
|
|
||||||
"Couldn't get node and registration info for op: ", node_id));
|
|
||||||
}
|
|
||||||
return OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteIntArray* GetOpsToReplaceFromGraphWithDequantize(TfLiteContext* context) {
|
|
||||||
TfLiteIntArray* execution_plan = nullptr;
|
|
||||||
if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) {
|
|
||||||
context->ReportError(context, "Unable to get graph execution plan.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
std::set<std::string> errors;
|
|
||||||
std::vector<int> ops_to_replace;
|
|
||||||
DequantizeOpRemover dequant_remover;
|
|
||||||
|
|
||||||
for (int i = 0; i < execution_plan->size; ++i) {
|
|
||||||
const int node_id = execution_plan->data[i];
|
|
||||||
TfLiteNode* node = nullptr;
|
|
||||||
TfLiteRegistration* registration = nullptr;
|
|
||||||
auto status =
|
|
||||||
GetNodeAndRegistration(context, node_id, &node, ®istration);
|
|
||||||
if (!status.ok()) {
|
|
||||||
context->ReportError(context, status.error_message().c_str());
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dequant_remover.MayAddNode(*node, registration->builtin_code, node_id,
|
|
||||||
context->tensors)) {
|
|
||||||
// For now, add the node to the list of ops to replace.
|
|
||||||
ops_to_replace.push_back(node_id);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Record the node id of a dequantize node whose output tensor is one of the
|
|
||||||
// input tensor of this node.
|
|
||||||
std::vector<int> inputs_from_dequant;
|
|
||||||
// Original input tensor ids of this node.
|
|
||||||
std::vector<int> orig_inputs;
|
|
||||||
dequant_remover.MayRemapInputTensors(node, &inputs_from_dequant,
|
|
||||||
&orig_inputs);
|
|
||||||
|
|
||||||
status = IsSupported(context, node, registration);
|
|
||||||
if (status.ok() &&
|
|
||||||
// TODO(eignasheva): resolve sub operation support for metal delegate
|
|
||||||
// registration->builtin_code != kTfLiteBuiltinSub &&
|
|
||||||
IsAllFloatTensors(context, node->inputs) &&
|
|
||||||
IsAllFloatTensors(context, node->outputs) && errors.empty()) {
|
|
||||||
// Node is supported and there were no previous errors.
|
|
||||||
ops_to_replace.push_back(node_id);
|
|
||||||
} else {
|
|
||||||
// The node is not replaceable, record an error message.
|
|
||||||
errors.insert(absl::StrCat(GetOpNameByRegistration(registration), ": ",
|
|
||||||
status.error_message()));
|
|
||||||
dequant_remover.MayRestoreInputTensors(node, inputs_from_dequant,
|
|
||||||
orig_inputs);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!errors.empty()) {
|
|
||||||
std::string unsupported = absl::StrJoin(errors, "\n");
|
|
||||||
std::string error_message =
|
|
||||||
"Next operations are not supported by GPU delegate:\n" + unsupported +
|
|
||||||
"\nFirst " + std::to_string(ops_to_replace.size()) +
|
|
||||||
" operations will run on the GPU, and the remaining " +
|
|
||||||
std::to_string(execution_plan->size - ops_to_replace.size()) +
|
|
||||||
" on the CPU.";
|
|
||||||
context->ReportError(context, error_message.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
dequant_remover.RemovePreservedNodesFrom(&ops_to_replace);
|
|
||||||
return ConvertVectorToTfLiteIntArray(ops_to_replace);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(impjdi): Check number of input/output tensors and their dimensions.
|
// TODO(impjdi): Check number of input/output tensors and their dimensions.
|
||||||
// TODO(impjdi): Check ops' parameters.
|
// TODO(impjdi): Check ops' parameters.
|
||||||
TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
|
TfLiteIntArray* GetOpsToReplace(TfLiteContext* context) {
|
||||||
TfLiteIntArray* execution_plan = nullptr;
|
IsNodeSupportedFn node_supported_fn =
|
||||||
if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) {
|
[=](TfLiteContext* context, TfLiteNode* node,
|
||||||
context->ReportError(context, "Unable to get graph execution plan.");
|
TfLiteRegistration* registration) -> Status {
|
||||||
|
RETURN_IF_ERROR(IsSupported(context, node, registration));
|
||||||
|
return (IsAllFloatTensors(context, node->inputs) &&
|
||||||
|
IsAllFloatTensors(context, node->outputs))
|
||||||
|
? OkStatus()
|
||||||
|
: FailedPreconditionError(
|
||||||
|
"OP is supported, but tensor type isn't matched!");
|
||||||
|
};
|
||||||
|
|
||||||
|
GraphWithDequantPartitionHelper partition_helper(context, node_supported_fn);
|
||||||
|
std::set<std::string> unsupported_nodes_info;
|
||||||
|
auto status = partition_helper.Partition(&unsupported_nodes_info);
|
||||||
|
if (!status.ok()) {
|
||||||
|
TF_LITE_KERNEL_LOG(context, status.error_message().c_str());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dispatch to another function if graph has Dequantize nodes.
|
// We simply get 1st largest partition, but we could later explore whether
|
||||||
for (int i = 0; i < execution_plan->size; ++i) {
|
// getting more partitions could lead to better performance, i.e. by
|
||||||
const int node_id = execution_plan->data[i];
|
// parameterizing '1' here.
|
||||||
TfLiteNode* node = nullptr;
|
std::vector<int> ops_to_replace =
|
||||||
TfLiteRegistration* registration = nullptr;
|
partition_helper.GetNodesOfFirstNLargestPartitions(1);
|
||||||
auto status =
|
|
||||||
GetNodeAndRegistration(context, node_id, &node, ®istration);
|
|
||||||
if (!status.ok()) {
|
|
||||||
context->ReportError(context, status.error_message().c_str());
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
if (registration->builtin_code == kTfLiteBuiltinDequantize &&
|
|
||||||
context->tensors[node->inputs->data[0]].type ==
|
|
||||||
TfLiteType::kTfLiteFloat16) {
|
|
||||||
return GetOpsToReplaceFromGraphWithDequantize(context);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// No Dequantize nodes. Iterate through graph and find ops to replace.
|
if (!unsupported_nodes_info.empty()) {
|
||||||
TfLiteIntArray* subgraph = TfLiteIntArrayCreate(execution_plan->size);
|
std::string unsupported = absl::StrJoin(unsupported_nodes_info, "\n");
|
||||||
subgraph->size = 0;
|
std::string error_message = absl::StrCat(
|
||||||
std::set<std::string> errors;
|
"Following operations are not supported by GPU delegate:\n",
|
||||||
for (int i = 0; i < execution_plan->size; ++i) {
|
unsupported, "\n");
|
||||||
const int node_id = execution_plan->data[i];
|
if (!ops_to_replace.empty()) {
|
||||||
TfLiteNode* node;
|
absl::StrAppendFormat(
|
||||||
TfLiteRegistration* registration;
|
&error_message,
|
||||||
auto status =
|
"%d operations will run on the GPU (first node: "
|
||||||
GetNodeAndRegistration(context, node_id, &node, ®istration);
|
"%d, last node: %d), and the remaining %d",
|
||||||
if (!status.ok()) {
|
ops_to_replace.size(), ops_to_replace.front(), ops_to_replace.back(),
|
||||||
context->ReportError(context, status.error_message().c_str());
|
partition_helper.num_total_nodes() - ops_to_replace.size());
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
status = IsSupported(context, node, registration);
|
|
||||||
if (status.ok() &&
|
|
||||||
// TODO(eignasheva): resolve sub operation support for metal delegate
|
|
||||||
// registration->builtin_code != kTfLiteBuiltinSub &&
|
|
||||||
IsAllFloatTensors(context, node->inputs) &&
|
|
||||||
IsAllFloatTensors(context, node->outputs)) {
|
|
||||||
if (errors.empty()) subgraph->data[subgraph->size++] = node_id;
|
|
||||||
} else {
|
} else {
|
||||||
errors.insert(absl::StrCat(GetOpNameByRegistration(registration), ": ",
|
absl::StrAppend(&error_message,
|
||||||
status.error_message()));
|
"No operations will run on the GPU, and all ",
|
||||||
|
partition_helper.num_total_nodes());
|
||||||
}
|
}
|
||||||
|
absl::StrAppend(&error_message, " operations will run on the CPU.");
|
||||||
|
TF_LITE_KERNEL_LOG(context, error_message.c_str());
|
||||||
}
|
}
|
||||||
if (!errors.empty()) {
|
return ConvertVectorToTfLiteIntArray(ops_to_replace);
|
||||||
std::string unsupported = absl::StrJoin(errors, "\n");
|
|
||||||
std::string error_message =
|
|
||||||
"Next operations are not supported by GPU delegate:\n" + unsupported +
|
|
||||||
"\nFirst " + std::to_string(subgraph->size) +
|
|
||||||
" operations will run on the GPU, and the remaining " +
|
|
||||||
std::to_string(execution_plan->size - subgraph->size) + " on the CPU.";
|
|
||||||
context->ReportError(context, error_message.c_str());
|
|
||||||
}
|
|
||||||
return subgraph;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BuildModel(TfLiteContext* context,
|
Status BuildModel(TfLiteContext* context,
|
||||||
@ -3047,7 +3169,7 @@ Status BuildModel(TfLiteContext* context,
|
|||||||
const auto status =
|
const auto status =
|
||||||
operations[i]->Parse(tflite_node, registration, graph, &reader);
|
operations[i]->Parse(tflite_node, registration, graph, &reader);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
return InternalError(absl::StrCat(GetOpNameByRegistration(registration),
|
return InternalError(absl::StrCat(GetOpNameByRegistration(*registration),
|
||||||
": ", status.error_message()));
|
": ", status.error_message()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -120,9 +120,52 @@ TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankGT3) {
|
|||||||
EXPECT_FALSE(status.ok());
|
EXPECT_FALSE(status.ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
class InterpreterFp16 {
|
class DelegatedInterpreter {
|
||||||
public:
|
public:
|
||||||
explicit InterpreterFp16(TfLiteBuiltinOperator op) {
|
explicit DelegatedInterpreter(int num_nodes) {
|
||||||
|
exec_plan_ = TfLiteIntArrayCreate(num_nodes);
|
||||||
|
}
|
||||||
|
virtual ~DelegatedInterpreter() {
|
||||||
|
TfLiteIntArrayFree(exec_plan_);
|
||||||
|
for (auto params : delegate_params_) {
|
||||||
|
TfLiteIntArrayFree(params.nodes_to_replace);
|
||||||
|
TfLiteIntArrayFree(params.input_tensors);
|
||||||
|
TfLiteIntArrayFree(params.output_tensors);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the TfLiteContext to be mocked for swapping out functions that have to
|
||||||
|
// be called inside delegate (i.e. in delegat kernel mode).
|
||||||
|
TfLiteContext* context() { return interpreter_.primary_subgraph().context(); }
|
||||||
|
|
||||||
|
std::vector<std::pair<TfLiteNode, TfLiteRegistration>>&
|
||||||
|
nodes_and_registration() {
|
||||||
|
return interpreter_.primary_subgraph().nodes_and_registration();
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteIntArray* exec_plan() const { return exec_plan_; }
|
||||||
|
TfLiteDelegateParams* add_delegate_params() {
|
||||||
|
delegate_params_.push_back(TfLiteDelegateParams());
|
||||||
|
return &delegate_params_.back();
|
||||||
|
}
|
||||||
|
TfLiteDelegateParams* delegate_params() { return &delegate_params_.front(); }
|
||||||
|
int num_delegate_params() { return delegate_params_.size(); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Interpreter interpreter_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The manually-set execution plan for this delegated interpreter.
|
||||||
|
TfLiteIntArray* exec_plan_;
|
||||||
|
|
||||||
|
// The TfLiteDelegateParams object that's manually populated inside the mocked
|
||||||
|
// TfLiteContext::PreviewDelegatePartitioning.
|
||||||
|
std::vector<TfLiteDelegateParams> delegate_params_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class InterpreterFp16 : public DelegatedInterpreter {
|
||||||
|
public:
|
||||||
|
explicit InterpreterFp16(TfLiteBuiltinOperator op) : DelegatedInterpreter(3) {
|
||||||
void* builtin_data = malloc(sizeof(int));
|
void* builtin_data = malloc(sizeof(int));
|
||||||
EXPECT_EQ(interpreter_.AddTensors(5), kTfLiteOk);
|
EXPECT_EQ(interpreter_.AddTensors(5), kTfLiteOk);
|
||||||
EXPECT_EQ(interpreter_.SetInputs({0, 1}), kTfLiteOk);
|
EXPECT_EQ(interpreter_.SetInputs({0, 1}), kTfLiteOk);
|
||||||
@ -187,22 +230,23 @@ class InterpreterFp16 {
|
|||||||
3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false),
|
3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false),
|
||||||
kTfLiteOk);
|
kTfLiteOk);
|
||||||
|
|
||||||
exec_plan_ = TfLiteIntArrayCreate(3);
|
exec_plan()->data[0] = 0;
|
||||||
exec_plan_->data[0] = 0;
|
exec_plan()->data[1] = 1;
|
||||||
exec_plan_->data[1] = 1;
|
exec_plan()->data[2] = 2;
|
||||||
exec_plan_->data[2] = 2;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
~InterpreterFp16() { TfLiteIntArrayFree(exec_plan_); }
|
|
||||||
|
|
||||||
Subgraph* GetSubgraph() { return interpreter_.subgraph(0); }
|
|
||||||
TfLiteIntArray* exec_plan() const { return exec_plan_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
Interpreter interpreter_;
|
|
||||||
TfLiteIntArray* exec_plan_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// **NOTE**: we have several interpreter instances created at global scope to
|
||||||
|
// test *exactly* the GetOpsToReplace function alone, and not the sequence of
|
||||||
|
// function calls that includes GetOpsToReplace when calling
|
||||||
|
// ModifyGraphWithDelegate. A TfLiteContext is needed to test GetOpsToReplace,
|
||||||
|
// but TfLiteContexts intentionally make it difficult to call certain functions
|
||||||
|
// in a non-delegate context (see tensorflow/lite/subgraph/subgraph.cc for
|
||||||
|
// details) We create our own GetExecutionPlan, GetNodeAndRegistration and
|
||||||
|
// PreviewDelegatePartitioning lambdas inside each test, but we can't use local
|
||||||
|
// captures without changing the function signature. Therefore, this test data
|
||||||
|
// lives at global scope in order to be accessible inside the lambda.
|
||||||
|
|
||||||
InterpreterFp16* interpreter_fp16_add_op =
|
InterpreterFp16* interpreter_fp16_add_op =
|
||||||
new InterpreterFp16(kTfLiteBuiltinAdd);
|
new InterpreterFp16(kTfLiteBuiltinAdd);
|
||||||
|
|
||||||
@ -218,7 +262,8 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
|
|||||||
// t0 (FP16) --> Add -> t4
|
// t0 (FP16) --> Add -> t4
|
||||||
// t2 (FP16) --/
|
// t2 (FP16) --/
|
||||||
//
|
//
|
||||||
TfLiteContext* context = interpreter_fp16_add_op->GetSubgraph()->context();
|
TfLiteContext* context = interpreter_fp16_add_op->context();
|
||||||
|
|
||||||
// These functions are meant to be called inside delegates. Swap out
|
// These functions are meant to be called inside delegates. Swap out
|
||||||
// for similar functions to permit direct calling of GetOpsToReplace.
|
// for similar functions to permit direct calling of GetOpsToReplace.
|
||||||
context->GetExecutionPlan = [](struct TfLiteContext* context,
|
context->GetExecutionPlan = [](struct TfLiteContext* context,
|
||||||
@ -229,12 +274,30 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
|
|||||||
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
||||||
TfLiteNode** node,
|
TfLiteNode** node,
|
||||||
TfLiteRegistration** registration) {
|
TfLiteRegistration** registration) {
|
||||||
auto& node_and_reg = interpreter_fp16_add_op->GetSubgraph()
|
auto& node_and_reg =
|
||||||
->nodes_and_registration()[node_index];
|
interpreter_fp16_add_op->nodes_and_registration()[node_index];
|
||||||
*node = &node_and_reg.first;
|
*node = &node_and_reg.first;
|
||||||
*registration = &node_and_reg.second;
|
*registration = &node_and_reg.second;
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
};
|
};
|
||||||
|
context->PreviewDelegatePartitioning =
|
||||||
|
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
|
||||||
|
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
|
||||||
|
auto params = interpreter_fp16_add_op->add_delegate_params();
|
||||||
|
params->nodes_to_replace = TfLiteIntArrayCreate(3);
|
||||||
|
params->nodes_to_replace->data[0] = 0;
|
||||||
|
params->nodes_to_replace->data[1] = 1;
|
||||||
|
params->nodes_to_replace->data[2] = 2;
|
||||||
|
params->input_tensors = TfLiteIntArrayCreate(2);
|
||||||
|
params->input_tensors->data[0] = 0;
|
||||||
|
params->input_tensors->data[1] = 2;
|
||||||
|
params->output_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->output_tensors->data[0] = 4;
|
||||||
|
|
||||||
|
*partition_params_array = interpreter_fp16_add_op->delegate_params();
|
||||||
|
*num_partitions = interpreter_fp16_add_op->num_delegate_params();
|
||||||
|
return kTfLiteOk;
|
||||||
|
};
|
||||||
|
|
||||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
||||||
|
|
||||||
@ -251,17 +314,6 @@ TEST(ModelBuilderTest, GetOpsToReplacePrunesFp16DequantizeNodes) {
|
|||||||
TfLiteIntArrayFree(ops_to_replace);
|
TfLiteIntArrayFree(ops_to_replace);
|
||||||
}
|
}
|
||||||
|
|
||||||
// This interpreter instance is created at global scope to test *exactly*
|
|
||||||
// the GetOpsToReplace function alone, and not the sequence of function calls
|
|
||||||
// that includes GetOpsToReplace when calling ModifyGraphWithDelegate.
|
|
||||||
// A TfLiteContext is needed to test GetOpsToReplace, but TfLiteContexts
|
|
||||||
// intentionally make it difficult to call certain functions in a
|
|
||||||
// non-delegate context (see tensorflow/lite/subgraph/subgraph.cc for details)
|
|
||||||
// We create our own GetExecutionPlan and GetNodeAndRegistration lambdas
|
|
||||||
// inside each test, but we can't use local captures without changing the
|
|
||||||
// function signature. Therefore, this test data lives at global scope
|
|
||||||
// in order to be accessible inside the lambda.
|
|
||||||
|
|
||||||
InterpreterFp16* interpreter_fp16_gt_op =
|
InterpreterFp16* interpreter_fp16_gt_op =
|
||||||
new InterpreterFp16(kTfLiteBuiltinGreater);
|
new InterpreterFp16(kTfLiteBuiltinGreater);
|
||||||
|
|
||||||
@ -274,7 +326,7 @@ TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) {
|
|||||||
// Because there is no GPU equivalent for the Greater op, we don't prune
|
// Because there is no GPU equivalent for the Greater op, we don't prune
|
||||||
// the Dequantize nodes.
|
// the Dequantize nodes.
|
||||||
|
|
||||||
TfLiteContext* context = interpreter_fp16_gt_op->GetSubgraph()->context();
|
TfLiteContext* context = interpreter_fp16_gt_op->context();
|
||||||
// These functions are meant to be called inside delegates. Swap out
|
// These functions are meant to be called inside delegates. Swap out
|
||||||
// for similar functions to permit direct calling of GetOpsToReplace.
|
// for similar functions to permit direct calling of GetOpsToReplace.
|
||||||
context->GetExecutionPlan = [](struct TfLiteContext* context,
|
context->GetExecutionPlan = [](struct TfLiteContext* context,
|
||||||
@ -285,12 +337,37 @@ TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) {
|
|||||||
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
||||||
TfLiteNode** node,
|
TfLiteNode** node,
|
||||||
TfLiteRegistration** registration) {
|
TfLiteRegistration** registration) {
|
||||||
auto& node_and_reg = interpreter_fp16_gt_op->GetSubgraph()
|
auto& node_and_reg =
|
||||||
->nodes_and_registration()[node_index];
|
interpreter_fp16_gt_op->nodes_and_registration()[node_index];
|
||||||
*node = &node_and_reg.first;
|
*node = &node_and_reg.first;
|
||||||
*registration = &node_and_reg.second;
|
*registration = &node_and_reg.second;
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
};
|
};
|
||||||
|
context->PreviewDelegatePartitioning =
|
||||||
|
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
|
||||||
|
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
|
||||||
|
auto params = interpreter_fp16_gt_op->add_delegate_params();
|
||||||
|
// First partition for DequantNode (t0->t1)
|
||||||
|
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
||||||
|
params->nodes_to_replace->data[0] = 0;
|
||||||
|
params->input_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->input_tensors->data[0] = 0;
|
||||||
|
params->output_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->output_tensors->data[0] = 1;
|
||||||
|
|
||||||
|
// Second partition for DequantNode (t2->t3)
|
||||||
|
params = interpreter_fp16_add_op->add_delegate_params();
|
||||||
|
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
||||||
|
params->nodes_to_replace->data[0] = 0;
|
||||||
|
params->input_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->input_tensors->data[0] = 0;
|
||||||
|
params->output_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->output_tensors->data[0] = 1;
|
||||||
|
|
||||||
|
*partition_params_array = interpreter_fp16_gt_op->delegate_params();
|
||||||
|
*num_partitions = interpreter_fp16_gt_op->num_delegate_params();
|
||||||
|
return kTfLiteOk;
|
||||||
|
};
|
||||||
|
|
||||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
||||||
|
|
||||||
@ -309,9 +386,9 @@ TEST(ModelBuilderTest, GetOpsToReplaceKeepsFp16DequantizeNodes) {
|
|||||||
TfLiteIntArrayFree(ops_to_replace);
|
TfLiteIntArrayFree(ops_to_replace);
|
||||||
}
|
}
|
||||||
|
|
||||||
class InterpreterFp32 {
|
class InterpreterFp32 : public DelegatedInterpreter {
|
||||||
public:
|
public:
|
||||||
InterpreterFp32() {
|
InterpreterFp32() : DelegatedInterpreter(2) {
|
||||||
void* builtin_data = malloc(sizeof(int));
|
void* builtin_data = malloc(sizeof(int));
|
||||||
EXPECT_EQ(interpreter_.AddTensors(4), kTfLiteOk);
|
EXPECT_EQ(interpreter_.AddTensors(4), kTfLiteOk);
|
||||||
EXPECT_EQ(interpreter_.SetInputs({0, 2}), kTfLiteOk);
|
EXPECT_EQ(interpreter_.SetInputs({0, 2}), kTfLiteOk);
|
||||||
@ -363,34 +440,24 @@ class InterpreterFp32 {
|
|||||||
interpreter_.SetTensorParametersReadWrite(
|
interpreter_.SetTensorParametersReadWrite(
|
||||||
2, TfLiteType::kTfLiteFloat32, "t2", dims, quantization, false),
|
2, TfLiteType::kTfLiteFloat32, "t2", dims, quantization, false),
|
||||||
kTfLiteOk);
|
kTfLiteOk);
|
||||||
exec_plan_ = TfLiteIntArrayCreate(2);
|
|
||||||
exec_plan_->data[0] = 0;
|
exec_plan()->data[0] = 0;
|
||||||
exec_plan_->data[1] = 1;
|
exec_plan()->data[1] = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
~InterpreterFp32() { TfLiteIntArrayFree(exec_plan_); }
|
|
||||||
|
|
||||||
Subgraph* GetSubgraph() { return interpreter_.subgraph(0); }
|
|
||||||
TfLiteIntArray* exec_plan() const { return exec_plan_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
Interpreter interpreter_;
|
|
||||||
TfLiteIntArray* exec_plan_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
InterpreterFp32* interpreter_fp32 = new InterpreterFp32();
|
InterpreterFp32* interpreter_fp32 = new InterpreterFp32();
|
||||||
|
|
||||||
TEST(ModelBuilderTest, GetOpsToReplaceDoesNotPruneUint8) {
|
TEST(ModelBuilderTest, GetOpsToReplaceDoesNotPruneUint8) {
|
||||||
// A graph with a Dequant node with uint8 input
|
// A graph with a Dequant node with uint8 input is not pruned. As this op is
|
||||||
// is not pruned. The delegate will attempt to replace it
|
// currently not supported on the GPU. Therefore, the Dequant op will be
|
||||||
// with a GPU op, but this op is currently not supported on
|
// scheduled to run on the CPU while the remaining supported op Add on the
|
||||||
// the GPU. Therefore, the Dequant op and all downstream ops
|
// GPU.
|
||||||
// will be scheduled to run on the CPU.
|
|
||||||
//
|
//
|
||||||
// t0 (uint8) --> Dequant --> t1 (FP32) --> Add -> t3
|
// t0 (uint8) --> Dequant --> t1 (FP32) --> Add -> t3
|
||||||
// t2 (FP32) --/
|
// t2 (FP32) --/
|
||||||
//
|
//
|
||||||
TfLiteContext* context = interpreter_fp32->GetSubgraph()->context();
|
TfLiteContext* context = interpreter_fp32->context();
|
||||||
|
|
||||||
// These functions are meant to be called inside delegates. Swap out
|
// These functions are meant to be called inside delegates. Swap out
|
||||||
// for similar functions to permit direct calling of GetOpsToReplace.
|
// for similar functions to permit direct calling of GetOpsToReplace.
|
||||||
@ -402,24 +469,43 @@ TEST(ModelBuilderTest, GetOpsToReplaceDoesNotPruneUint8) {
|
|||||||
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
||||||
TfLiteNode** node,
|
TfLiteNode** node,
|
||||||
TfLiteRegistration** registration) {
|
TfLiteRegistration** registration) {
|
||||||
auto& node_and_reg =
|
auto& node_and_reg = interpreter_fp32->nodes_and_registration()[node_index];
|
||||||
interpreter_fp32->GetSubgraph()->nodes_and_registration()[node_index];
|
|
||||||
*node = &node_and_reg.first;
|
*node = &node_and_reg.first;
|
||||||
*registration = &node_and_reg.second;
|
*registration = &node_and_reg.second;
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
};
|
};
|
||||||
|
context->PreviewDelegatePartitioning =
|
||||||
|
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
|
||||||
|
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
|
||||||
|
auto params = interpreter_fp32->add_delegate_params();
|
||||||
|
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
||||||
|
params->nodes_to_replace->data[0] = 1;
|
||||||
|
params->input_tensors = TfLiteIntArrayCreate(2);
|
||||||
|
params->input_tensors->data[0] = 1;
|
||||||
|
params->input_tensors->data[1] = 2;
|
||||||
|
params->output_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->output_tensors->data[0] = 3;
|
||||||
|
|
||||||
|
*partition_params_array = interpreter_fp32->delegate_params();
|
||||||
|
*num_partitions = interpreter_fp32->num_delegate_params();
|
||||||
|
return kTfLiteOk;
|
||||||
|
};
|
||||||
|
|
||||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
||||||
|
|
||||||
// No ops are run on the GPU, since the Dequant op is not pruned and must run
|
// As the Dequant op is not pruned and the ADD op could run on GPU, we have
|
||||||
// on the CPU.
|
// 1 partition.
|
||||||
EXPECT_EQ(ops_to_replace->size, 0);
|
EXPECT_EQ(ops_to_replace->size, 1);
|
||||||
|
// ADD at index 1.
|
||||||
|
EXPECT_EQ(1, ops_to_replace->data[0]);
|
||||||
|
|
||||||
TfLiteIntArrayFree(ops_to_replace);
|
TfLiteIntArrayFree(ops_to_replace);
|
||||||
}
|
}
|
||||||
|
|
||||||
class InterpreterMultiNode {
|
class InterpreterMultiNode : public DelegatedInterpreter {
|
||||||
public:
|
public:
|
||||||
explicit InterpreterMultiNode(bool add_op_first = true) {
|
explicit InterpreterMultiNode(bool add_op_first = true)
|
||||||
|
: DelegatedInterpreter(5) {
|
||||||
void* builtin_data = malloc(sizeof(int));
|
void* builtin_data = malloc(sizeof(int));
|
||||||
EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk);
|
EXPECT_EQ(interpreter_.AddTensors(8), kTfLiteOk);
|
||||||
EXPECT_EQ(interpreter_.SetInputs({0, 1, 2}), kTfLiteOk);
|
EXPECT_EQ(interpreter_.SetInputs({0, 1, 2}), kTfLiteOk);
|
||||||
@ -552,38 +638,29 @@ class InterpreterMultiNode {
|
|||||||
interpreter_.SetTensorParametersReadWrite(
|
interpreter_.SetTensorParametersReadWrite(
|
||||||
7, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
|
7, TfLiteType::kTfLiteFloat32, "t5", dims, quantization, false),
|
||||||
kTfLiteOk);
|
kTfLiteOk);
|
||||||
exec_plan_ = TfLiteIntArrayCreate(5);
|
|
||||||
exec_plan_->data[0] = 0;
|
exec_plan()->data[0] = 0;
|
||||||
exec_plan_->data[1] = 1;
|
exec_plan()->data[1] = 1;
|
||||||
exec_plan_->data[2] = 2;
|
exec_plan()->data[2] = 2;
|
||||||
exec_plan_->data[3] = 3;
|
exec_plan()->data[3] = 3;
|
||||||
exec_plan_->data[4] = 4;
|
exec_plan()->data[4] = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
~InterpreterMultiNode() { TfLiteIntArrayFree(exec_plan_); }
|
|
||||||
|
|
||||||
Subgraph* GetSubgraph() { return interpreter_.subgraph(0); }
|
|
||||||
TfLiteIntArray* exec_plan() const { return exec_plan_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
Interpreter interpreter_;
|
|
||||||
TfLiteIntArray* exec_plan_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
InterpreterMultiNode* interpreter_mn = new InterpreterMultiNode();
|
InterpreterMultiNode* interpreter_mn = new InterpreterMultiNode();
|
||||||
|
|
||||||
TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequants) {
|
TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsAddFirst) {
|
||||||
// A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'.
|
// A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'.
|
||||||
// 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
|
// 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
|
||||||
// t0 (FP16) --> Dequant --> t3 (FP32) --> Greater -> t6
|
// t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(4) -> t6
|
||||||
// t1 (FP16) --> Dequant --> t4 (FP32) --/
|
// t1 (FP16) --> Dequant(1) --> t4 (FP32) --/
|
||||||
// --\
|
// --\
|
||||||
// t3 (FP16) --> Dequant --> t5 (FP32) --> Add -> t7
|
// t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(3) -> t7
|
||||||
//
|
//
|
||||||
// OpsToReplace should replace the 'Add' op and the Dequant outputing
|
// OpsToReplace should replace the 'Add' op and the Dequant outputing
|
||||||
// t5, but leave the other Dequant nodes because 'Greater' must run
|
// t5, but leave the other Dequant nodes because 'Greater' must run
|
||||||
// on the CPU.
|
// on the CPU.
|
||||||
TfLiteContext* context = interpreter_mn->GetSubgraph()->context();
|
TfLiteContext* context = interpreter_mn->context();
|
||||||
|
|
||||||
// These functions are meant to be called inside delegates. Swap out
|
// These functions are meant to be called inside delegates. Swap out
|
||||||
// for similar functions to permit direct calling of GetOpsToReplace.
|
// for similar functions to permit direct calling of GetOpsToReplace.
|
||||||
@ -595,12 +672,48 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequants) {
|
|||||||
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
||||||
TfLiteNode** node,
|
TfLiteNode** node,
|
||||||
TfLiteRegistration** registration) {
|
TfLiteRegistration** registration) {
|
||||||
auto& node_and_reg =
|
auto& node_and_reg = interpreter_mn->nodes_and_registration()[node_index];
|
||||||
interpreter_mn->GetSubgraph()->nodes_and_registration()[node_index];
|
|
||||||
*node = &node_and_reg.first;
|
*node = &node_and_reg.first;
|
||||||
*registration = &node_and_reg.second;
|
*registration = &node_and_reg.second;
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
};
|
};
|
||||||
|
context->PreviewDelegatePartitioning =
|
||||||
|
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
|
||||||
|
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
|
||||||
|
auto params = interpreter_mn->add_delegate_params();
|
||||||
|
// First partition for DequantNode(0)
|
||||||
|
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
||||||
|
params->nodes_to_replace->data[0] = 0;
|
||||||
|
params->input_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->input_tensors->data[0] = 0;
|
||||||
|
params->output_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->output_tensors->data[0] = 3;
|
||||||
|
|
||||||
|
// Second partition for DequantNode(1)
|
||||||
|
params = interpreter_mn->add_delegate_params();
|
||||||
|
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
||||||
|
params->nodes_to_replace->data[0] = 1;
|
||||||
|
params->input_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->input_tensors->data[0] = 1;
|
||||||
|
params->output_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->output_tensors->data[0] = 4;
|
||||||
|
|
||||||
|
// Third partition for DequantNode(1), DequantNode(2), ADD(3)
|
||||||
|
params = interpreter_mn->add_delegate_params();
|
||||||
|
params->nodes_to_replace = TfLiteIntArrayCreate(3);
|
||||||
|
params->nodes_to_replace->data[0] = 1;
|
||||||
|
params->nodes_to_replace->data[1] = 2;
|
||||||
|
params->nodes_to_replace->data[2] = 3;
|
||||||
|
params->input_tensors = TfLiteIntArrayCreate(2);
|
||||||
|
params->input_tensors->data[0] = 1;
|
||||||
|
params->input_tensors->data[0] = 3;
|
||||||
|
params->output_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->output_tensors->data[0] = 7;
|
||||||
|
|
||||||
|
*partition_params_array = interpreter_mn->delegate_params();
|
||||||
|
*num_partitions = interpreter_mn->num_delegate_params();
|
||||||
|
return kTfLiteOk;
|
||||||
|
};
|
||||||
|
|
||||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
||||||
|
|
||||||
@ -624,19 +737,22 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequants) {
|
|||||||
|
|
||||||
InterpreterMultiNode* interpreter_mn2 =
|
InterpreterMultiNode* interpreter_mn2 =
|
||||||
new InterpreterMultiNode(/*add_op_first=*/false);
|
new InterpreterMultiNode(/*add_op_first=*/false);
|
||||||
|
TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectDequantsGreaterFirst) {
|
||||||
TEST(ModelBuilderTest, GetOpsToReplaceRestoresInputsOnErrors) {
|
// A graph with three Dequant nodes feeding two ops, 'Add' and 'Greater'.
|
||||||
// A graph with three Dequant nodes feeding two ops, 'Greater' and 'Add'.
|
|
||||||
// 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
|
// 'Add' can be replaced by the GPU delegate, but 'Greater' can not.
|
||||||
// t0 (FP16) --> Dequant --> t3 (FP32) --> Greater -> t6
|
// t0 (FP16) --> Dequant(0) --> t3 (FP32) --> Greater(3) -> t6
|
||||||
// t1 (FP16) --> Dequant --> t4 (FP32) --/
|
// t1 (FP16) --> Dequant(1) --> t4 (FP32) --/
|
||||||
// --\
|
// --\
|
||||||
// t3 (FP16) --> Dequant --> t5 (FP32) --> Add -> t7
|
// t3 (FP16) --> Dequant(2) --> t5 (FP32) --> Add(4) -> t7
|
||||||
//
|
//
|
||||||
// 'Greater' comes first in the execution plan though, so Add should not
|
// Note: the graph dependency is exactly same w/ that in
|
||||||
// be scheduled to run on the Gpu. Further, it's inputs should remain t4
|
// GetOpsToReplaceSelectsCorrectDequantsAddFirst, but the unsupported
|
||||||
// and t5.
|
// 'Greater' op appears first in the execution plan. Despite this,
|
||||||
TfLiteContext* context = interpreter_mn->GetSubgraph()->context();
|
// OpsToReplace should still replace the 'Add' op and the Dequant outputing
|
||||||
|
// t5, but leave the other Dequant nodes because 'Greater' must run
|
||||||
|
// on the CPU.
|
||||||
|
|
||||||
|
TfLiteContext* context = interpreter_mn2->context();
|
||||||
|
|
||||||
// These functions are meant to be called inside delegates. Swap out
|
// These functions are meant to be called inside delegates. Swap out
|
||||||
// for similar functions to permit direct calling of GetOpsToReplace.
|
// for similar functions to permit direct calling of GetOpsToReplace.
|
||||||
@ -648,27 +764,67 @@ TEST(ModelBuilderTest, GetOpsToReplaceRestoresInputsOnErrors) {
|
|||||||
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
||||||
TfLiteNode** node,
|
TfLiteNode** node,
|
||||||
TfLiteRegistration** registration) {
|
TfLiteRegistration** registration) {
|
||||||
auto& node_and_reg =
|
auto& node_and_reg = interpreter_mn2->nodes_and_registration()[node_index];
|
||||||
interpreter_mn2->GetSubgraph()->nodes_and_registration()[node_index];
|
|
||||||
*node = &node_and_reg.first;
|
*node = &node_and_reg.first;
|
||||||
*registration = &node_and_reg.second;
|
*registration = &node_and_reg.second;
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
context->PreviewDelegatePartitioning =
|
||||||
|
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
|
||||||
|
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
|
||||||
|
auto params = interpreter_mn2->add_delegate_params();
|
||||||
|
// First partition for DequantNode(0)
|
||||||
|
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
||||||
|
params->nodes_to_replace->data[0] = 0;
|
||||||
|
params->input_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->input_tensors->data[0] = 0;
|
||||||
|
params->output_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->output_tensors->data[0] = 3;
|
||||||
|
|
||||||
|
// Second partition for DequantNode(1)
|
||||||
|
params = interpreter_mn2->add_delegate_params();
|
||||||
|
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
||||||
|
params->nodes_to_replace->data[0] = 1;
|
||||||
|
params->input_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->input_tensors->data[0] = 1;
|
||||||
|
params->output_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->output_tensors->data[0] = 4;
|
||||||
|
|
||||||
|
// Third partition for DequantNode(1), DequantNode(2), ADD(4)
|
||||||
|
params = interpreter_mn2->add_delegate_params();
|
||||||
|
params->nodes_to_replace = TfLiteIntArrayCreate(3);
|
||||||
|
params->nodes_to_replace->data[0] = 1;
|
||||||
|
params->nodes_to_replace->data[1] = 2;
|
||||||
|
params->nodes_to_replace->data[2] = 4;
|
||||||
|
params->input_tensors = TfLiteIntArrayCreate(2);
|
||||||
|
params->input_tensors->data[0] = 1;
|
||||||
|
params->input_tensors->data[0] = 3;
|
||||||
|
params->output_tensors = TfLiteIntArrayCreate(1);
|
||||||
|
params->output_tensors->data[0] = 7;
|
||||||
|
|
||||||
|
*partition_params_array = interpreter_mn2->delegate_params();
|
||||||
|
*num_partitions = interpreter_mn2->num_delegate_params();
|
||||||
|
return kTfLiteOk;
|
||||||
|
};
|
||||||
|
|
||||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
||||||
|
|
||||||
// Verify that no ops will be replaced.
|
EXPECT_EQ(ops_to_replace->size, 2);
|
||||||
EXPECT_EQ(ops_to_replace->size, 0);
|
// Op at index 2 is the Dequant op (t3 -> t5).
|
||||||
|
EXPECT_EQ(ops_to_replace->data[0], 2);
|
||||||
|
// Op at index 4 is the Add op.
|
||||||
|
EXPECT_EQ(ops_to_replace->data[1], 4);
|
||||||
|
|
||||||
TfLiteNode* node = nullptr;
|
TfLiteNode* node = nullptr;
|
||||||
TfLiteRegistration* registration = nullptr;
|
TfLiteRegistration* registration = nullptr;
|
||||||
// Verify that Add op has fp32 inputs.
|
// Verify that Add op has fp16 inputs.
|
||||||
context->GetNodeAndRegistration(context, 4, &node, ®istration);
|
context->GetNodeAndRegistration(context, ops_to_replace->data[1], &node,
|
||||||
EXPECT_EQ(registration->builtin_code, kTfLiteBuiltinAdd);
|
®istration);
|
||||||
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
|
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
|
||||||
TfLiteType::kTfLiteFloat32);
|
TfLiteType::kTfLiteFloat16);
|
||||||
EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
|
EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
|
||||||
TfLiteType::kTfLiteFloat32);
|
TfLiteType::kTfLiteFloat16);
|
||||||
TfLiteIntArrayFree(ops_to_replace);
|
TfLiteIntArrayFree(ops_to_replace);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#include <complex>
|
#include <complex>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/builtin_ops.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
@ -131,4 +132,15 @@ bool IsUnresolvedCustomOp(const TfLiteRegistration& registration) {
|
|||||||
registration.invoke == &UnresolvedOpInvoke;
|
registration.invoke == &UnresolvedOpInvoke;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string GetOpNameByRegistration(const TfLiteRegistration& registration) {
|
||||||
|
auto op = registration.builtin_code;
|
||||||
|
std::string result =
|
||||||
|
EnumNameBuiltinOperator(static_cast<BuiltinOperator>(op));
|
||||||
|
if ((op == kTfLiteBuiltinCustom || op == kTfLiteBuiltinDelegate) &&
|
||||||
|
registration.custom_name) {
|
||||||
|
result += " " + std::string(registration.custom_name);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_UTIL_H_
|
#ifndef TENSORFLOW_LITE_UTIL_H_
|
||||||
#define TENSORFLOW_LITE_UTIL_H_
|
#define TENSORFLOW_LITE_UTIL_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
@ -73,6 +74,8 @@ TfLiteRegistration CreateUnresolvedCustomOp(const char* custom_op_name);
|
|||||||
// Checks whether the provided op is an unresolved custom op.
|
// Checks whether the provided op is an unresolved custom op.
|
||||||
bool IsUnresolvedCustomOp(const TfLiteRegistration& registration);
|
bool IsUnresolvedCustomOp(const TfLiteRegistration& registration);
|
||||||
|
|
||||||
|
// Returns a descriptive name with the given op TfLiteRegistration.
|
||||||
|
std::string GetOpNameByRegistration(const TfLiteRegistration& registration);
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_UTIL_H_
|
#endif // TENSORFLOW_LITE_UTIL_H_
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <gmock/gmock.h>
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace {
|
namespace {
|
||||||
@ -90,6 +91,32 @@ TEST(CombineHashes, TestHashOutputsDifferent) {
|
|||||||
EXPECT_NE(output1, output2);
|
EXPECT_NE(output1, output2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(GetOpNameByRegistration, ValidBuiltinCode) {
|
||||||
|
TfLiteRegistration registration;
|
||||||
|
registration.builtin_code = tflite::BuiltinOperator_ADD;
|
||||||
|
const auto op_name = GetOpNameByRegistration(registration);
|
||||||
|
EXPECT_EQ("ADD", op_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GetOpNameByRegistration, InvalidBuiltinCode) {
|
||||||
|
TfLiteRegistration registration;
|
||||||
|
registration.builtin_code = -1;
|
||||||
|
const auto op_name = GetOpNameByRegistration(registration);
|
||||||
|
EXPECT_EQ("", op_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GetOpNameByRegistration, CustomName) {
|
||||||
|
TfLiteRegistration registration;
|
||||||
|
registration.builtin_code = tflite::BuiltinOperator_CUSTOM;
|
||||||
|
registration.custom_name = "TestOp";
|
||||||
|
auto op_name = GetOpNameByRegistration(registration);
|
||||||
|
EXPECT_EQ("CUSTOM TestOp", op_name);
|
||||||
|
|
||||||
|
registration.builtin_code = tflite::BuiltinOperator_DELEGATE;
|
||||||
|
registration.custom_name = "TestDelegate";
|
||||||
|
op_name = GetOpNameByRegistration(registration);
|
||||||
|
EXPECT_EQ("DELEGATE TestDelegate", op_name);
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user