Merge pull request #26655 from Dayananda-V:tflite_hybrid_op
PiperOrigin-RevId: 239124760
This commit is contained in:
commit
9201689e42
@ -95,9 +95,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(context, output, output_size_array));
|
||||
|
||||
bool is_hybrid =
|
||||
input->type == kTfLiteFloat32 && (input_weights->type == kTfLiteUInt8 ||
|
||||
input_weights->type == kTfLiteInt8);
|
||||
const bool is_hybrid = IsHybridOp(input, input_weights);
|
||||
|
||||
// Allocate temporary tensors to store quantized values of input and
|
||||
// hidden_state tensors.
|
||||
|
@ -509,8 +509,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->ResizeTensor(context, fw_output, fw_output_size));
|
||||
|
||||
// The weights are of consistent type, so it suffices to check one.
|
||||
const bool is_hybrid_op = (fw_input_to_output_weights->type == kTfLiteUInt8 ||
|
||||
fw_input_to_output_weights->type == kTfLiteInt8);
|
||||
const bool is_hybrid_op = IsHybridOp(input, fw_input_to_output_weights);
|
||||
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
if (is_hybrid_op) {
|
||||
|
@ -168,11 +168,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
bw_aux_input_weights->dims->data[1]);
|
||||
}
|
||||
|
||||
const bool is_hybrid_op = ((fw_input_weights->type == kTfLiteUInt8 ||
|
||||
fw_input_weights->type == kTfLiteInt8) &&
|
||||
input->type == kTfLiteFloat32);
|
||||
|
||||
if (is_hybrid_op) {
|
||||
if (IsHybridOp(input, fw_input_weights)) {
|
||||
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
|
||||
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
|
@ -84,6 +84,13 @@ inline void SetTensorToDynamic(TfLiteTensor* tensor) {
|
||||
}
|
||||
}
|
||||
|
||||
// Determines whether it is a hybrid op - one that has float inputs and
|
||||
// quantized weights.
|
||||
inline bool IsHybridOp(const TfLiteTensor* input, const TfLiteTensor* weight) {
|
||||
return ((weight->type == kTfLiteUInt8 || weight->type == kTfLiteInt8) &&
|
||||
input->type == kTfLiteFloat32);
|
||||
}
|
||||
|
||||
// Check dimensionality match and populate OpData for Conv and DepthwiseConv.
|
||||
TfLiteStatus PopulateConvolutionQuantizationParams(
|
||||
TfLiteContext* context, const TfLiteTensor* input,
|
||||
|
@ -381,10 +381,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->ResizeTensor(context, output, output_size));
|
||||
|
||||
// The weights are of consistent type, so it suffices to check one.
|
||||
// TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
|
||||
const bool is_hybrid_op = ((input_to_output_weights->type == kTfLiteUInt8 ||
|
||||
input_to_output_weights->type == kTfLiteInt8) &&
|
||||
input->type == kTfLiteFloat32);
|
||||
const bool is_hybrid_op = IsHybridOp(input, input_to_output_weights);
|
||||
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
if (is_hybrid_op) {
|
||||
|
@ -176,9 +176,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->ResizeTensor(context, output, output_size_array));
|
||||
|
||||
// The weights are of consistent type, so it suffices to check one.
|
||||
const bool is_hybrid_op = (input->type == kTfLiteFloat32 &&
|
||||
(weights_feature->type == kTfLiteUInt8 ||
|
||||
weights_feature->type == kTfLiteInt8));
|
||||
const bool is_hybrid_op = IsHybridOp(input, weights_feature);
|
||||
|
||||
// Resize scratch.
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
|
@ -304,14 +304,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(context, output, output_size));
|
||||
|
||||
// The weights are of consistent type, so it suffices to check one.
|
||||
// TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
|
||||
const bool is_hybrid_op = ((input_to_output_weights->type == kTfLiteUInt8 ||
|
||||
input_to_output_weights->type == kTfLiteInt8) &&
|
||||
input->type == kTfLiteFloat32);
|
||||
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
if (is_hybrid_op) {
|
||||
if (IsHybridOp(input, input_to_output_weights)) {
|
||||
node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
|
||||
} else {
|
||||
node->temporaries = TfLiteIntArrayCreate(1);
|
||||
@ -338,7 +332,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
|
||||
scratch_buffer_size));
|
||||
|
||||
if (is_hybrid_op) {
|
||||
if (IsHybridOp(input, input_to_output_weights)) {
|
||||
// Allocate temporary tensors to store quantized values of input,
|
||||
// activation_state and cell_state tensors.
|
||||
node->temporaries->data[kInputQuantized] =
|
||||
|
@ -96,9 +96,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(context, output, output_size_array));
|
||||
|
||||
const bool is_hybrid =
|
||||
input->type == kTfLiteFloat32 && (input_weights->type == kTfLiteUInt8 ||
|
||||
input_weights->type == kTfLiteInt8);
|
||||
const bool is_hybrid = IsHybridOp(input, input_weights);
|
||||
|
||||
// Allocate temporary tensors to store quantized values of input and
|
||||
// hidden_state tensors.
|
||||
|
Loading…
Reference in New Issue
Block a user