String input support on TFLite Equal/Not_Equal op
PiperOrigin-RevId: 307754184 Change-Id: Ib46faf4aa8b37a698f6ba67ada4477cab1f1658e
This commit is contained in:
parent
79abfee5c3
commit
599c8d34ef
@ -1228,8 +1228,8 @@ def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape,
|
||||
|
||||
let arguments = (
|
||||
ins
|
||||
TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$x,
|
||||
TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$y
|
||||
TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8, TFL_Str]>:$x,
|
||||
TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8, TFL_Str]>:$y
|
||||
);
|
||||
|
||||
let results = (outs TFL_BoolTensor:$output);
|
||||
|
@ -27,7 +27,8 @@ constexpr int kInputTensor1 = 0;
|
||||
constexpr int kInputTensor2 = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus ComparisonPrepareCommon(TfLiteContext* context, TfLiteNode* node,
|
||||
bool is_string_allowed) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
@ -36,7 +37,9 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
// Don't support string.
|
||||
TF_LITE_ENSURE(context, input1->type != kTfLiteString);
|
||||
if (!is_string_allowed) {
|
||||
TF_LITE_ENSURE(context, input1->type != kTfLiteString);
|
||||
}
|
||||
// Currently only support tensors have the same type.
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
|
||||
output->type = kTfLiteBool;
|
||||
@ -54,6 +57,15 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return context->ResizeTensor(context, output, output_size);
|
||||
}
|
||||
|
||||
TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return ComparisonPrepareCommon(context, node, false);
|
||||
}
|
||||
|
||||
TfLiteStatus ComparisonPrepareStringAllowed(TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
return ComparisonPrepareCommon(context, node, true);
|
||||
}
|
||||
|
||||
template <typename input_dtype, reference_ops::ComparisonFn<int32> opname>
|
||||
void ComparisonQuantized(const TfLiteTensor* input1, const TfLiteTensor* input2,
|
||||
TfLiteTensor* output, bool requires_broadcast) {
|
||||
@ -108,6 +120,21 @@ void Comparison(const TfLiteTensor* input1, const TfLiteTensor* input2,
|
||||
GetTensorShape(output), GetTensorData<bool>(output));
|
||||
}
|
||||
|
||||
template <bool (*opname)(const StringRef&, const StringRef&)>
|
||||
void ComparisonString(const TfLiteTensor* input1, const TfLiteTensor* input2,
|
||||
TfLiteTensor* output, bool requires_broadcast) {
|
||||
bool* output_data = GetTensorData<bool>(output);
|
||||
if (requires_broadcast) {
|
||||
reference_ops::BroadcastComparison4DSlowStringImpl<opname>(
|
||||
GetTensorShape(input1), input1, GetTensorShape(input2), input2,
|
||||
GetTensorShape(output), output_data);
|
||||
} else {
|
||||
reference_ops::ComparisonStringImpl<opname>(
|
||||
GetTensorShape(input1), input1, GetTensorShape(input2), input2,
|
||||
GetTensorShape(output), output_data);
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
@ -138,9 +165,14 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
ComparisonQuantized<int8_t, reference_ops::EqualFn>(
|
||||
input1, input2, output, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteString:
|
||||
ComparisonString<reference_ops::StringRefEqualFn>(input1, input2, output,
|
||||
requires_broadcast);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(
|
||||
context, "Does not support type %d, requires bool|float|int|uint8",
|
||||
context,
|
||||
"Does not support type %d, requires bool|float|int|uint8|string",
|
||||
input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
@ -177,9 +209,14 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
ComparisonQuantized<int8_t, reference_ops::NotEqualFn>(
|
||||
input1, input2, output, requires_broadcast);
|
||||
break;
|
||||
case kTfLiteString:
|
||||
ComparisonString<reference_ops::StringRefNotEqualFn>(
|
||||
input1, input2, output, requires_broadcast);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(
|
||||
context, "Does not support type %d, requires bool|float|int|uint8",
|
||||
context,
|
||||
"Does not support type %d, requires bool|float|int|uint8|string",
|
||||
input1->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
@ -330,14 +367,15 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace comparisons
|
||||
|
||||
TfLiteRegistration* Register_EQUAL() {
|
||||
static TfLiteRegistration r = {
|
||||
nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::EqualEval};
|
||||
static TfLiteRegistration r = {nullptr, nullptr,
|
||||
comparisons::ComparisonPrepareStringAllowed,
|
||||
comparisons::EqualEval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_NOT_EQUAL() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr,
|
||||
comparisons::ComparisonPrepare,
|
||||
comparisons::ComparisonPrepareStringAllowed,
|
||||
comparisons::NotEqualEval};
|
||||
return &r;
|
||||
}
|
||||
|
@ -125,6 +125,20 @@ TEST(ComparisonsTest, EqualInt) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, EqualString) {
|
||||
if (SingleOpModel::GetForceUseNnapi()) {
|
||||
return;
|
||||
}
|
||||
ComparisonOpModel model({1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}, TensorType_STRING,
|
||||
BuiltinOperator_EQUAL);
|
||||
model.PopulateTensor<std::string>(model.input1(), {"A", "B", "C", "D"});
|
||||
model.PopulateTensor<std::string>(model.input2(), {"A", "C", "B", "D"});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4, 1));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, EqualBroadcast) {
|
||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
|
||||
BuiltinOperator_EQUAL);
|
||||
@ -148,6 +162,20 @@ TEST(ComparisonsTest, EqualBroadcastTwoD) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, EqualBroadcastString) {
|
||||
if (SingleOpModel::GetForceUseNnapi()) {
|
||||
return;
|
||||
}
|
||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_STRING,
|
||||
BuiltinOperator_EQUAL);
|
||||
model.PopulateTensor<std::string>(model.input1(), {"A", "B", "A", "B"});
|
||||
model.PopulateTensor<std::string>(model.input2(), {"A"});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, false));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, NotEqualBool) {
|
||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_BOOL,
|
||||
BuiltinOperator_NOT_EQUAL);
|
||||
@ -181,6 +209,20 @@ TEST(ComparisonsTest, NotEqualInt) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, NotEqualString) {
|
||||
if (SingleOpModel::GetForceUseNnapi()) {
|
||||
return;
|
||||
}
|
||||
ComparisonOpModel model({1, 1, 1, 1, 4}, {1, 1, 1, 1, 4}, TensorType_STRING,
|
||||
BuiltinOperator_NOT_EQUAL);
|
||||
model.PopulateTensor<std::string>(model.input1(), {"A", "B", "C", "D"});
|
||||
model.PopulateTensor<std::string>(model.input2(), {"A", "C", "B", "D"});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 1, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, NotEqualBroadcast) {
|
||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
|
||||
BuiltinOperator_NOT_EQUAL);
|
||||
@ -204,6 +246,20 @@ TEST(ComparisonsTest, NotEqualBroadcastTwoD) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, NotEqualBroadcastString) {
|
||||
if (SingleOpModel::GetForceUseNnapi()) {
|
||||
return;
|
||||
}
|
||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_STRING,
|
||||
BuiltinOperator_NOT_EQUAL);
|
||||
model.PopulateTensor<std::string>(model.input1(), {"A", "B", "A", "B"});
|
||||
model.PopulateTensor<std::string>(model.input2(), {"A"});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, true));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, GreaterFloat) {
|
||||
ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
|
||||
BuiltinOperator_GREATER);
|
||||
|
@ -504,6 +504,7 @@ cc_library(
|
||||
":tensor",
|
||||
":tensor_utils",
|
||||
":types",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels:op_macros",
|
||||
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
|
||||
@ -566,6 +567,7 @@ cc_library(
|
||||
"//tensorflow/lite/kernels:op_macros",
|
||||
"@ruy//ruy/profiler:instrumentation",
|
||||
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
|
||||
"//tensorflow/lite:string_util",
|
||||
] + select({
|
||||
":haswell": tflite_deps_intel,
|
||||
":ios_x86_64": tflite_deps_intel,
|
||||
|
@ -15,8 +15,10 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
@ -49,6 +51,18 @@ inline bool LessEqualFn(T lhs, T rhs) {
|
||||
return lhs <= rhs;
|
||||
}
|
||||
|
||||
inline bool StringRefEqualFn(const StringRef& lhs, const StringRef& rhs) {
|
||||
if (lhs.len != rhs.len) return false;
|
||||
for (int i = 0; i < lhs.len; ++i) {
|
||||
if (lhs.str[i] != rhs.str[i]) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool StringRefNotEqualFn(const StringRef& lhs, const StringRef& rhs) {
|
||||
return !StringRefEqualFn(lhs, rhs);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using ComparisonFn = bool (*)(T, T);
|
||||
|
||||
@ -64,6 +78,22 @@ inline void ComparisonImpl(
|
||||
}
|
||||
}
|
||||
|
||||
template <bool (*F)(const StringRef&, const StringRef&)>
|
||||
inline void ComparisonStringImpl(const RuntimeShape& input1_shape,
|
||||
const TfLiteTensor* input1,
|
||||
const RuntimeShape& input2_shape,
|
||||
const TfLiteTensor* input2,
|
||||
const RuntimeShape& output_shape,
|
||||
bool* output_data) {
|
||||
const int64_t flatsize =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
for (int64_t i = 0; i < flatsize; ++i) {
|
||||
const auto lhs = GetString(input1, i);
|
||||
const auto rhs = GetString(input2, i);
|
||||
output_data[i] = F(lhs, rhs);
|
||||
}
|
||||
}
|
||||
|
||||
template <ComparisonFn<float> F>
|
||||
inline void Comparison(const ComparisonParams& op_params,
|
||||
const RuntimeShape& input1_shape,
|
||||
@ -105,35 +135,76 @@ inline void ComparisonWithScaling(
|
||||
}
|
||||
}
|
||||
|
||||
struct BroadcastComparison4DSlowCommon {
|
||||
const RuntimeShape output_shape;
|
||||
NdArrayDesc<4> desc1;
|
||||
NdArrayDesc<4> desc2;
|
||||
};
|
||||
|
||||
inline BroadcastComparison4DSlowCommon BroadcastComparison4DSlowPreprocess(
|
||||
const RuntimeShape& unextended_input1_shape,
|
||||
const RuntimeShape& unextended_input2_shape,
|
||||
const RuntimeShape& unextended_output_shape) {
|
||||
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||
NdArrayDesc<4> desc1;
|
||||
NdArrayDesc<4> desc2;
|
||||
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
|
||||
unextended_input2_shape, &desc1, &desc2);
|
||||
return {RuntimeShape::ExtendedShape(4, unextended_output_shape), desc1,
|
||||
desc2};
|
||||
}
|
||||
|
||||
template <typename T, ComparisonFn<T> F>
|
||||
inline void BroadcastComparison4DSlowImpl(
|
||||
const ComparisonParams& op_params,
|
||||
const RuntimeShape& unextended_input1_shape, const T* input1_data,
|
||||
const RuntimeShape& unextended_input2_shape, const T* input2_data,
|
||||
const RuntimeShape& unextended_output_shape, bool* output_data) {
|
||||
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||
const RuntimeShape output_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
||||
const BroadcastComparison4DSlowCommon dims =
|
||||
BroadcastComparison4DSlowPreprocess(unextended_input1_shape,
|
||||
unextended_input2_shape,
|
||||
unextended_output_shape);
|
||||
|
||||
NdArrayDesc<4> desc1;
|
||||
NdArrayDesc<4> desc2;
|
||||
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
|
||||
unextended_input2_shape, &desc1, &desc2);
|
||||
|
||||
for (int b = 0; b < output_shape.Dims(0); ++b) {
|
||||
for (int y = 0; y < output_shape.Dims(1); ++y) {
|
||||
for (int x = 0; x < output_shape.Dims(2); ++x) {
|
||||
for (int c = 0; c < output_shape.Dims(3); ++c) {
|
||||
output_data[Offset(output_shape, b, y, x, c)] =
|
||||
F(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
|
||||
input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
|
||||
for (int b = 0; b < dims.output_shape.Dims(0); ++b) {
|
||||
for (int y = 0; y < dims.output_shape.Dims(1); ++y) {
|
||||
for (int x = 0; x < dims.output_shape.Dims(2); ++x) {
|
||||
for (int c = 0; c < dims.output_shape.Dims(3); ++c) {
|
||||
output_data[Offset(dims.output_shape, b, y, x, c)] =
|
||||
F(input1_data[SubscriptToIndex(dims.desc1, b, y, x, c)],
|
||||
input2_data[SubscriptToIndex(dims.desc2, b, y, x, c)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool (*F)(const StringRef&, const StringRef&)>
|
||||
inline void BroadcastComparison4DSlowStringImpl(
|
||||
const RuntimeShape& unextended_input1_shape, const TfLiteTensor* input1,
|
||||
const RuntimeShape& unextended_input2_shape, const TfLiteTensor* input2,
|
||||
const RuntimeShape& unextended_output_shape, bool* output_data) {
|
||||
const BroadcastComparison4DSlowCommon dims =
|
||||
BroadcastComparison4DSlowPreprocess(unextended_input1_shape,
|
||||
unextended_input2_shape,
|
||||
unextended_output_shape);
|
||||
|
||||
for (int b = 0; b < dims.output_shape.Dims(0); ++b) {
|
||||
for (int y = 0; y < dims.output_shape.Dims(1); ++y) {
|
||||
for (int x = 0; x < dims.output_shape.Dims(2); ++x) {
|
||||
for (int c = 0; c < dims.output_shape.Dims(3); ++c) {
|
||||
const auto lhs =
|
||||
GetString(input1, SubscriptToIndex(dims.desc1, b, y, x, c));
|
||||
const auto rhs =
|
||||
GetString(input2, SubscriptToIndex(dims.desc2, b, y, x, c));
|
||||
output_data[Offset(dims.output_shape, b, y, x, c)] = F(lhs, rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <ComparisonFn<float> F>
|
||||
inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
|
||||
const RuntimeShape& input1_shape,
|
||||
@ -153,16 +224,10 @@ inline void BroadcastComparison4DSlowWithScaling(
|
||||
const RuntimeShape& unextended_input1_shape, const T* input1_data,
|
||||
const RuntimeShape& unextended_input2_shape, const T* input2_data,
|
||||
const RuntimeShape& unextended_output_shape, bool* output_data) {
|
||||
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||
const RuntimeShape output_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
||||
|
||||
NdArrayDesc<4> desc1;
|
||||
NdArrayDesc<4> desc2;
|
||||
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
|
||||
unextended_input2_shape, &desc1, &desc2);
|
||||
const BroadcastComparison4DSlowCommon dims =
|
||||
BroadcastComparison4DSlowPreprocess(unextended_input1_shape,
|
||||
unextended_input2_shape,
|
||||
unextended_output_shape);
|
||||
|
||||
int left_shift = op_params.left_shift;
|
||||
int32 input1_offset = op_params.input1_offset;
|
||||
@ -172,14 +237,16 @@ inline void BroadcastComparison4DSlowWithScaling(
|
||||
int32 input2_multiplier = op_params.input2_multiplier;
|
||||
int input2_shift = op_params.input2_shift;
|
||||
|
||||
for (int b = 0; b < output_shape.Dims(0); ++b) {
|
||||
for (int y = 0; y < output_shape.Dims(1); ++y) {
|
||||
for (int x = 0; x < output_shape.Dims(2); ++x) {
|
||||
for (int c = 0; c < output_shape.Dims(3); ++c) {
|
||||
for (int b = 0; b < dims.output_shape.Dims(0); ++b) {
|
||||
for (int y = 0; y < dims.output_shape.Dims(1); ++y) {
|
||||
for (int x = 0; x < dims.output_shape.Dims(2); ++x) {
|
||||
for (int c = 0; c < dims.output_shape.Dims(3); ++c) {
|
||||
const int32 input1_val =
|
||||
input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)];
|
||||
input1_offset +
|
||||
input1_data[SubscriptToIndex(dims.desc1, b, y, x, c)];
|
||||
const int32 input2_val =
|
||||
input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)];
|
||||
input2_offset +
|
||||
input2_data[SubscriptToIndex(dims.desc2, b, y, x, c)];
|
||||
const int32 shifted_input1_val = input1_val * (1 << left_shift);
|
||||
const int32 shifted_input2_val = input2_val * (1 << left_shift);
|
||||
const int32 scaled_input1_val =
|
||||
@ -188,7 +255,7 @@ inline void BroadcastComparison4DSlowWithScaling(
|
||||
const int32 scaled_input2_val =
|
||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||
shifted_input2_val, input2_multiplier, input2_shift);
|
||||
output_data[Offset(output_shape, b, y, x, c)] =
|
||||
output_data[Offset(dims.output_shape, b, y, x, c)] =
|
||||
F(scaled_input1_val, scaled_input2_val);
|
||||
}
|
||||
}
|
||||
|
@ -224,10 +224,10 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT());
|
||||
AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
|
||||
AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE());
|
||||
|
@ -28,7 +28,7 @@ def make_equal_tests(options):
|
||||
"""Make a set of tests to do equal."""
|
||||
|
||||
test_parameters = [{
|
||||
"input_dtype": [tf.float32, tf.int32, tf.int64],
|
||||
"input_dtype": [tf.float32, tf.int32, tf.int64, tf.string],
|
||||
"input_shape_pair": [([], []), ([1, 1, 1, 3], [1, 1, 1, 3]),
|
||||
([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
|
||||
([5, 5], [1]), ([10], [2, 4, 10])],
|
||||
@ -60,4 +60,4 @@ def make_equal_tests(options):
|
||||
test_parameters,
|
||||
build_graph,
|
||||
build_inputs,
|
||||
expected_tf_failures=3)
|
||||
expected_tf_failures=4)
|
||||
|
@ -28,7 +28,7 @@ def make_not_equal_tests(options):
|
||||
"""Make a set of tests to do not equal."""
|
||||
|
||||
test_parameters = [{
|
||||
"input_dtype": [tf.float32, tf.int32, tf.int64],
|
||||
"input_dtype": [tf.float32, tf.int32, tf.int64, tf.string],
|
||||
"input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
|
||||
([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
|
||||
([5, 5], [1]), ([10], [2, 4, 10])],
|
||||
@ -60,4 +60,4 @@ def make_not_equal_tests(options):
|
||||
test_parameters,
|
||||
build_graph,
|
||||
build_inputs,
|
||||
expected_tf_failures=3)
|
||||
expected_tf_failures=4)
|
||||
|
@ -183,8 +183,10 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
|
||||
{{OperatorType::kReverseSequence, 1}, "1.14.0"},
|
||||
{{OperatorType::kEqual, 1}, "1.14.0"},
|
||||
{{OperatorType::kEqual, 2}, "1.14.0"},
|
||||
{{OperatorType::kEqual, 3}, kPendingReleaseOpVersion},
|
||||
{{OperatorType::kNotEqual, 1}, "1.14.0"},
|
||||
{{OperatorType::kNotEqual, 2}, "1.14.0"},
|
||||
{{OperatorType::kNotEqual, 3}, kPendingReleaseOpVersion},
|
||||
{{OperatorType::kGreater, 1}, "1.14.0"},
|
||||
{{OperatorType::kGreater, 2}, "1.14.0"},
|
||||
{{OperatorType::kGreaterEqual, 1}, "1.14.0"},
|
||||
|
@ -407,6 +407,18 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_EQUAL:
|
||||
case BuiltinOperator_NOT_EQUAL:
|
||||
if (!op_sig.input_types.empty()) {
|
||||
if (op_sig.input_types.at(0) == TensorType_STRING) {
|
||||
return 3;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_ADD:
|
||||
case BuiltinOperator_CONCATENATION:
|
||||
case BuiltinOperator_PAD:
|
||||
@ -426,8 +438,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
case BuiltinOperator_TOPK_V2:
|
||||
case BuiltinOperator_ARG_MAX:
|
||||
case BuiltinOperator_ARG_MIN:
|
||||
case BuiltinOperator_EQUAL:
|
||||
case BuiltinOperator_NOT_EQUAL:
|
||||
case BuiltinOperator_GREATER:
|
||||
case BuiltinOperator_GREATER_EQUAL:
|
||||
case BuiltinOperator_LESS:
|
||||
|
@ -86,10 +86,20 @@ void SimpleOutputVersioningTest(BuiltinOperator op) {
|
||||
|
||||
TEST(OpVersionTest, VersioningEqualTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_EQUAL);
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_EQUAL,
|
||||
.input_types = std::vector<TensorType>{TensorType_STRING},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningNotEqualTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_NOT_EQUAL);
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_NOT_EQUAL,
|
||||
.input_types = std::vector<TensorType>{TensorType_STRING},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningLessTest) {
|
||||
|
Loading…
Reference in New Issue
Block a user