Merge pull request #42543 from freedomtan:nnapi_delegate_13_hacks
PiperOrigin-RevId: 331184445 Change-Id: Ifb1d311a3a75159a8d198686e0307086bd76c316
This commit is contained in:
commit
d9c0b26e9e
@ -164,6 +164,48 @@ bool IsQuantized(TfLiteType type) {
|
||||
}
|
||||
}
|
||||
|
||||
bool IsInt32(TfLiteType type) {
|
||||
switch (type) {
|
||||
case kTfLiteInt32:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsFloatOrQuantized(TfLiteType type) {
|
||||
switch (type) {
|
||||
case kTfLiteFloat32:
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsFloatOrInt32(TfLiteType type) {
|
||||
switch (type) {
|
||||
case kTfLiteFloat32:
|
||||
case kTfLiteInt32:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsFloatQuantizedOrInt32(TfLiteType type) {
|
||||
switch (type) {
|
||||
case kTfLiteFloat32:
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8:
|
||||
case kTfLiteInt32:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsScalarInputSupported(int builtin_code) {
|
||||
switch (builtin_code) {
|
||||
case kTfLiteBuiltinAdd:
|
||||
@ -1530,11 +1572,30 @@ bool ExpectIsFloatOrQuant8Operator(const TfLiteContext* context,
|
||||
const TfLiteNode* node,
|
||||
OpValidationContext* val_ctx) {
|
||||
const auto input_type = context->tensors[node->inputs->data[0]].type;
|
||||
return Expect(IsFloat(input_type) || IsQuantized(input_type),
|
||||
return Expect(IsFloatOrQuantized(input_type),
|
||||
NNAPIValidationFailureType::kUnsupportedInputType,
|
||||
"Input should be Float or Quant8", val_ctx);
|
||||
}
|
||||
|
||||
bool ExpectIsFloatOrInt32Operator(const TfLiteContext* context,
|
||||
const TfLiteNode* node,
|
||||
OpValidationContext* val_ctx) {
|
||||
const auto input_type = context->tensors[node->inputs->data[0]].type;
|
||||
return Expect(IsFloatOrInt32(input_type),
|
||||
NNAPIValidationFailureType::kUnsupportedInputType,
|
||||
"Input should be Float or Int32", val_ctx);
|
||||
}
|
||||
|
||||
bool ExpectIsFloatQuant8OrInt32Operator(const TfLiteContext* context,
|
||||
const TfLiteNode* node,
|
||||
OpValidationContext* val_ctx) {
|
||||
const auto input_type = context->tensors[node->inputs->data[0]].type;
|
||||
return Expect(IsFloatQuantizedOrInt32(input_type),
|
||||
NNAPIValidationFailureType::kUnsupportedInputType,
|
||||
"Input should be Float, Quant8, or Int32", val_ctx);
|
||||
}
|
||||
|
||||
// When using NN API version 1.0 or 1.1, the condition below must be true for
|
||||
// When using NN API version 1.0 or 1.1, the condition below must be true for
|
||||
// quantized versions of the following ops:
|
||||
// * CONV_2D
|
||||
@ -1571,7 +1632,17 @@ bool NNAPIDelegateKernel::Validate(
|
||||
switch (builtin_code) {
|
||||
case kTfLiteBuiltinAdd: {
|
||||
ExpectMaxOpVersion(version, 2, &val_ctx);
|
||||
ExpectIsFloatOrQuant8Operator(context, node, &val_ctx);
|
||||
if (android_sdk_version >= kMinSdkVersionForNNAPI13) {
|
||||
ExpectIsFloatQuant8OrInt32Operator(context, node, &val_ctx);
|
||||
if (IsInt32(context->tensors[node->inputs->data[0]].type)) {
|
||||
Expect(reinterpret_cast<TfLiteAddParams*>(node->builtin_data)
|
||||
->activation == kTfLiteActNone,
|
||||
NNAPIValidationFailureType::kNoActivationExpected,
|
||||
"No activation function supported", &val_ctx);
|
||||
}
|
||||
} else {
|
||||
ExpectIsFloatOrQuant8Operator(context, node, &val_ctx);
|
||||
}
|
||||
} break;
|
||||
case kTfLiteBuiltinArgMax:
|
||||
case kTfLiteBuiltinArgMin: {
|
||||
@ -1616,7 +1687,17 @@ bool NNAPIDelegateKernel::Validate(
|
||||
} break;
|
||||
case kTfLiteBuiltinMul: {
|
||||
ExpectMaxOpVersion(version, 2, &val_ctx);
|
||||
ExpectIsFloatOrQuant8Operator(context, node, &val_ctx);
|
||||
if (android_sdk_version >= kMinSdkVersionForNNAPI13) {
|
||||
ExpectIsFloatQuant8OrInt32Operator(context, node, &val_ctx);
|
||||
if (IsInt32(context->tensors[node->inputs->data[0]].type)) {
|
||||
Expect(reinterpret_cast<TfLiteMulParams*>(node->builtin_data)
|
||||
->activation == kTfLiteActNone,
|
||||
NNAPIValidationFailureType::kNoActivationExpected,
|
||||
"No activation function supported", &val_ctx);
|
||||
}
|
||||
} else {
|
||||
ExpectIsFloatOrQuant8Operator(context, node, &val_ctx);
|
||||
}
|
||||
} break;
|
||||
case kTfLiteBuiltinAveragePool2d: {
|
||||
ExpectMaxOpVersion(version, 2, &val_ctx);
|
||||
@ -1971,9 +2052,17 @@ bool NNAPIDelegateKernel::Validate(
|
||||
Expect((android_sdk_version >= kMinSdkVersionForNNAPI11 &&
|
||||
IsFloat(input_type)) ||
|
||||
(android_sdk_version >= kMinSdkVersionForNNAPI12 &&
|
||||
IsQuantized(input_type)),
|
||||
IsQuantized(input_type)) ||
|
||||
(android_sdk_version >= kMinSdkVersionForNNAPI13 &&
|
||||
IsInt32(input_type)),
|
||||
NNAPIValidationFailureType::kUnsupportedInputType,
|
||||
"NNAPI only support float sub.", &val_ctx);
|
||||
if (IsInt32(input_type)) {
|
||||
Expect(reinterpret_cast<TfLiteSubParams*>(node->builtin_data)
|
||||
->activation == kTfLiteActNone,
|
||||
NNAPIValidationFailureType::kNoActivationExpected,
|
||||
"No activation function supported", &val_ctx);
|
||||
}
|
||||
const int input0_rank =
|
||||
context->tensors[node->inputs->data[0]].dims->size;
|
||||
const int input1_rank =
|
||||
|
Loading…
Reference in New Issue
Block a user