Merge pull request #42543 from freedomtan:nnapi_delegate_13_hacks

PiperOrigin-RevId: 331184445
Change-Id: Ifb1d311a3a75159a8d198686e0307086bd76c316
This commit is contained in:
TensorFlower Gardener 2020-09-11 11:35:37 -07:00
commit d9c0b26e9e

View File

@ -164,6 +164,48 @@ bool IsQuantized(TfLiteType type) {
}
}
bool IsInt32(TfLiteType type) {
switch (type) {
case kTfLiteInt32:
return true;
default:
return false;
}
}
bool IsFloatOrQuantized(TfLiteType type) {
switch (type) {
case kTfLiteFloat32:
case kTfLiteUInt8:
case kTfLiteInt8:
return true;
default:
return false;
}
}
bool IsFloatOrInt32(TfLiteType type) {
switch (type) {
case kTfLiteFloat32:
case kTfLiteInt32:
return true;
default:
return false;
}
}
bool IsFloatQuantizedOrInt32(TfLiteType type) {
switch (type) {
case kTfLiteFloat32:
case kTfLiteUInt8:
case kTfLiteInt8:
case kTfLiteInt32:
return true;
default:
return false;
}
}
bool IsScalarInputSupported(int builtin_code) {
switch (builtin_code) {
case kTfLiteBuiltinAdd:
@ -1530,11 +1572,30 @@ bool ExpectIsFloatOrQuant8Operator(const TfLiteContext* context,
const TfLiteNode* node,
OpValidationContext* val_ctx) {
const auto input_type = context->tensors[node->inputs->data[0]].type;
return Expect(IsFloat(input_type) || IsQuantized(input_type),
return Expect(IsFloatOrQuantized(input_type),
NNAPIValidationFailureType::kUnsupportedInputType,
"Input should be Float or Quant8", val_ctx);
}
bool ExpectIsFloatOrInt32Operator(const TfLiteContext* context,
const TfLiteNode* node,
OpValidationContext* val_ctx) {
const auto input_type = context->tensors[node->inputs->data[0]].type;
return Expect(IsFloatOrInt32(input_type),
NNAPIValidationFailureType::kUnsupportedInputType,
"Input should be Float or Int32", val_ctx);
}
bool ExpectIsFloatQuant8OrInt32Operator(const TfLiteContext* context,
const TfLiteNode* node,
OpValidationContext* val_ctx) {
const auto input_type = context->tensors[node->inputs->data[0]].type;
return Expect(IsFloatQuantizedOrInt32(input_type),
NNAPIValidationFailureType::kUnsupportedInputType,
"Input should be Float, Quant8, or Int32", val_ctx);
}
// When using NN API version 1.0 or 1.1, the condition below must be true for
// When using NN API version 1.0 or 1.1, the condition below must be true for
// quantized versions of the following ops:
// * CONV_2D
@ -1571,7 +1632,17 @@ bool NNAPIDelegateKernel::Validate(
switch (builtin_code) {
case kTfLiteBuiltinAdd: {
ExpectMaxOpVersion(version, 2, &val_ctx);
if (android_sdk_version >= kMinSdkVersionForNNAPI13) {
ExpectIsFloatQuant8OrInt32Operator(context, node, &val_ctx);
if (IsInt32(context->tensors[node->inputs->data[0]].type)) {
Expect(reinterpret_cast<TfLiteAddParams*>(node->builtin_data)
->activation == kTfLiteActNone,
NNAPIValidationFailureType::kNoActivationExpected,
"No activation function supported", &val_ctx);
}
} else {
ExpectIsFloatOrQuant8Operator(context, node, &val_ctx);
}
} break;
case kTfLiteBuiltinArgMax:
case kTfLiteBuiltinArgMin: {
@ -1616,7 +1687,17 @@ bool NNAPIDelegateKernel::Validate(
} break;
case kTfLiteBuiltinMul: {
ExpectMaxOpVersion(version, 2, &val_ctx);
if (android_sdk_version >= kMinSdkVersionForNNAPI13) {
ExpectIsFloatQuant8OrInt32Operator(context, node, &val_ctx);
if (IsInt32(context->tensors[node->inputs->data[0]].type)) {
Expect(reinterpret_cast<TfLiteMulParams*>(node->builtin_data)
->activation == kTfLiteActNone,
NNAPIValidationFailureType::kNoActivationExpected,
"No activation function supported", &val_ctx);
}
} else {
ExpectIsFloatOrQuant8Operator(context, node, &val_ctx);
}
} break;
case kTfLiteBuiltinAveragePool2d: {
ExpectMaxOpVersion(version, 2, &val_ctx);
@ -1971,9 +2052,17 @@ bool NNAPIDelegateKernel::Validate(
Expect((android_sdk_version >= kMinSdkVersionForNNAPI11 &&
IsFloat(input_type)) ||
(android_sdk_version >= kMinSdkVersionForNNAPI12 &&
IsQuantized(input_type)),
IsQuantized(input_type)) ||
(android_sdk_version >= kMinSdkVersionForNNAPI13 &&
IsInt32(input_type)),
NNAPIValidationFailureType::kUnsupportedInputType,
"NNAPI only support float sub.", &val_ctx);
if (IsInt32(input_type)) {
Expect(reinterpret_cast<TfLiteSubParams*>(node->builtin_data)
->activation == kTfLiteActNone,
NNAPIValidationFailureType::kNoActivationExpected,
"No activation function supported", &val_ctx);
}
const int input0_rank =
context->tensors[node->inputs->data[0]].dims->size;
const int input1_rank =