TFLite binary size optimization by removing templates
PiperOrigin-RevId: 310071337 Change-Id: I54b08317e854aadd187dddab6d9db27ac01b921f
This commit is contained in:
parent
eaaf5b6b43
commit
765dceef7f
@ -120,18 +120,18 @@ void Comparison(const TfLiteTensor* input1, const TfLiteTensor* input2,
|
||||
GetTensorShape(output), GetTensorData<bool>(output));
|
||||
}
|
||||
|
||||
template <bool (*opname)(const StringRef&, const StringRef&)>
|
||||
void ComparisonString(const TfLiteTensor* input1, const TfLiteTensor* input2,
|
||||
void ComparisonString(bool (*opname)(const StringRef&, const StringRef&),
|
||||
const TfLiteTensor* input1, const TfLiteTensor* input2,
|
||||
TfLiteTensor* output, bool requires_broadcast) {
|
||||
bool* output_data = GetTensorData<bool>(output);
|
||||
if (requires_broadcast) {
|
||||
reference_ops::BroadcastComparison4DSlowStringImpl<opname>(
|
||||
GetTensorShape(input1), input1, GetTensorShape(input2), input2,
|
||||
reference_ops::BroadcastComparison4DSlowStringImpl(
|
||||
opname, GetTensorShape(input1), input1, GetTensorShape(input2), input2,
|
||||
GetTensorShape(output), output_data);
|
||||
} else {
|
||||
reference_ops::ComparisonStringImpl<opname>(
|
||||
GetTensorShape(input1), input1, GetTensorShape(input2), input2,
|
||||
GetTensorShape(output), output_data);
|
||||
reference_ops::ComparisonStringImpl(opname, GetTensorShape(input1), input1,
|
||||
GetTensorShape(input2), input2,
|
||||
GetTensorShape(output), output_data);
|
||||
}
|
||||
}
|
||||
|
||||
@ -166,8 +166,8 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
input1, input2, output, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteString:
|
||||
ComparisonString<reference_ops::StringRefEqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
ComparisonString(reference_ops::StringRefEqualFn, input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(
|
||||
@ -210,8 +210,8 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
input1, input2, output, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteString:
|
||||
ComparisonString<reference_ops::StringRefNotEqualFn>(
|
||||
input1, input2, output, requires_broadcast);
|
||||
ComparisonString(reference_ops::StringRefNotEqualFn, input1, input2,
|
||||
output, requires_broadcast);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(
|
||||
|
@ -78,8 +78,8 @@ inline void ComparisonImpl(
|
||||
}
|
||||
}
|
||||
|
||||
template <bool (*F)(const StringRef&, const StringRef&)>
|
||||
inline void ComparisonStringImpl(const RuntimeShape& input1_shape,
|
||||
inline void ComparisonStringImpl(bool (*F)(const StringRef&, const StringRef&),
|
||||
const RuntimeShape& input1_shape,
|
||||
const TfLiteTensor* input1,
|
||||
const RuntimeShape& input2_shape,
|
||||
const TfLiteTensor* input2,
|
||||
@ -180,8 +180,8 @@ inline void BroadcastComparison4DSlowImpl(
|
||||
}
|
||||
}
|
||||
|
||||
template <bool (*F)(const StringRef&, const StringRef&)>
|
||||
inline void BroadcastComparison4DSlowStringImpl(
|
||||
bool (*F)(const StringRef&, const StringRef&),
|
||||
const RuntimeShape& unextended_input1_shape, const TfLiteTensor* input1,
|
||||
const RuntimeShape& unextended_input2_shape, const TfLiteTensor* input2,
|
||||
const RuntimeShape& unextended_output_shape, bool* output_data) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user