Add more documentation to classes/functions and remove some dead code.
PiperOrigin-RevId: 314821136 Change-Id: I2e19f4b68939fd972914471c4373044ac45dc1e8
This commit is contained in:
parent
ec7e854ca9
commit
6d98a59237
@ -32,6 +32,7 @@ namespace tflite {
|
|||||||
namespace delegates {
|
namespace delegates {
|
||||||
namespace hexagon {
|
namespace hexagon {
|
||||||
|
|
||||||
|
// Wrapper that holds all data representing a single node in the Hexagon graph.
|
||||||
struct OpNode {
|
struct OpNode {
|
||||||
std::vector<hexagon_nn_input> inputs;
|
std::vector<hexagon_nn_input> inputs;
|
||||||
std::vector<hexagon_nn_output> outputs;
|
std::vector<hexagon_nn_output> outputs;
|
||||||
@ -48,6 +49,15 @@ struct OpNode {
|
|||||||
|
|
||||||
class GraphBuilder;
|
class GraphBuilder;
|
||||||
|
|
||||||
|
// Represents a single Op in the TFLite graph.
|
||||||
|
// For each op in TFLite there should be an OpBuidler, this builder is
|
||||||
|
// responsible for constructing equivalent node(s) in the hexagon graph. A
|
||||||
|
// single builder can create one or more ops in the hexagon graph. When adding
|
||||||
|
// new op* users should inherit from this class and implement
|
||||||
|
// - PopulateSubgraph: which given inputs/outputs should construct the
|
||||||
|
// equivalent hexagon nodes.
|
||||||
|
// - RegisterOutputs: Which should have logic that maps final outputs from a
|
||||||
|
// given node to the equivalent in Hexagon graph.
|
||||||
class OpBuilder {
|
class OpBuilder {
|
||||||
public:
|
public:
|
||||||
// Const representing the shape of a scalar value.
|
// Const representing the shape of a scalar value.
|
||||||
@ -65,30 +75,59 @@ class OpBuilder {
|
|||||||
|
|
||||||
virtual ~OpBuilder() {}
|
virtual ~OpBuilder() {}
|
||||||
|
|
||||||
// TODO(karimnosseir): Do we need to have builder pattern, or they are few not
|
// Sets the op type in the hexagon graph.
|
||||||
// worth it ?
|
|
||||||
void SetOpType(int op_type) { op_node_.op_type = op_type; }
|
void SetOpType(int op_type) { op_node_.op_type = op_type; }
|
||||||
|
|
||||||
|
// Sets the node id in the hexagon graph.
|
||||||
void SetNodeId(int node_id) { op_node_.node_id = node_id; }
|
void SetNodeId(int node_id) { op_node_.node_id = node_id; }
|
||||||
|
|
||||||
|
// Sets the TfLite node index in the TfLite graph.
|
||||||
void SetTFLiteNodeId(int node_index) {
|
void SetTFLiteNodeId(int node_index) {
|
||||||
op_node_.tflite_node_index = node_index;
|
op_node_.tflite_node_index = node_index;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Marks this node as Const node.
|
||||||
void SetConstNode() { op_node_.op_type = OP_Const; }
|
void SetConstNode() { op_node_.op_type = OP_Const; }
|
||||||
|
|
||||||
|
// Sets the padding type of the current node.
|
||||||
void SetPaddingType(hexagon_nn_padding_type padding_type) {
|
void SetPaddingType(hexagon_nn_padding_type padding_type) {
|
||||||
op_node_.padding_type = padding_type;
|
op_node_.padding_type = padding_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sets the builtin_data of TFLite node that this Builder is responsible for.
|
||||||
void SetBuiltinData(void* builtin_data) { builtin_data_ = builtin_data; }
|
void SetBuiltinData(void* builtin_data) { builtin_data_ = builtin_data; }
|
||||||
|
|
||||||
|
// Returns true if the current op is a const Op.
|
||||||
bool IsConstNode() const { return op_node_.op_type == OP_Const; }
|
bool IsConstNode() const { return op_node_.op_type == OP_Const; }
|
||||||
|
|
||||||
void print() {}
|
// Subclasses should override it and have logic which handles initializing
|
||||||
|
// hexagon node(s) for the current op, given 'inputs' 'outputs'
|
||||||
|
virtual TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||||
|
const TfLiteIntArray* outputs,
|
||||||
|
TfLiteContext* context) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subclasses should override it and register the final output(s) from the
|
||||||
|
// node to the equivalent in hexagon graph.
|
||||||
|
virtual TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs,
|
||||||
|
TfLiteContext* context) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Constructs OpNode which represents a node in the Hexagon graph.
|
||||||
const OpNode* Build();
|
const OpNode* Build();
|
||||||
|
|
||||||
|
// Returns the Node index in TFLite graph.
|
||||||
|
int GetTFLiteNodeID() const { return op_node_.tflite_node_index; }
|
||||||
|
|
||||||
|
// Returns the Op type of the current Op (in Hexagon graph)
|
||||||
|
int GetOpType() const { return op_node_.op_type; }
|
||||||
|
|
||||||
|
// Returns the node id in the hexagon graph.
|
||||||
|
int GetID() const { return op_node_.node_id; }
|
||||||
|
|
||||||
|
// Adds tensor identified by 'tensor_id' as input to the current Op.
|
||||||
void AddInput(const TensorID& tensor_id) { input_ids_.push_back(tensor_id); }
|
void AddInput(const TensorID& tensor_id) { input_ids_.push_back(tensor_id); }
|
||||||
|
|
||||||
// Adds Output to the current node, the output has shape defined in 'dims'.
|
// Adds Output to the current node, the output has shape defined in 'dims'.
|
||||||
@ -106,25 +145,13 @@ class OpBuilder {
|
|||||||
// Same as above but accepts pointer instead of std::vector.
|
// Same as above but accepts pointer instead of std::vector.
|
||||||
TensorID AddOutput(int elementsize, int rank, const int* max_sizes_vect);
|
TensorID AddOutput(int elementsize, int rank, const int* max_sizes_vect);
|
||||||
|
|
||||||
int GetID() const { return op_node_.node_id; }
|
// Sets the node that corresponds to this builder in TFLite graph.
|
||||||
|
|
||||||
int GetTFLiteNodeID() const { return op_node_.tflite_node_index; }
|
|
||||||
|
|
||||||
int GetOpType() const { return op_node_.op_type; }
|
|
||||||
|
|
||||||
void SetTfLiteNode(const TfLiteNode* node) { tflite_node_ = node; }
|
void SetTfLiteNode(const TfLiteNode* node) { tflite_node_ = node; }
|
||||||
|
|
||||||
virtual TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs,
|
// Static
|
||||||
const TfLiteIntArray* outputs,
|
// Computes the min/max values of 'tensor' and sets the values in
|
||||||
TfLiteContext* context) {
|
// the out params 'min' and 'max'.
|
||||||
return kTfLiteOk;
|
// Returns kTfLiteOk on success.
|
||||||
}
|
|
||||||
|
|
||||||
virtual TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs,
|
|
||||||
TfLiteContext* context) {
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
|
|
||||||
static TfLiteStatus ComputeMinAndMaxQuantValues(const TfLiteTensor& tensor,
|
static TfLiteStatus ComputeMinAndMaxQuantValues(const TfLiteTensor& tensor,
|
||||||
float* min, float* max) {
|
float* min, float* max) {
|
||||||
if (tensor.type == kTfLiteUInt8) {
|
if (tensor.type == kTfLiteUInt8) {
|
||||||
|
Loading…
Reference in New Issue
Block a user