Merge pull request #26970 from ANSHUMAN87:FC-code-refactor
PiperOrigin-RevId: 240185593
This commit is contained in:
commit
9bb938c6f9
@ -277,17 +277,6 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
#define TF_LITE_MACRO_DISPATCH(macro_name, params, target_namespace) \
|
||||
if (params->activation == kTfLiteActNone) { \
|
||||
macro_name(target_namespace, kNone); \
|
||||
} \
|
||||
if (params->activation == kTfLiteActRelu) { \
|
||||
macro_name(target_namespace, kRelu); \
|
||||
} \
|
||||
if (params->activation == kTfLiteActRelu6) { \
|
||||
macro_name(target_namespace, kRelu6); \
|
||||
}
|
||||
|
||||
namespace {
|
||||
void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter, const TfLiteTensor* bias,
|
||||
@ -343,38 +332,29 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
|
||||
return EvalHybrid(context, node, params, data, input, filter, bias,
|
||||
input_quantized, scaling_factors, output);
|
||||
} else if (kernel_type == kReference) {
|
||||
switch (output->type) {
|
||||
case kTfLiteUInt8:
|
||||
TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
FullyConnectedInt8(data, input, filter, bias, output, gemm_context);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
TF_LITE_FULLY_CONNECTED(reference_ops, int16_t);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(
|
||||
context,
|
||||
"Quantized FullyConnected expects output data type uint8 or int16");
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else {
|
||||
switch (output->type) {
|
||||
case kTfLiteUInt8:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t);
|
||||
} else {
|
||||
TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t);
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
FullyConnectedInt8(data, input, filter, bias, output, gemm_context);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_FULLY_CONNECTED(reference_ops, int16_t);
|
||||
} else {
|
||||
TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
context->ReportError(
|
||||
context,
|
||||
"Quantized FullyConnected expects output data type uint8 or int16");
|
||||
context->ReportError(context,
|
||||
"Quantized FullyConnected expects output data "
|
||||
"type uint8, int8 or int16");
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
@ -457,8 +437,6 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
#undef TF_LITE_MACRO_DISPATCH
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
@ -501,8 +479,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
default:
|
||||
context->ReportError(context, "Type %d not currently supported.",
|
||||
filter->type);
|
||||
context->ReportError(context,
|
||||
"Filter data type %s currently not supported.",
|
||||
TfLiteTypeGetName(filter->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
Loading…
Reference in New Issue
Block a user