Add member function in OpBuilder to compute min/max and add them to the graph.

Also, updated activation and arg min/max builders.

PiperOrigin-RevId: 323445668
Change-Id: I558bfe8b1fe3055762022457668770345c8367c1
This commit is contained in:
Karim Nosir 2020-07-27 14:48:33 -07:00 committed by TensorFlower Gardener
parent ac4bda59d6
commit dd3ce26d7b
4 changed files with 26 additions and 16 deletions

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <limits> #include <limits>
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h" #include "tensorflow/lite/delegates/hexagon/hexagon_nn/hexagon_nn.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
@ -32,13 +33,7 @@ TfLiteStatus ActivationOpBuilder::PopulateSubGraph(
int tensor_id = inputs->data[0]; int tensor_id = inputs->data[0];
const auto& input_tensor = context->tensors[tensor_id]; const auto& input_tensor = context->tensors[tensor_id];
AddInput(graph_builder_->GetHexagonTensorId(tensor_id)); AddInput(graph_builder_->GetHexagonTensorId(tensor_id));
ComputeMinAndMaxQuantValues(input_tensor, &input_min_, &input_max_); TF_LITE_ENSURE_STATUS(ComputeAndAddMinAndMax(context, input_tensor));
auto* input_min_const = graph_builder_->AddConstNodeWithData(
kScalarShape, reinterpret_cast<char*>(&input_min_), sizeof(input_min_));
auto* input_max_const = graph_builder_->AddConstNodeWithData(
kScalarShape, reinterpret_cast<char*>(&input_max_), sizeof(input_max_));
AddInput(TensorID(input_min_const->GetID(), 0));
AddInput(TensorID(input_max_const->GetID(), 0));
if (op_node_.op_type == OP_QuantizedReluX_8) { if (op_node_.op_type == OP_QuantizedReluX_8) {
auto* relu_value_const = graph_builder_->AddConstNodeWithData( auto* relu_value_const = graph_builder_->AddConstNodeWithData(

View File

@ -54,15 +54,7 @@ TfLiteStatus ArgMinMaxOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
AddInput(TensorID(input_axis_const->GetID(), 0)); AddInput(TensorID(input_axis_const->GetID(), 0));
// Compute Min/Max // Compute Min/Max
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(ComputeAndAddMinAndMax(context, input_tensor));
ComputeMinAndMaxQuantValues(input_tensor, &input_min_, &input_max_));
auto* input_min_const = graph_builder_->AddConstNodeWithData(
kScalarShape, reinterpret_cast<char*>(&input_min_), sizeof(input_min_));
auto* input_max_const = graph_builder_->AddConstNodeWithData(
kScalarShape, reinterpret_cast<char*>(&input_max_), sizeof(input_max_));
AddInput(TensorID(input_min_const->GetID(), 0));
AddInput(TensorID(input_max_const->GetID(), 0));
// Output Node // Output Node
int output_batch_size, output_height_size, output_width_size, int output_batch_size, output_height_size, output_width_size,

View File

@ -279,6 +279,21 @@ const OpNode* OpBuilder::Build() {
return &op_node_; return &op_node_;
} }
TfLiteStatus OpBuilder::ComputeAndAddMinAndMax(TfLiteContext* context,
const TfLiteTensor& tensor) {
float tensor_min, tensor_max;
TF_LITE_ENSURE_STATUS(
ComputeMinAndMaxQuantValues(tensor, &tensor_min, &tensor_max));
auto* min_const_node = graph_builder_->AddConstNodeWithData(
kScalarShape, reinterpret_cast<char*>(&tensor_min), sizeof(tensor_min));
auto* max_const_node = graph_builder_->AddConstNodeWithData(
kScalarShape, reinterpret_cast<char*>(&tensor_max), sizeof(tensor_max));
AddInput(TensorID(min_const_node->GetID(), 0));
AddInput(TensorID(max_const_node->GetID(), 0));
return kTfLiteOk;
}
// Static // Static
constexpr int OpBuilder::kScalarShape[]; constexpr int OpBuilder::kScalarShape[];

View File

@ -182,6 +182,14 @@ class OpBuilder {
} }
} }
// Computes the min and max for 'tensor' and adds them as input
// to the node.
TfLiteStatus ComputeAndAddMinAndMax(TfLiteContext* context,
const TfLiteTensor& tensor);
// Computes the float min and max for 'tensor', given 'min_value' and
// 'max_value' data range. The float min and max will be set in 'min' and
// 'max' params
template <typename T> template <typename T>
static TfLiteStatus ComputeMinAndMaxQuantValues(const TfLiteTensor& tensor, static TfLiteStatus ComputeMinAndMaxQuantValues(const TfLiteTensor& tensor,
float* min, float* max, float* min, float* max,