TFLite binary size optimization by removing templates

PiperOrigin-RevId: 310071337
Change-Id: I54b08317e854aadd187dddab6d9db27ac01b921f
This commit is contained in:
Hyeonjong Ryu 2020-05-05 19:59:15 -07:00 committed by TensorFlower Gardener
parent eaaf5b6b43
commit 765dceef7f
2 changed files with 14 additions and 14 deletions

View File

@ -120,18 +120,18 @@ void Comparison(const TfLiteTensor* input1, const TfLiteTensor* input2,
GetTensorShape(output), GetTensorData<bool>(output));
}
template <bool (*opname)(const StringRef&, const StringRef&)>
void ComparisonString(const TfLiteTensor* input1, const TfLiteTensor* input2,
void ComparisonString(bool (*opname)(const StringRef&, const StringRef&),
const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output, bool requires_broadcast) {
bool* output_data = GetTensorData<bool>(output);
if (requires_broadcast) {
reference_ops::BroadcastComparison4DSlowStringImpl<opname>(
GetTensorShape(input1), input1, GetTensorShape(input2), input2,
reference_ops::BroadcastComparison4DSlowStringImpl(
opname, GetTensorShape(input1), input1, GetTensorShape(input2), input2,
GetTensorShape(output), output_data);
} else {
reference_ops::ComparisonStringImpl<opname>(
GetTensorShape(input1), input1, GetTensorShape(input2), input2,
GetTensorShape(output), output_data);
reference_ops::ComparisonStringImpl(opname, GetTensorShape(input1), input1,
GetTensorShape(input2), input2,
GetTensorShape(output), output_data);
}
}
@ -166,8 +166,8 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
input1, input2, output, requires_broadcast);
break;
case kTfLiteString:
ComparisonString<reference_ops::StringRefEqualFn>(input1, input2, output,
requires_broadcast);
ComparisonString(reference_ops::StringRefEqualFn, input1, input2, output,
requires_broadcast);
break;
default:
context->ReportError(
@ -210,8 +210,8 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
input1, input2, output, requires_broadcast);
break;
case kTfLiteString:
ComparisonString<reference_ops::StringRefNotEqualFn>(
input1, input2, output, requires_broadcast);
ComparisonString(reference_ops::StringRefNotEqualFn, input1, input2,
output, requires_broadcast);
break;
default:
context->ReportError(

View File

@ -78,8 +78,8 @@ inline void ComparisonImpl(
}
}
template <bool (*F)(const StringRef&, const StringRef&)>
inline void ComparisonStringImpl(const RuntimeShape& input1_shape,
inline void ComparisonStringImpl(bool (*F)(const StringRef&, const StringRef&),
const RuntimeShape& input1_shape,
const TfLiteTensor* input1,
const RuntimeShape& input2_shape,
const TfLiteTensor* input2,
@ -180,8 +180,8 @@ inline void BroadcastComparison4DSlowImpl(
}
}
template <bool (*F)(const StringRef&, const StringRef&)>
inline void BroadcastComparison4DSlowStringImpl(
bool (*F)(const StringRef&, const StringRef&),
const RuntimeShape& unextended_input1_shape, const TfLiteTensor* input1,
const RuntimeShape& unextended_input2_shape, const TfLiteTensor* input2,
const RuntimeShape& unextended_output_shape, bool* output_data) {