Add more documentation to classes/functions and remove some dead code.
PiperOrigin-RevId: 314821136 Change-Id: I2e19f4b68939fd972914471c4373044ac45dc1e8
This commit is contained in:
parent
ec7e854ca9
commit
6d98a59237
@ -32,6 +32,7 @@ namespace tflite {
|
||||
namespace delegates {
|
||||
namespace hexagon {
|
||||
|
||||
// Wrapper that holds all data representing a single node in the Hexagon graph.
|
||||
struct OpNode {
|
||||
std::vector<hexagon_nn_input> inputs;
|
||||
std::vector<hexagon_nn_output> outputs;
|
||||
@ -48,6 +49,15 @@ struct OpNode {
|
||||
|
||||
class GraphBuilder;
|
||||
|
||||
// Represents a single Op in the TFLite graph.
|
||||
// For each op in TFLite there should be an OpBuidler, this builder is
|
||||
// responsible for constructing equivalent node(s) in the hexagon graph. A
|
||||
// single builder can create one or more ops in the hexagon graph. When adding
|
||||
// new op* users should inherit from this class and implement
|
||||
// - PopulateSubgraph: which given inputs/outputs should construct the
|
||||
// equivalent hexagon nodes.
|
||||
// - RegisterOutputs: Which should have logic that maps final outputs from a
|
||||
// given node to the equivalent in Hexagon graph.
|
||||
class OpBuilder {
|
||||
public:
|
||||
// Const representing the shape of a scalar value.
|
||||
@ -65,30 +75,59 @@ class OpBuilder {
|
||||
|
||||
virtual ~OpBuilder() {}
|
||||
|
||||
// TODO(karimnosseir): Do we need to have builder pattern, or they are few not
|
||||
// worth it ?
|
||||
// Sets the op type in the hexagon graph.
|
||||
void SetOpType(int op_type) { op_node_.op_type = op_type; }
|
||||
|
||||
// Sets the node id in the hexagon graph.
|
||||
void SetNodeId(int node_id) { op_node_.node_id = node_id; }
|
||||
|
||||
// Sets the TfLite node index in the TfLite graph.
|
||||
void SetTFLiteNodeId(int node_index) {
|
||||
op_node_.tflite_node_index = node_index;
|
||||
}
|
||||
|
||||
// Marks this node as Const node.
|
||||
void SetConstNode() { op_node_.op_type = OP_Const; }
|
||||
|
||||
// Sets the padding type of the current node.
|
||||
void SetPaddingType(hexagon_nn_padding_type padding_type) {
|
||||
op_node_.padding_type = padding_type;
|
||||
}
|
||||
|
||||
// Sets the builtin_data of TFLite node that this Builder is responsible for.
|
||||
void SetBuiltinData(void* builtin_data) { builtin_data_ = builtin_data; }
|
||||
|
||||
// Returns true if the current op is a const Op.
|
||||
bool IsConstNode() const { return op_node_.op_type == OP_Const; }
|
||||
|
||||
void print() {}
|
||||
// Subclasses should override it and have logic which handles initializing
|
||||
// hexagon node(s) for the current op, given 'inputs' 'outputs'
|
||||
virtual TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||
const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Subclasses should override it and register the final output(s) from the
|
||||
// node to the equivalent in hexagon graph.
|
||||
virtual TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Constructs OpNode which represents a node in the Hexagon graph.
|
||||
const OpNode* Build();
|
||||
|
||||
// Returns the Node index in TFLite graph.
|
||||
int GetTFLiteNodeID() const { return op_node_.tflite_node_index; }
|
||||
|
||||
// Returns the Op type of the current Op (in Hexagon graph)
|
||||
int GetOpType() const { return op_node_.op_type; }
|
||||
|
||||
// Returns the node id in the hexagon graph.
|
||||
int GetID() const { return op_node_.node_id; }
|
||||
|
||||
// Adds tensor identified by 'tensor_id' as input to the current Op.
|
||||
void AddInput(const TensorID& tensor_id) { input_ids_.push_back(tensor_id); }
|
||||
|
||||
// Adds Output to the current node, the output has shape defined in 'dims'.
|
||||
@ -106,25 +145,13 @@ class OpBuilder {
|
||||
// Same as above but accepts pointer instead of std::vector.
|
||||
TensorID AddOutput(int elementsize, int rank, const int* max_sizes_vect);
|
||||
|
||||
int GetID() const { return op_node_.node_id; }
|
||||
|
||||
int GetTFLiteNodeID() const { return op_node_.tflite_node_index; }
|
||||
|
||||
int GetOpType() const { return op_node_.op_type; }
|
||||
|
||||
// Sets the node that corresponds to this builder in TFLite graph.
|
||||
void SetTfLiteNode(const TfLiteNode* node) { tflite_node_ = node; }
|
||||
|
||||
virtual TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||
const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
virtual TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs,
|
||||
TfLiteContext* context) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Static
|
||||
// Computes the min/max values of 'tensor' and sets the values in
|
||||
// the out params 'min' and 'max'.
|
||||
// Returns kTfLiteOk on success.
|
||||
static TfLiteStatus ComputeMinAndMaxQuantValues(const TfLiteTensor& tensor,
|
||||
float* min, float* max) {
|
||||
if (tensor.type == kTfLiteUInt8) {
|
||||
|
Loading…
Reference in New Issue
Block a user