Merge pull request #26970 from ANSHUMAN87:FC-code-refactor
PiperOrigin-RevId: 240185593
This commit is contained in:
commit
9bb938c6f9
@ -277,17 +277,6 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define TF_LITE_MACRO_DISPATCH(macro_name, params, target_namespace) \
|
|
||||||
if (params->activation == kTfLiteActNone) { \
|
|
||||||
macro_name(target_namespace, kNone); \
|
|
||||||
} \
|
|
||||||
if (params->activation == kTfLiteActRelu) { \
|
|
||||||
macro_name(target_namespace, kRelu); \
|
|
||||||
} \
|
|
||||||
if (params->activation == kTfLiteActRelu6) { \
|
|
||||||
macro_name(target_namespace, kRelu6); \
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
|
void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
|
||||||
const TfLiteTensor* filter, const TfLiteTensor* bias,
|
const TfLiteTensor* filter, const TfLiteTensor* bias,
|
||||||
@ -343,38 +332,29 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
|
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
|
||||||
return EvalHybrid(context, node, params, data, input, filter, bias,
|
return EvalHybrid(context, node, params, data, input, filter, bias,
|
||||||
input_quantized, scaling_factors, output);
|
input_quantized, scaling_factors, output);
|
||||||
} else if (kernel_type == kReference) {
|
|
||||||
switch (output->type) {
|
|
||||||
case kTfLiteUInt8:
|
|
||||||
TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t);
|
|
||||||
break;
|
|
||||||
case kTfLiteInt8:
|
|
||||||
FullyConnectedInt8(data, input, filter, bias, output, gemm_context);
|
|
||||||
break;
|
|
||||||
case kTfLiteInt16:
|
|
||||||
TF_LITE_FULLY_CONNECTED(reference_ops, int16_t);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
context->ReportError(
|
|
||||||
context,
|
|
||||||
"Quantized FullyConnected expects output data type uint8 or int16");
|
|
||||||
return kTfLiteError;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
switch (output->type) {
|
switch (output->type) {
|
||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
|
if (kernel_type == kReference) {
|
||||||
|
TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t);
|
||||||
|
} else {
|
||||||
TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t);
|
TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt8:
|
case kTfLiteInt8:
|
||||||
FullyConnectedInt8(data, input, filter, bias, output, gemm_context);
|
FullyConnectedInt8(data, input, filter, bias, output, gemm_context);
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt16:
|
case kTfLiteInt16:
|
||||||
|
if (kernel_type == kReference) {
|
||||||
|
TF_LITE_FULLY_CONNECTED(reference_ops, int16_t);
|
||||||
|
} else {
|
||||||
TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t);
|
TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
context->ReportError(
|
context->ReportError(context,
|
||||||
context,
|
"Quantized FullyConnected expects output data "
|
||||||
"Quantized FullyConnected expects output data type uint8 or int16");
|
"type uint8, int8 or int16");
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -457,8 +437,6 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
#undef TF_LITE_MACRO_DISPATCH
|
|
||||||
|
|
||||||
template <KernelType kernel_type>
|
template <KernelType kernel_type>
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params =
|
auto* params =
|
||||||
@ -501,8 +479,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
context->ReportError(context, "Type %d not currently supported.",
|
context->ReportError(context,
|
||||||
filter->type);
|
"Filter data type %s currently not supported.",
|
||||||
|
TfLiteTypeGetName(filter->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
|
Loading…
Reference in New Issue
Block a user