Merge pull request #26970 from ANSHUMAN87:FC-code-refactor

PiperOrigin-RevId: 240185593
This commit is contained in:
TensorFlower Gardener 2019-03-25 11:53:40 -07:00
commit 9bb938c6f9

View File

@ -277,17 +277,6 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk; return kTfLiteOk;
} }
#define TF_LITE_MACRO_DISPATCH(macro_name, params, target_namespace) \
if (params->activation == kTfLiteActNone) { \
macro_name(target_namespace, kNone); \
} \
if (params->activation == kTfLiteActRelu) { \
macro_name(target_namespace, kRelu); \
} \
if (params->activation == kTfLiteActRelu6) { \
macro_name(target_namespace, kRelu6); \
}
namespace { namespace {
void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input, void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
const TfLiteTensor* filter, const TfLiteTensor* bias, const TfLiteTensor* filter, const TfLiteTensor* bias,
@ -343,38 +332,29 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1); TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1);
return EvalHybrid(context, node, params, data, input, filter, bias, return EvalHybrid(context, node, params, data, input, filter, bias,
input_quantized, scaling_factors, output); input_quantized, scaling_factors, output);
} else if (kernel_type == kReference) {
switch (output->type) {
case kTfLiteUInt8:
TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t);
break;
case kTfLiteInt8:
FullyConnectedInt8(data, input, filter, bias, output, gemm_context);
break;
case kTfLiteInt16:
TF_LITE_FULLY_CONNECTED(reference_ops, int16_t);
break;
default:
context->ReportError(
context,
"Quantized FullyConnected expects output data type uint8 or int16");
return kTfLiteError;
}
} else { } else {
switch (output->type) { switch (output->type) {
case kTfLiteUInt8: case kTfLiteUInt8:
if (kernel_type == kReference) {
TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t);
} else {
TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t); TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t);
}
break; break;
case kTfLiteInt8: case kTfLiteInt8:
FullyConnectedInt8(data, input, filter, bias, output, gemm_context); FullyConnectedInt8(data, input, filter, bias, output, gemm_context);
break; break;
case kTfLiteInt16: case kTfLiteInt16:
if (kernel_type == kReference) {
TF_LITE_FULLY_CONNECTED(reference_ops, int16_t);
} else {
TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t); TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t);
}
break; break;
default: default:
context->ReportError( context->ReportError(context,
context, "Quantized FullyConnected expects output data "
"Quantized FullyConnected expects output data type uint8 or int16"); "type uint8, int8 or int16");
return kTfLiteError; return kTfLiteError;
} }
} }
@ -457,8 +437,6 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk; return kTfLiteOk;
} }
#undef TF_LITE_MACRO_DISPATCH
template <KernelType kernel_type> template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = auto* params =
@ -501,8 +479,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError; return kTfLiteError;
} }
default: default:
context->ReportError(context, "Type %d not currently supported.", context->ReportError(context,
filter->type); "Filter data type %s currently not supported.",
TfLiteTypeGetName(filter->type));
return kTfLiteError; return kTfLiteError;
} }
return kTfLiteOk; return kTfLiteOk;