String input support on TFLite Equal/Not_Equal op
PiperOrigin-RevId: 307754184 Change-Id: Ib46faf4aa8b37a698f6ba67ada4477cab1f1658e
This commit is contained in:
parent
79abfee5c3
commit
599c8d34ef
@ -1228,8 +1228,8 @@ def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape,
|
|||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins
|
ins
|
||||||
TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$x,
|
TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8, TFL_Str]>:$x,
|
||||||
TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$y
|
TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8, TFL_Str]>:$y
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs TFL_BoolTensor:$output);
|
let results = (outs TFL_BoolTensor:$output);
|
||||||
|
@ -27,7 +27,8 @@ constexpr int kInputTensor1 = 0;
|
|||||||
constexpr int kInputTensor2 = 1;
|
constexpr int kInputTensor2 = 1;
|
||||||
constexpr int kOutputTensor = 0;
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus ComparisonPrepareCommon(TfLiteContext* context, TfLiteNode* node,
|
||||||
|
bool is_string_allowed) {
|
||||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
|
||||||
@ -36,7 +37,9 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
// Don't support string.
|
// Don't support string.
|
||||||
TF_LITE_ENSURE(context, input1->type != kTfLiteString);
|
if (!is_string_allowed) {
|
||||||
|
TF_LITE_ENSURE(context, input1->type != kTfLiteString);
|
||||||
|
}
|
||||||
// Currently only support tensors have the same type.
|
// Currently only support tensors have the same type.
|
||||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
||||||
output->type = kTfLiteBool;
|
output->type = kTfLiteBool;
|
||||||
@ -54,6 +57,15 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
return context->ResizeTensor(context, output, output_size);
|
return context->ResizeTensor(context, output, output_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
return ComparisonPrepareCommon(context, node, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus ComparisonPrepareStringAllowed(TfLiteContext* context,
|
||||||
|
TfLiteNode* node) {
|
||||||
|
return ComparisonPrepareCommon(context, node, true);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename input_dtype, reference_ops::ComparisonFn<int32> opname>
|
template <typename input_dtype, reference_ops::ComparisonFn<int32> opname>
|
||||||
void ComparisonQuantized(const TfLiteTensor* input1, const TfLiteTensor* input2,
|
void ComparisonQuantized(const TfLiteTensor* input1, const TfLiteTensor* input2,
|
||||||
TfLiteTensor* output, bool requires_broadcast) {
|
TfLiteTensor* output, bool requires_broadcast) {
|
||||||
@ -108,6 +120,21 @@ void Comparison(const TfLiteTensor* input1, const TfLiteTensor* input2,
|
|||||||
GetTensorShape(output), GetTensorData<bool>(output));
|
GetTensorShape(output), GetTensorData<bool>(output));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <bool (*opname)(const StringRef&, const StringRef&)>
|
||||||
|
void ComparisonString(const TfLiteTensor* input1, const TfLiteTensor* input2,
|
||||||
|
TfLiteTensor* output, bool requires_broadcast) {
|
||||||
|
bool* output_data = GetTensorData<bool>(output);
|
||||||
|
if (requires_broadcast) {
|
||||||
|
reference_ops::BroadcastComparison4DSlowStringImpl<opname>(
|
||||||
|
GetTensorShape(input1), input1, GetTensorShape(input2), input2,
|
||||||
|
GetTensorShape(output), output_data);
|
||||||
|
} else {
|
||||||
|
reference_ops::ComparisonStringImpl<opname>(
|
||||||
|
GetTensorShape(input1), input1, GetTensorShape(input2), input2,
|
||||||
|
GetTensorShape(output), output_data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||||
@ -138,9 +165,14 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
ComparisonQuantized<int8_t, reference_ops::EqualFn>(
|
ComparisonQuantized<int8_t, reference_ops::EqualFn>(
|
||||||
input1, input2, output, requires_broadcast);
|
input1, input2, output, requires_broadcast);
|
||||||
break;
|
break;
|
||||||
|
case kTfLiteString:
|
||||||
|
ComparisonString<reference_ops::StringRefEqualFn>(input1, input2, output,
|
||||||
|
requires_broadcast);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
context->ReportError(
|
context->ReportError(
|
||||||
context, "Does not support type %d, requires bool|float|int|uint8",
|
context,
|
||||||
|
"Does not support type %d, requires bool|float|int|uint8|string",
|
||||||
input1->type);
|
input1->type);
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
@ -177,9 +209,14 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
ComparisonQuantized<int8_t, reference_ops::NotEqualFn>(
|
ComparisonQuantized<int8_t, reference_ops::NotEqualFn>(
|
||||||
input1, input2, output, requires_broadcast);
|
input1, input2, output, requires_broadcast);
|
||||||
break;
|
break;
|
||||||
|
case kTfLiteString:
|
||||||
|
ComparisonString<reference_ops::StringRefNotEqualFn>(
|
||||||
|
input1, input2, output, requires_broadcast);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
context->ReportError(
|
context->ReportError(
|
||||||
context, "Does not support type %d, requires bool|float|int|uint8",
|
context,
|
||||||
|
"Does not support type %d, requires bool|float|int|uint8|string",
|
||||||
input1->type);
|
input1->type);
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
@ -330,14 +367,15 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace comparisons
|
} // namespace comparisons
|
||||||
|
|
||||||
TfLiteRegistration* Register_EQUAL() {
|
TfLiteRegistration* Register_EQUAL() {
|
||||||
static TfLiteRegistration r = {
|
static TfLiteRegistration r = {nullptr, nullptr,
|
||||||
nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::EqualEval};
|
comparisons::ComparisonPrepareStringAllowed,
|
||||||
|
comparisons::EqualEval};
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteRegistration* Register_NOT_EQUAL() {
|
TfLiteRegistration* Register_NOT_EQUAL() {
|
||||||
static TfLiteRegistration r = {nullptr, nullptr,
|
static TfLiteRegistration r = {nullptr, nullptr,
|
||||||
comparisons::ComparisonPrepare,
|
comparisons::ComparisonPrepareStringAllowed,
|
||||||
comparisons::NotEqualEval};
|
comparisons::NotEqualEval};
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
@ -125,6 +125,20 @@ TEST(ComparisonsTest, EqualInt) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ComparisonsTest, EqualString) {
|
||||||
|
if (SingleOpModel::GetForceUseNnapi()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ComparisonOpModel model({1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}, TensorType_STRING,
|
||||||
|
BuiltinOperator_EQUAL);
|
||||||
|
model.PopulateTensor<std::string>(model.input1(), {"A", "B", "C", "D"});
|
||||||
|
model.PopulateTensor<std::string>(model.input2(), {"A", "C", "B", "D"});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4, 1));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ComparisonsTest, EqualBroadcast) {
|
TEST(ComparisonsTest, EqualBroadcast) {
|
||||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
|
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
|
||||||
BuiltinOperator_EQUAL);
|
BuiltinOperator_EQUAL);
|
||||||
@ -148,6 +162,20 @@ TEST(ComparisonsTest, EqualBroadcastTwoD) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ComparisonsTest, EqualBroadcastString) {
|
||||||
|
if (SingleOpModel::GetForceUseNnapi()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_STRING,
|
||||||
|
BuiltinOperator_EQUAL);
|
||||||
|
model.PopulateTensor<std::string>(model.input1(), {"A", "B", "A", "B"});
|
||||||
|
model.PopulateTensor<std::string>(model.input2(), {"A"});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, false));
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ComparisonsTest, NotEqualBool) {
|
TEST(ComparisonsTest, NotEqualBool) {
|
||||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_BOOL,
|
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_BOOL,
|
||||||
BuiltinOperator_NOT_EQUAL);
|
BuiltinOperator_NOT_EQUAL);
|
||||||
@ -181,6 +209,20 @@ TEST(ComparisonsTest, NotEqualInt) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ComparisonsTest, NotEqualString) {
|
||||||
|
if (SingleOpModel::GetForceUseNnapi()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ComparisonOpModel model({1, 1, 1, 1, 4}, {1, 1, 1, 1, 4}, TensorType_STRING,
|
||||||
|
BuiltinOperator_NOT_EQUAL);
|
||||||
|
model.PopulateTensor<std::string>(model.input1(), {"A", "B", "C", "D"});
|
||||||
|
model.PopulateTensor<std::string>(model.input2(), {"A", "C", "B", "D"});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 1, 4));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ComparisonsTest, NotEqualBroadcast) {
|
TEST(ComparisonsTest, NotEqualBroadcast) {
|
||||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
|
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
|
||||||
BuiltinOperator_NOT_EQUAL);
|
BuiltinOperator_NOT_EQUAL);
|
||||||
@ -204,6 +246,20 @@ TEST(ComparisonsTest, NotEqualBroadcastTwoD) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ComparisonsTest, NotEqualBroadcastString) {
|
||||||
|
if (SingleOpModel::GetForceUseNnapi()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_STRING,
|
||||||
|
BuiltinOperator_NOT_EQUAL);
|
||||||
|
model.PopulateTensor<std::string>(model.input1(), {"A", "B", "A", "B"});
|
||||||
|
model.PopulateTensor<std::string>(model.input2(), {"A"});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, true));
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ComparisonsTest, GreaterFloat) {
|
TEST(ComparisonsTest, GreaterFloat) {
|
||||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
|
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
|
||||||
BuiltinOperator_GREATER);
|
BuiltinOperator_GREATER);
|
||||||
|
@ -504,6 +504,7 @@ cc_library(
|
|||||||
":tensor",
|
":tensor",
|
||||||
":tensor_utils",
|
":tensor_utils",
|
||||||
":types",
|
":types",
|
||||||
|
"//tensorflow/lite:string_util",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
"//tensorflow/lite/kernels:op_macros",
|
"//tensorflow/lite/kernels:op_macros",
|
||||||
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
|
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
|
||||||
@ -566,6 +567,7 @@ cc_library(
|
|||||||
"//tensorflow/lite/kernels:op_macros",
|
"//tensorflow/lite/kernels:op_macros",
|
||||||
"@ruy//ruy/profiler:instrumentation",
|
"@ruy//ruy/profiler:instrumentation",
|
||||||
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
|
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
|
||||||
|
"//tensorflow/lite:string_util",
|
||||||
] + select({
|
] + select({
|
||||||
":haswell": tflite_deps_intel,
|
":haswell": tflite_deps_intel,
|
||||||
":ios_x86_64": tflite_deps_intel,
|
":ios_x86_64": tflite_deps_intel,
|
||||||
|
@ -15,8 +15,10 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
|
||||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/common.h"
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/types.h"
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
@ -49,6 +51,18 @@ inline bool LessEqualFn(T lhs, T rhs) {
|
|||||||
return lhs <= rhs;
|
return lhs <= rhs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool StringRefEqualFn(const StringRef& lhs, const StringRef& rhs) {
|
||||||
|
if (lhs.len != rhs.len) return false;
|
||||||
|
for (int i = 0; i < lhs.len; ++i) {
|
||||||
|
if (lhs.str[i] != rhs.str[i]) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool StringRefNotEqualFn(const StringRef& lhs, const StringRef& rhs) {
|
||||||
|
return !StringRefEqualFn(lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
using ComparisonFn = bool (*)(T, T);
|
using ComparisonFn = bool (*)(T, T);
|
||||||
|
|
||||||
@ -64,6 +78,22 @@ inline void ComparisonImpl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <bool (*F)(const StringRef&, const StringRef&)>
|
||||||
|
inline void ComparisonStringImpl(const RuntimeShape& input1_shape,
|
||||||
|
const TfLiteTensor* input1,
|
||||||
|
const RuntimeShape& input2_shape,
|
||||||
|
const TfLiteTensor* input2,
|
||||||
|
const RuntimeShape& output_shape,
|
||||||
|
bool* output_data) {
|
||||||
|
const int64_t flatsize =
|
||||||
|
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||||
|
for (int64_t i = 0; i < flatsize; ++i) {
|
||||||
|
const auto lhs = GetString(input1, i);
|
||||||
|
const auto rhs = GetString(input2, i);
|
||||||
|
output_data[i] = F(lhs, rhs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <ComparisonFn<float> F>
|
template <ComparisonFn<float> F>
|
||||||
inline void Comparison(const ComparisonParams& op_params,
|
inline void Comparison(const ComparisonParams& op_params,
|
||||||
const RuntimeShape& input1_shape,
|
const RuntimeShape& input1_shape,
|
||||||
@ -105,35 +135,76 @@ inline void ComparisonWithScaling(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct BroadcastComparison4DSlowCommon {
|
||||||
|
const RuntimeShape output_shape;
|
||||||
|
NdArrayDesc<4> desc1;
|
||||||
|
NdArrayDesc<4> desc2;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline BroadcastComparison4DSlowCommon BroadcastComparison4DSlowPreprocess(
|
||||||
|
const RuntimeShape& unextended_input1_shape,
|
||||||
|
const RuntimeShape& unextended_input2_shape,
|
||||||
|
const RuntimeShape& unextended_output_shape) {
|
||||||
|
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||||
|
NdArrayDesc<4> desc1;
|
||||||
|
NdArrayDesc<4> desc2;
|
||||||
|
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
|
||||||
|
unextended_input2_shape, &desc1, &desc2);
|
||||||
|
return {RuntimeShape::ExtendedShape(4, unextended_output_shape), desc1,
|
||||||
|
desc2};
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, ComparisonFn<T> F>
|
template <typename T, ComparisonFn<T> F>
|
||||||
inline void BroadcastComparison4DSlowImpl(
|
inline void BroadcastComparison4DSlowImpl(
|
||||||
const ComparisonParams& op_params,
|
const ComparisonParams& op_params,
|
||||||
const RuntimeShape& unextended_input1_shape, const T* input1_data,
|
const RuntimeShape& unextended_input1_shape, const T* input1_data,
|
||||||
const RuntimeShape& unextended_input2_shape, const T* input2_data,
|
const RuntimeShape& unextended_input2_shape, const T* input2_data,
|
||||||
const RuntimeShape& unextended_output_shape, bool* output_data) {
|
const RuntimeShape& unextended_output_shape, bool* output_data) {
|
||||||
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
|
const BroadcastComparison4DSlowCommon dims =
|
||||||
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
|
BroadcastComparison4DSlowPreprocess(unextended_input1_shape,
|
||||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
unextended_input2_shape,
|
||||||
const RuntimeShape output_shape =
|
unextended_output_shape);
|
||||||
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
|
||||||
|
|
||||||
NdArrayDesc<4> desc1;
|
for (int b = 0; b < dims.output_shape.Dims(0); ++b) {
|
||||||
NdArrayDesc<4> desc2;
|
for (int y = 0; y < dims.output_shape.Dims(1); ++y) {
|
||||||
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
|
for (int x = 0; x < dims.output_shape.Dims(2); ++x) {
|
||||||
unextended_input2_shape, &desc1, &desc2);
|
for (int c = 0; c < dims.output_shape.Dims(3); ++c) {
|
||||||
|
output_data[Offset(dims.output_shape, b, y, x, c)] =
|
||||||
for (int b = 0; b < output_shape.Dims(0); ++b) {
|
F(input1_data[SubscriptToIndex(dims.desc1, b, y, x, c)],
|
||||||
for (int y = 0; y < output_shape.Dims(1); ++y) {
|
input2_data[SubscriptToIndex(dims.desc2, b, y, x, c)]);
|
||||||
for (int x = 0; x < output_shape.Dims(2); ++x) {
|
|
||||||
for (int c = 0; c < output_shape.Dims(3); ++c) {
|
|
||||||
output_data[Offset(output_shape, b, y, x, c)] =
|
|
||||||
F(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
|
|
||||||
input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <bool (*F)(const StringRef&, const StringRef&)>
|
||||||
|
inline void BroadcastComparison4DSlowStringImpl(
|
||||||
|
const RuntimeShape& unextended_input1_shape, const TfLiteTensor* input1,
|
||||||
|
const RuntimeShape& unextended_input2_shape, const TfLiteTensor* input2,
|
||||||
|
const RuntimeShape& unextended_output_shape, bool* output_data) {
|
||||||
|
const BroadcastComparison4DSlowCommon dims =
|
||||||
|
BroadcastComparison4DSlowPreprocess(unextended_input1_shape,
|
||||||
|
unextended_input2_shape,
|
||||||
|
unextended_output_shape);
|
||||||
|
|
||||||
|
for (int b = 0; b < dims.output_shape.Dims(0); ++b) {
|
||||||
|
for (int y = 0; y < dims.output_shape.Dims(1); ++y) {
|
||||||
|
for (int x = 0; x < dims.output_shape.Dims(2); ++x) {
|
||||||
|
for (int c = 0; c < dims.output_shape.Dims(3); ++c) {
|
||||||
|
const auto lhs =
|
||||||
|
GetString(input1, SubscriptToIndex(dims.desc1, b, y, x, c));
|
||||||
|
const auto rhs =
|
||||||
|
GetString(input2, SubscriptToIndex(dims.desc2, b, y, x, c));
|
||||||
|
output_data[Offset(dims.output_shape, b, y, x, c)] = F(lhs, rhs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <ComparisonFn<float> F>
|
template <ComparisonFn<float> F>
|
||||||
inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
|
inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
|
||||||
const RuntimeShape& input1_shape,
|
const RuntimeShape& input1_shape,
|
||||||
@ -153,16 +224,10 @@ inline void BroadcastComparison4DSlowWithScaling(
|
|||||||
const RuntimeShape& unextended_input1_shape, const T* input1_data,
|
const RuntimeShape& unextended_input1_shape, const T* input1_data,
|
||||||
const RuntimeShape& unextended_input2_shape, const T* input2_data,
|
const RuntimeShape& unextended_input2_shape, const T* input2_data,
|
||||||
const RuntimeShape& unextended_output_shape, bool* output_data) {
|
const RuntimeShape& unextended_output_shape, bool* output_data) {
|
||||||
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
|
const BroadcastComparison4DSlowCommon dims =
|
||||||
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
|
BroadcastComparison4DSlowPreprocess(unextended_input1_shape,
|
||||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
unextended_input2_shape,
|
||||||
const RuntimeShape output_shape =
|
unextended_output_shape);
|
||||||
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
|
||||||
|
|
||||||
NdArrayDesc<4> desc1;
|
|
||||||
NdArrayDesc<4> desc2;
|
|
||||||
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
|
|
||||||
unextended_input2_shape, &desc1, &desc2);
|
|
||||||
|
|
||||||
int left_shift = op_params.left_shift;
|
int left_shift = op_params.left_shift;
|
||||||
int32 input1_offset = op_params.input1_offset;
|
int32 input1_offset = op_params.input1_offset;
|
||||||
@ -172,14 +237,16 @@ inline void BroadcastComparison4DSlowWithScaling(
|
|||||||
int32 input2_multiplier = op_params.input2_multiplier;
|
int32 input2_multiplier = op_params.input2_multiplier;
|
||||||
int input2_shift = op_params.input2_shift;
|
int input2_shift = op_params.input2_shift;
|
||||||
|
|
||||||
for (int b = 0; b < output_shape.Dims(0); ++b) {
|
for (int b = 0; b < dims.output_shape.Dims(0); ++b) {
|
||||||
for (int y = 0; y < output_shape.Dims(1); ++y) {
|
for (int y = 0; y < dims.output_shape.Dims(1); ++y) {
|
||||||
for (int x = 0; x < output_shape.Dims(2); ++x) {
|
for (int x = 0; x < dims.output_shape.Dims(2); ++x) {
|
||||||
for (int c = 0; c < output_shape.Dims(3); ++c) {
|
for (int c = 0; c < dims.output_shape.Dims(3); ++c) {
|
||||||
const int32 input1_val =
|
const int32 input1_val =
|
||||||
input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)];
|
input1_offset +
|
||||||
|
input1_data[SubscriptToIndex(dims.desc1, b, y, x, c)];
|
||||||
const int32 input2_val =
|
const int32 input2_val =
|
||||||
input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)];
|
input2_offset +
|
||||||
|
input2_data[SubscriptToIndex(dims.desc2, b, y, x, c)];
|
||||||
const int32 shifted_input1_val = input1_val * (1 << left_shift);
|
const int32 shifted_input1_val = input1_val * (1 << left_shift);
|
||||||
const int32 shifted_input2_val = input2_val * (1 << left_shift);
|
const int32 shifted_input2_val = input2_val * (1 << left_shift);
|
||||||
const int32 scaled_input1_val =
|
const int32 scaled_input1_val =
|
||||||
@ -188,7 +255,7 @@ inline void BroadcastComparison4DSlowWithScaling(
|
|||||||
const int32 scaled_input2_val =
|
const int32 scaled_input2_val =
|
||||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||||
shifted_input2_val, input2_multiplier, input2_shift);
|
shifted_input2_val, input2_multiplier, input2_shift);
|
||||||
output_data[Offset(output_shape, b, y, x, c)] =
|
output_data[Offset(dims.output_shape, b, y, x, c)] =
|
||||||
F(scaled_input1_val, scaled_input2_val);
|
F(scaled_input1_val, scaled_input2_val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -224,10 +224,10 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
/* max_version = */ 3);
|
/* max_version = */ 3);
|
||||||
AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL(),
|
AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL(),
|
||||||
/* min_version = */ 1,
|
/* min_version = */ 1,
|
||||||
/* max_version = */ 2);
|
/* max_version = */ 3);
|
||||||
AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL(),
|
AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL(),
|
||||||
/* min_version = */ 1,
|
/* min_version = */ 1,
|
||||||
/* max_version = */ 2);
|
/* max_version = */ 3);
|
||||||
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT());
|
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT());
|
||||||
AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
|
AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
|
||||||
AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE());
|
AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE());
|
||||||
|
@ -28,7 +28,7 @@ def make_equal_tests(options):
|
|||||||
"""Make a set of tests to do equal."""
|
"""Make a set of tests to do equal."""
|
||||||
|
|
||||||
test_parameters = [{
|
test_parameters = [{
|
||||||
"input_dtype": [tf.float32, tf.int32, tf.int64],
|
"input_dtype": [tf.float32, tf.int32, tf.int64, tf.string],
|
||||||
"input_shape_pair": [([], []), ([1, 1, 1, 3], [1, 1, 1, 3]),
|
"input_shape_pair": [([], []), ([1, 1, 1, 3], [1, 1, 1, 3]),
|
||||||
([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
|
([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
|
||||||
([5, 5], [1]), ([10], [2, 4, 10])],
|
([5, 5], [1]), ([10], [2, 4, 10])],
|
||||||
@ -60,4 +60,4 @@ def make_equal_tests(options):
|
|||||||
test_parameters,
|
test_parameters,
|
||||||
build_graph,
|
build_graph,
|
||||||
build_inputs,
|
build_inputs,
|
||||||
expected_tf_failures=3)
|
expected_tf_failures=4)
|
||||||
|
@ -28,7 +28,7 @@ def make_not_equal_tests(options):
|
|||||||
"""Make a set of tests to do not equal."""
|
"""Make a set of tests to do not equal."""
|
||||||
|
|
||||||
test_parameters = [{
|
test_parameters = [{
|
||||||
"input_dtype": [tf.float32, tf.int32, tf.int64],
|
"input_dtype": [tf.float32, tf.int32, tf.int64, tf.string],
|
||||||
"input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
|
"input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
|
||||||
([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
|
([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
|
||||||
([5, 5], [1]), ([10], [2, 4, 10])],
|
([5, 5], [1]), ([10], [2, 4, 10])],
|
||||||
@ -60,4 +60,4 @@ def make_not_equal_tests(options):
|
|||||||
test_parameters,
|
test_parameters,
|
||||||
build_graph,
|
build_graph,
|
||||||
build_inputs,
|
build_inputs,
|
||||||
expected_tf_failures=3)
|
expected_tf_failures=4)
|
||||||
|
@ -183,8 +183,10 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
|
|||||||
{{OperatorType::kReverseSequence, 1}, "1.14.0"},
|
{{OperatorType::kReverseSequence, 1}, "1.14.0"},
|
||||||
{{OperatorType::kEqual, 1}, "1.14.0"},
|
{{OperatorType::kEqual, 1}, "1.14.0"},
|
||||||
{{OperatorType::kEqual, 2}, "1.14.0"},
|
{{OperatorType::kEqual, 2}, "1.14.0"},
|
||||||
|
{{OperatorType::kEqual, 3}, kPendingReleaseOpVersion},
|
||||||
{{OperatorType::kNotEqual, 1}, "1.14.0"},
|
{{OperatorType::kNotEqual, 1}, "1.14.0"},
|
||||||
{{OperatorType::kNotEqual, 2}, "1.14.0"},
|
{{OperatorType::kNotEqual, 2}, "1.14.0"},
|
||||||
|
{{OperatorType::kNotEqual, 3}, kPendingReleaseOpVersion},
|
||||||
{{OperatorType::kGreater, 1}, "1.14.0"},
|
{{OperatorType::kGreater, 1}, "1.14.0"},
|
||||||
{{OperatorType::kGreater, 2}, "1.14.0"},
|
{{OperatorType::kGreater, 2}, "1.14.0"},
|
||||||
{{OperatorType::kGreaterEqual, 1}, "1.14.0"},
|
{{OperatorType::kGreaterEqual, 1}, "1.14.0"},
|
||||||
|
@ -407,6 +407,18 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||||||
}
|
}
|
||||||
return 1;
|
return 1;
|
||||||
|
|
||||||
|
case BuiltinOperator_EQUAL:
|
||||||
|
case BuiltinOperator_NOT_EQUAL:
|
||||||
|
if (!op_sig.input_types.empty()) {
|
||||||
|
if (op_sig.input_types.at(0) == TensorType_STRING) {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
|
||||||
case BuiltinOperator_ADD:
|
case BuiltinOperator_ADD:
|
||||||
case BuiltinOperator_CONCATENATION:
|
case BuiltinOperator_CONCATENATION:
|
||||||
case BuiltinOperator_PAD:
|
case BuiltinOperator_PAD:
|
||||||
@ -426,8 +438,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||||||
case BuiltinOperator_TOPK_V2:
|
case BuiltinOperator_TOPK_V2:
|
||||||
case BuiltinOperator_ARG_MAX:
|
case BuiltinOperator_ARG_MAX:
|
||||||
case BuiltinOperator_ARG_MIN:
|
case BuiltinOperator_ARG_MIN:
|
||||||
case BuiltinOperator_EQUAL:
|
|
||||||
case BuiltinOperator_NOT_EQUAL:
|
|
||||||
case BuiltinOperator_GREATER:
|
case BuiltinOperator_GREATER:
|
||||||
case BuiltinOperator_GREATER_EQUAL:
|
case BuiltinOperator_GREATER_EQUAL:
|
||||||
case BuiltinOperator_LESS:
|
case BuiltinOperator_LESS:
|
||||||
|
@ -86,10 +86,20 @@ void SimpleOutputVersioningTest(BuiltinOperator op) {
|
|||||||
|
|
||||||
TEST(OpVersionTest, VersioningEqualTest) {
|
TEST(OpVersionTest, VersioningEqualTest) {
|
||||||
SimpleVersioningTest(BuiltinOperator_EQUAL);
|
SimpleVersioningTest(BuiltinOperator_EQUAL);
|
||||||
|
OpSignature fake_op_sig = {
|
||||||
|
.op = BuiltinOperator_EQUAL,
|
||||||
|
.input_types = std::vector<TensorType>{TensorType_STRING},
|
||||||
|
};
|
||||||
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(OpVersionTest, VersioningNotEqualTest) {
|
TEST(OpVersionTest, VersioningNotEqualTest) {
|
||||||
SimpleVersioningTest(BuiltinOperator_NOT_EQUAL);
|
SimpleVersioningTest(BuiltinOperator_NOT_EQUAL);
|
||||||
|
OpSignature fake_op_sig = {
|
||||||
|
.op = BuiltinOperator_NOT_EQUAL,
|
||||||
|
.input_types = std::vector<TensorType>{TensorType_STRING},
|
||||||
|
};
|
||||||
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(OpVersionTest, VersioningLessTest) {
|
TEST(OpVersionTest, VersioningLessTest) {
|
||||||
|
Loading…
Reference in New Issue
Block a user