Add more documentation to classes/functions and remove some dead code.

PiperOrigin-RevId: 314821136
Change-Id: I2e19f4b68939fd972914471c4373044ac45dc1e8
This commit is contained in:
Karim Nosir 2020-06-04 16:01:52 -07:00 committed by TensorFlower Gardener
parent ec7e854ca9
commit 6d98a59237

View File

@ -32,6 +32,7 @@ namespace tflite {
namespace delegates {
namespace hexagon {
// Wrapper that holds all data representing a single node in the Hexagon graph.
struct OpNode {
std::vector<hexagon_nn_input> inputs;
std::vector<hexagon_nn_output> outputs;
@ -48,6 +49,15 @@ struct OpNode {
class GraphBuilder;
// Represents a single Op in the TFLite graph.
// For each op in TFLite there should be an OpBuidler, this builder is
// responsible for constructing equivalent node(s) in the hexagon graph. A
// single builder can create one or more ops in the hexagon graph. When adding
// new op* users should inherit from this class and implement
// - PopulateSubgraph: which given inputs/outputs should construct the
// equivalent hexagon nodes.
// - RegisterOutputs: Which should have logic that maps final outputs from a
// given node to the equivalent in Hexagon graph.
class OpBuilder {
public:
// Const representing the shape of a scalar value.
@ -65,30 +75,59 @@ class OpBuilder {
virtual ~OpBuilder() {}
// TODO(karimnosseir): Do we need to have builder pattern, or they are few not
// worth it ?
// Sets the op type in the hexagon graph.
void SetOpType(int op_type) { op_node_.op_type = op_type; }
// Sets the node id in the hexagon graph.
void SetNodeId(int node_id) { op_node_.node_id = node_id; }
// Sets the TfLite node index in the TfLite graph.
void SetTFLiteNodeId(int node_index) {
op_node_.tflite_node_index = node_index;
}
// Marks this node as Const node.
void SetConstNode() { op_node_.op_type = OP_Const; }
// Sets the padding type of the current node.
void SetPaddingType(hexagon_nn_padding_type padding_type) {
op_node_.padding_type = padding_type;
}
// Sets the builtin_data of TFLite node that this Builder is responsible for.
void SetBuiltinData(void* builtin_data) { builtin_data_ = builtin_data; }
// Returns true if the current op is a const Op.
bool IsConstNode() const { return op_node_.op_type == OP_Const; }
void print() {}
// Subclasses should override it and have logic which handles initializing
// hexagon node(s) for the current op, given 'inputs' 'outputs'
virtual TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs,
const TfLiteIntArray* outputs,
TfLiteContext* context) {
return kTfLiteOk;
}
// Subclasses should override it and register the final output(s) from the
// node to the equivalent in hexagon graph.
virtual TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs,
TfLiteContext* context) {
return kTfLiteOk;
}
// Constructs OpNode which represents a node in the Hexagon graph.
const OpNode* Build();
// Returns the Node index in TFLite graph.
int GetTFLiteNodeID() const { return op_node_.tflite_node_index; }
// Returns the Op type of the current Op (in Hexagon graph)
int GetOpType() const { return op_node_.op_type; }
// Returns the node id in the hexagon graph.
int GetID() const { return op_node_.node_id; }
// Adds tensor identified by 'tensor_id' as input to the current Op.
void AddInput(const TensorID& tensor_id) { input_ids_.push_back(tensor_id); }
// Adds Output to the current node, the output has shape defined in 'dims'.
@ -106,25 +145,13 @@ class OpBuilder {
// Same as above but accepts pointer instead of std::vector.
TensorID AddOutput(int elementsize, int rank, const int* max_sizes_vect);
int GetID() const { return op_node_.node_id; }
int GetTFLiteNodeID() const { return op_node_.tflite_node_index; }
int GetOpType() const { return op_node_.op_type; }
// Sets the node that corresponds to this builder in TFLite graph.
void SetTfLiteNode(const TfLiteNode* node) { tflite_node_ = node; }
virtual TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs,
const TfLiteIntArray* outputs,
TfLiteContext* context) {
return kTfLiteOk;
}
virtual TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs,
TfLiteContext* context) {
return kTfLiteOk;
}
// Static
// Computes the min/max values of 'tensor' and sets the values in
// the out params 'min' and 'max'.
// Returns kTfLiteOk on success.
static TfLiteStatus ComputeMinAndMaxQuantValues(const TfLiteTensor& tensor,
float* min, float* max) {
if (tensor.type == kTfLiteUInt8) {