Lite: Fully_connected Op code refactored

This commit is contained in:
ANSHUMAN TRIPATHY 2019-03-21 10:27:43 +05:30
parent 395d32f274
commit 22453e13c0

View File

@ -277,17 +277,6 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}
#define TF_LITE_MACRO_DISPATCH(macro_name, params, target_namespace) \
if (params->activation == kTfLiteActNone) { \
macro_name(target_namespace, kNone); \
} \
if (params->activation == kTfLiteActRelu) { \
macro_name(target_namespace, kRelu); \
} \
if (params->activation == kTfLiteActRelu6) { \
macro_name(target_namespace, kRelu6); \
}
namespace {
void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
const TfLiteTensor* filter, const TfLiteTensor* bias,
@ -343,38 +332,29 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
return EvalHybrid(context, node, params, data, input, filter, bias,
input_quantized, scaling_factors, output);
} else if (kernel_type == kReference) {
switch (output->type) {
case kTfLiteUInt8:
TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t);
break;
case kTfLiteInt8:
FullyConnectedInt8(data, input, filter, bias, output, gemm_context);
break;
case kTfLiteInt16:
TF_LITE_FULLY_CONNECTED(reference_ops, int16_t);
break;
default:
context->ReportError(
context,
"Quantized FullyConnected expects output data type uint8 or int16");
return kTfLiteError;
}
} else {
switch (output->type) {
case kTfLiteUInt8:
if (kernel_type == kReference) {
TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t);
} else {
TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t);
}
break;
case kTfLiteInt8:
FullyConnectedInt8(data, input, filter, bias, output, gemm_context);
break;
case kTfLiteInt16:
if (kernel_type == kReference) {
TF_LITE_FULLY_CONNECTED(reference_ops, int16_t);
} else {
TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t);
}
break;
default:
context->ReportError(
context,
"Quantized FullyConnected expects output data type uint8 or int16");
context->ReportError(context,
"Quantized FullyConnected expects output data "
"type uint8, int8 or int16");
return kTfLiteError;
}
}
@ -457,8 +437,6 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}
#undef TF_LITE_MACRO_DISPATCH
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
@ -501,8 +479,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError;
}
default:
context->ReportError(context, "Type %d not currently supported.",
filter->type);
context->ReportError(context,
"Filter data type %s currently not supported.",
TfLiteTypeGetName(filter->type));
return kTfLiteError;
}
return kTfLiteOk;