Implementation of comparison operations
PiperOrigin-RevId: 262861510
This commit is contained in:
parent
92b7212e54
commit
3fef7240e3
@ -15,6 +15,7 @@ cc_library(
|
||||
name = "micro_ops",
|
||||
srcs = [
|
||||
"arg_min_max.cc",
|
||||
"comparisons.cc",
|
||||
"conv.cc",
|
||||
"depthwise_conv.cc",
|
||||
"elementwise.cc",
|
||||
@ -62,6 +63,7 @@ cc_library(
|
||||
name = "portable_optimized_micro_ops",
|
||||
srcs = [
|
||||
"arg_min_max.cc",
|
||||
"comparisons.cc",
|
||||
"conv.cc",
|
||||
"elementwise.cc",
|
||||
"floor.cc",
|
||||
@ -260,6 +262,20 @@ tflite_micro_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tflite_micro_cc_test(
|
||||
name = "comparisons_test",
|
||||
srcs = [
|
||||
"comparisons_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":all_ops_resolver",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||
"//tensorflow/lite/experimental/micro/kernels:micro_utils",
|
||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "micro_utils",
|
||||
hdrs = ["micro_utils.h"],
|
||||
|
@ -39,6 +39,12 @@ TfLiteRegistration* Register_LOGICAL_OR();
|
||||
TfLiteRegistration* Register_LOGICAL_AND();
|
||||
TfLiteRegistration* Register_LOGICAL_NOT();
|
||||
TfLiteRegistration* Register_RESHAPE();
|
||||
TfLiteRegistration* Register_EQUAL();
|
||||
TfLiteRegistration* Register_NOT_EQUAL();
|
||||
TfLiteRegistration* Register_GREATER();
|
||||
TfLiteRegistration* Register_GREATER_EQUAL();
|
||||
TfLiteRegistration* Register_LESS();
|
||||
TfLiteRegistration* Register_LESS_EQUAL();
|
||||
|
||||
AllOpsResolver::AllOpsResolver() {
|
||||
AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D());
|
||||
@ -66,6 +72,12 @@ AllOpsResolver::AllOpsResolver() {
|
||||
AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND());
|
||||
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
|
||||
AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE());
|
||||
AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL());
|
||||
AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL());
|
||||
AddBuiltin(BuiltinOperator_GREATER, Register_GREATER());
|
||||
AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL());
|
||||
AddBuiltin(BuiltinOperator_LESS, Register_LESS());
|
||||
AddBuiltin(BuiltinOperator_LESS_EQUAL, Register_LESS_EQUAL());
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
338
tensorflow/lite/experimental/micro/kernels/comparisons.cc
Normal file
338
tensorflow/lite/experimental/micro/kernels/comparisons.cc
Normal file
@ -0,0 +1,338 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/internal/reference/comparisons.h"
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
namespace comparisons {
|
||||
namespace {
|
||||
|
||||
constexpr int kInputTensor1 = 0;
|
||||
constexpr int kInputTensor2 = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
// TODO(ruic): optimize macros below to using template functions.
|
||||
#define TF_LITE_QUANTIZE_COMPARISON(opname) \
|
||||
template <typename input_dtype> \
|
||||
void EvalQuantized##opname(TfLiteContext* context, TfLiteNode* node, \
|
||||
const TfLiteTensor* input1, \
|
||||
const TfLiteTensor* input2, TfLiteTensor* output, \
|
||||
bool requires_broadcast) { \
|
||||
if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) { \
|
||||
auto input1_offset = -input1->params.zero_point; \
|
||||
auto input2_offset = -input2->params.zero_point; \
|
||||
const int left_shift = 8; \
|
||||
\
|
||||
int32 input1_multiplier; \
|
||||
int input1_shift; \
|
||||
QuantizeMultiplierSmallerThanOneExp(input1->params.scale, \
|
||||
&input1_multiplier, &input1_shift); \
|
||||
int32 input2_multiplier; \
|
||||
int input2_shift; \
|
||||
QuantizeMultiplierSmallerThanOneExp(input2->params.scale, \
|
||||
&input2_multiplier, &input2_shift); \
|
||||
\
|
||||
ComparisonParams op_params; \
|
||||
op_params.left_shift = left_shift; \
|
||||
op_params.input1_offset = input1_offset; \
|
||||
op_params.input1_multiplier = input1_multiplier; \
|
||||
op_params.input1_shift = input1_shift; \
|
||||
op_params.input2_offset = input2_offset; \
|
||||
op_params.input2_multiplier = input2_multiplier; \
|
||||
op_params.input2_shift = input2_shift; \
|
||||
if (requires_broadcast) { \
|
||||
reference_ops::Broadcast4DSlow##opname##WithScaling( \
|
||||
op_params, GetTensorShape(input1), \
|
||||
GetTensorData<input_dtype>(input1), GetTensorShape(input2), \
|
||||
GetTensorData<input_dtype>(input2), GetTensorShape(output), \
|
||||
GetTensorData<bool>(output)); \
|
||||
} else { \
|
||||
reference_ops::opname##WithScaling( \
|
||||
op_params, GetTensorShape(input1), \
|
||||
GetTensorData<input_dtype>(input1), GetTensorShape(input2), \
|
||||
GetTensorData<input_dtype>(input2), GetTensorShape(output), \
|
||||
GetTensorData<bool>(output)); \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
TF_LITE_QUANTIZE_COMPARISON(Equal);
|
||||
TF_LITE_QUANTIZE_COMPARISON(NotEqual);
|
||||
TF_LITE_QUANTIZE_COMPARISON(Greater);
|
||||
TF_LITE_QUANTIZE_COMPARISON(GreaterEqual);
|
||||
TF_LITE_QUANTIZE_COMPARISON(Less);
|
||||
TF_LITE_QUANTIZE_COMPARISON(LessEqual);
|
||||
#undef TF_LITE_QUANTIZE_COMPARISON
|
||||
|
||||
#define TF_LITE_COMPARISON(type, opname, requires_broadcast) \
|
||||
{ \
|
||||
ComparisonParams op_params; \
|
||||
requires_broadcast \
|
||||
? reference_ops::Broadcast4DSlow##opname##NoScaling( \
|
||||
op_params, GetTensorShape(input1), GetTensorData<type>(input1), \
|
||||
GetTensorShape(input2), GetTensorData<type>(input2), \
|
||||
GetTensorShape(output), GetTensorData<bool>(output)) \
|
||||
: reference_ops::opname##NoScaling( \
|
||||
op_params, GetTensorShape(input1), GetTensorData<type>(input1), \
|
||||
GetTensorShape(input2), GetTensorData<type>(input2), \
|
||||
GetTensorShape(output), GetTensorData<bool>(output)); \
|
||||
}
|
||||
|
||||
TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
bool requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
switch (input1->type) {
|
||||
case kTfLiteBool:
|
||||
TF_LITE_COMPARISON(bool, Equal, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_COMPARISON(float, Equal, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_COMPARISON(int32_t, Equal, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
EvalQuantizedEqual<uint8_t>(context, node, input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
EvalQuantizedEqual<int8_t>(context, node, input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(
|
||||
context, "Does not support type %d, requires bool|float|int|uint8",
|
||||
input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// TODO(renjieliu): Refactor the logic to avoid duplications.
|
||||
TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
bool requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
switch (input1->type) {
|
||||
case kTfLiteBool:
|
||||
TF_LITE_COMPARISON(bool, NotEqual, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_COMPARISON(float, NotEqual, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_COMPARISON(int32_t, NotEqual, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
EvalQuantizedNotEqual<uint8_t>(context, node, input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
EvalQuantizedNotEqual<int8_t>(context, node, input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(
|
||||
context, "Does not support type %d, requires bool|float|int|uint8",
|
||||
input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
bool requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
switch (input1->type) {
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_COMPARISON(float, Greater, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_COMPARISON(int32_t, Greater, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
TF_LITE_COMPARISON(int64_t, Greater, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
EvalQuantizedGreater<uint8_t>(context, node, input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
EvalQuantizedGreater<int8_t>(context, node, input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Does not support type %d, requires float|int|uint8",
|
||||
input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
bool requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
switch (input1->type) {
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_COMPARISON(float, GreaterEqual, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_COMPARISON(int32_t, GreaterEqual, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
TF_LITE_COMPARISON(int64_t, GreaterEqual, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
EvalQuantizedGreaterEqual<uint8_t>(context, node, input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
EvalQuantizedGreaterEqual<int8_t>(context, node, input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Does not support type %d, requires float|int|uint8",
|
||||
input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
bool requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
switch (input1->type) {
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_COMPARISON(float, Less, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_COMPARISON(int32_t, Less, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
TF_LITE_COMPARISON(int64_t, Less, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
EvalQuantizedLess<uint8_t>(context, node, input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
EvalQuantizedLess<int8_t>(context, node, input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Does not support type %d, requires float|int|uint8",
|
||||
input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
bool requires_broadcast = !HaveSameShapes(input1, input2);
|
||||
switch (input1->type) {
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_COMPARISON(float, LessEqual, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_COMPARISON(int32_t, LessEqual, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
TF_LITE_COMPARISON(int64_t, LessEqual, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
EvalQuantizedLessEqual<uint8_t>(context, node, input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
EvalQuantizedLessEqual<int8_t>(context, node, input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Does not support type %d, requires float|int|uint8",
|
||||
input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace comparisons
|
||||
|
||||
TfLiteRegistration* Register_EQUAL() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, nullptr,
|
||||
comparisons::EqualEval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_NOT_EQUAL() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, nullptr,
|
||||
comparisons::NotEqualEval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_GREATER() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, nullptr,
|
||||
comparisons::GreaterEval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_GREATER_EQUAL() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, nullptr,
|
||||
comparisons::GreaterEqualEval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_LESS() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, nullptr,
|
||||
comparisons::LessEval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_LESS_EQUAL() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, nullptr,
|
||||
comparisons::LessEqualEval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
1100
tensorflow/lite/experimental/micro/kernels/comparisons_test.cc
Normal file
1100
tensorflow/lite/experimental/micro/kernels/comparisons_test.cc
Normal file
File diff suppressed because it is too large
Load Diff
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_TEST_UTILS_H_
|
||||
|
||||
#include <cstdarg>
|
||||
#include <cstdint>
|
||||
#include <initializer_list>
|
||||
#include <limits>
|
||||
|
||||
@ -72,6 +73,11 @@ inline uint8_t F2Q(const float value, const float min, const float max) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Converts a float value into a signed eight-bit quantized value.
|
||||
inline int8_t F2QS(const float value, const float min, const float max) {
|
||||
return F2Q(value, min, max) + std::numeric_limits<int8_t>::min();
|
||||
}
|
||||
|
||||
// Converts a float value into a signed thirty-two-bit quantized value.
|
||||
inline int32_t F2Q32(const float value, const float min, const float max) {
|
||||
return static_cast<int32_t>((value - ZeroPointFromMinMax<int32_t>(min, max)) /
|
||||
@ -123,6 +129,25 @@ inline TfLiteTensor CreateFloatTensor(std::initializer_list<float> data,
|
||||
return CreateFloatTensor(data.begin(), dims, name);
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateInt32Tensor(const int32_t* data, TfLiteIntArray* dims,
|
||||
const char* name) {
|
||||
TfLiteTensor result;
|
||||
result.type = kTfLiteInt32;
|
||||
result.data.i32 = const_cast<int32_t*>(data);
|
||||
result.dims = dims;
|
||||
result.params = {};
|
||||
result.allocation_type = kTfLiteMemNone;
|
||||
result.bytes = ElementCount(*dims) * sizeof(int32_t);
|
||||
result.allocation = nullptr;
|
||||
result.name = name;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateInt32Tensor(std::initializer_list<int32_t> data,
|
||||
TfLiteIntArray* dims, const char* name) {
|
||||
return CreateInt32Tensor(data.begin(), dims, name);
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateBoolTensor(const bool* data, TfLiteIntArray* dims,
|
||||
const char* name) {
|
||||
TfLiteTensor result;
|
||||
@ -166,6 +191,29 @@ inline TfLiteTensor CreateQuantizedTensor(std::initializer_list<uint8_t> data,
|
||||
return CreateQuantizedTensor(data.begin(), dims, name, min, max);
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateQuantizedInt8Tensor(const int8_t* data,
|
||||
TfLiteIntArray* dims,
|
||||
const char* name, float min,
|
||||
float max) {
|
||||
TfLiteTensor result;
|
||||
result.type = kTfLiteInt8;
|
||||
result.data.int8 = const_cast<int8_t*>(data);
|
||||
result.dims = dims;
|
||||
result.params = {ScaleFromMinMax<int8_t>(min, max),
|
||||
ZeroPointFromMinMax<int8_t>(min, max)};
|
||||
result.allocation_type = kTfLiteMemNone;
|
||||
result.bytes = ElementCount(*dims) * sizeof(int8_t);
|
||||
result.allocation = nullptr;
|
||||
result.name = name;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateQuantizedInt8Tensor(
|
||||
std::initializer_list<int8_t> data, TfLiteIntArray* dims, const char* name,
|
||||
float min, float max) {
|
||||
return CreateQuantizedInt8Tensor(data.begin(), dims, name, min, max);
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateQuantized32Tensor(const int32_t* data,
|
||||
TfLiteIntArray* dims,
|
||||
const char* name, float min,
|
||||
|
@ -108,6 +108,7 @@ tensorflow/lite/kernels/internal/common.h \
|
||||
tensorflow/lite/kernels/internal/compatibility.h \
|
||||
tensorflow/lite/kernels/internal/optimized/neon_check.h \
|
||||
tensorflow/lite/kernels/internal/reference/binary_function.h \
|
||||
tensorflow/lite/kernels/internal/reference/comparisons.h \
|
||||
tensorflow/lite/kernels/internal/reference/conv.h \
|
||||
tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h \
|
||||
tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h \
|
||||
|
@ -16,8 +16,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
|
@ -15,7 +15,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
|
||||
|
||||
#include "profiling/instrumentation.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
@ -112,7 +111,6 @@ inline void BroadcastComparison4DSlowImpl(
|
||||
const RuntimeShape& unextended_input1_shape, const T* input1_data,
|
||||
const RuntimeShape& unextended_input2_shape, const T* input2_data,
|
||||
const RuntimeShape& unextended_output_shape, bool* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlow");
|
||||
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||
@ -155,7 +153,6 @@ inline void BroadcastComparison4DSlowWithScaling(
|
||||
const RuntimeShape& unextended_input1_shape, const T* input1_data,
|
||||
const RuntimeShape& unextended_input2_shape, const T* input2_data,
|
||||
const RuntimeShape& unextended_output_shape, bool* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlowWithScaling");
|
||||
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||
@ -204,7 +201,6 @@ inline void BroadcastComparison4DSlowWithScaling(
|
||||
const RuntimeShape& input1_shape, const float* input1_data, \
|
||||
const RuntimeShape& input2_shape, const float* input2_data, \
|
||||
const RuntimeShape& output_shape, bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label(#name); \
|
||||
Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape, \
|
||||
input2_data, output_shape, output_data); \
|
||||
} \
|
||||
@ -214,7 +210,6 @@ inline void BroadcastComparison4DSlowWithScaling(
|
||||
const T* input1_data, const RuntimeShape& input2_shape, \
|
||||
const T* input2_data, const RuntimeShape& output_shape, \
|
||||
bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label(#name "NoScaling"); \
|
||||
ComparisonImpl<T, name##Fn>(op_params, input1_shape, input1_data, \
|
||||
input2_shape, input2_data, output_shape, \
|
||||
output_data); \
|
||||
@ -225,7 +220,6 @@ inline void BroadcastComparison4DSlowWithScaling(
|
||||
const T* input1_data, const RuntimeShape& input2_shape, \
|
||||
const T* input2_data, const RuntimeShape& output_shape, \
|
||||
bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label(#name "WithScaling/8bit"); \
|
||||
ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data, \
|
||||
input2_shape, input2_data, \
|
||||
output_shape, output_data); \
|
||||
@ -236,7 +230,6 @@ inline void BroadcastComparison4DSlowWithScaling(
|
||||
const T* input1_data, const RuntimeShape& input2_shape, \
|
||||
const T* input2_data, const RuntimeShape& output_shape, \
|
||||
bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "NoScaling"); \
|
||||
BroadcastComparison4DSlowImpl<T, name##Fn>( \
|
||||
op_params, input1_shape, input1_data, input2_shape, input2_data, \
|
||||
output_shape, output_data); \
|
||||
@ -246,7 +239,6 @@ inline void BroadcastComparison4DSlowWithScaling(
|
||||
const float* input1_data, const RuntimeShape& input2_shape, \
|
||||
const float* input2_data, const RuntimeShape& output_shape, \
|
||||
bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name); \
|
||||
BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data, \
|
||||
input2_shape, input2_data, \
|
||||
output_shape, output_data); \
|
||||
@ -257,7 +249,6 @@ inline void BroadcastComparison4DSlowWithScaling(
|
||||
const T* input1_data, const RuntimeShape& input2_shape, \
|
||||
const T* input2_data, const RuntimeShape& output_shape, \
|
||||
bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "/8bit"); \
|
||||
BroadcastComparison4DSlowWithScaling<T, name##Fn>( \
|
||||
op_params, input1_shape, input1_data, input2_shape, input2_data, \
|
||||
output_shape, output_data); \
|
||||
|
Loading…
x
Reference in New Issue
Block a user