From 599c8d34ef2b2e71e11c084a26d31fb98b8e1c6c Mon Sep 17 00:00:00 2001 From: Hyeonjong Ryu Date: Tue, 21 Apr 2020 23:12:21 -0700 Subject: [PATCH] String input support on TFLite Equal/Not_Equal op PiperOrigin-RevId: 307754184 Change-Id: Ib46faf4aa8b37a698f6ba67ada4477cab1f1658e --- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 4 +- tensorflow/lite/kernels/comparisons.cc | 52 ++++++- tensorflow/lite/kernels/comparisons_test.cc | 56 ++++++++ tensorflow/lite/kernels/internal/BUILD | 2 + .../kernels/internal/reference/comparisons.h | 135 +++++++++++++----- tensorflow/lite/kernels/register.cc | 4 +- tensorflow/lite/testing/op_tests/equal.py | 4 +- tensorflow/lite/testing/op_tests/not_equal.py | 4 +- tensorflow/lite/toco/tflite/op_version.cc | 2 + .../lite/tools/versioning/op_version.cc | 14 +- .../lite/tools/versioning/op_version_test.cc | 10 ++ 11 files changed, 236 insertions(+), 51 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index faaf5fa6812..ed2e6cd129f 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1228,8 +1228,8 @@ def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape, let arguments = ( ins - TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$x, - TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$y + TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8, TFL_Str]>:$x, + TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8, TFL_Str]>:$y ); let results = (outs TFL_BoolTensor:$output); diff --git a/tensorflow/lite/kernels/comparisons.cc b/tensorflow/lite/kernels/comparisons.cc index a8e3148464c..4e20efc20e3 100644 --- a/tensorflow/lite/kernels/comparisons.cc +++ b/tensorflow/lite/kernels/comparisons.cc @@ -27,7 +27,8 @@ constexpr int kInputTensor1 = 0; constexpr int kInputTensor2 = 1; constexpr int kOutputTensor = 0; -TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { +TfLiteStatus ComparisonPrepareCommon(TfLiteContext* context, TfLiteNode* node, + bool is_string_allowed) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -36,7 +37,9 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // Don't support string. - TF_LITE_ENSURE(context, input1->type != kTfLiteString); + if (!is_string_allowed) { + TF_LITE_ENSURE(context, input1->type != kTfLiteString); + } // Currently only support tensors have the same type. TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); output->type = kTfLiteBool; @@ -54,6 +57,15 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, output_size); } +TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { + return ComparisonPrepareCommon(context, node, false); +} + +TfLiteStatus ComparisonPrepareStringAllowed(TfLiteContext* context, + TfLiteNode* node) { + return ComparisonPrepareCommon(context, node, true); +} + template opname> void ComparisonQuantized(const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output, bool requires_broadcast) { @@ -108,6 +120,21 @@ void Comparison(const TfLiteTensor* input1, const TfLiteTensor* input2, GetTensorShape(output), GetTensorData(output)); } +template +void ComparisonString(const TfLiteTensor* input1, const TfLiteTensor* input2, + TfLiteTensor* output, bool requires_broadcast) { + bool* output_data = GetTensorData(output); + if (requires_broadcast) { + reference_ops::BroadcastComparison4DSlowStringImpl( + GetTensorShape(input1), input1, GetTensorShape(input2), input2, + GetTensorShape(output), output_data); + } else { + reference_ops::ComparisonStringImpl( + GetTensorShape(input1), input1, GetTensorShape(input2), input2, + GetTensorShape(output), output_data); + } +} + TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); @@ -138,9 +165,14 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { ComparisonQuantized( input1, input2, output, requires_broadcast); break; + case kTfLiteString: + ComparisonString(input1, input2, output, + requires_broadcast); + break; default: context->ReportError( - context, "Does not support type %d, requires bool|float|int|uint8", + context, + "Does not support type %d, requires bool|float|int|uint8|string", input1->type); return kTfLiteError; } @@ -177,9 +209,14 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { ComparisonQuantized( input1, input2, output, requires_broadcast); break; + case kTfLiteString: + ComparisonString( + input1, input2, output, requires_broadcast); + break; default: context->ReportError( - context, "Does not support type %d, requires bool|float|int|uint8", + context, + "Does not support type %d, requires bool|float|int|uint8|string", input1->type); return kTfLiteError; } @@ -330,14 +367,15 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { } // namespace comparisons TfLiteRegistration* Register_EQUAL() { - static TfLiteRegistration r = { - nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::EqualEval}; + static TfLiteRegistration r = {nullptr, nullptr, + comparisons::ComparisonPrepareStringAllowed, + comparisons::EqualEval}; return &r; } TfLiteRegistration* Register_NOT_EQUAL() { static TfLiteRegistration r = {nullptr, nullptr, - comparisons::ComparisonPrepare, + comparisons::ComparisonPrepareStringAllowed, comparisons::NotEqualEval}; return &r; } diff --git a/tensorflow/lite/kernels/comparisons_test.cc b/tensorflow/lite/kernels/comparisons_test.cc index 0fc49ea5c88..986600ccd1a 100644 --- a/tensorflow/lite/kernels/comparisons_test.cc +++ b/tensorflow/lite/kernels/comparisons_test.cc @@ -125,6 +125,20 @@ TEST(ComparisonsTest, EqualInt) { EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } +TEST(ComparisonsTest, EqualString) { + if (SingleOpModel::GetForceUseNnapi()) { + return; + } + ComparisonOpModel model({1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}, TensorType_STRING, + BuiltinOperator_EQUAL); + model.PopulateTensor(model.input1(), {"A", "B", "C", "D"}); + model.PopulateTensor(model.input2(), {"A", "C", "B", "D"}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4, 1)); +} + TEST(ComparisonsTest, EqualBroadcast) { ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, BuiltinOperator_EQUAL); @@ -148,6 +162,20 @@ TEST(ComparisonsTest, EqualBroadcastTwoD) { EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } +TEST(ComparisonsTest, EqualBroadcastString) { + if (SingleOpModel::GetForceUseNnapi()) { + return; + } + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_STRING, + BuiltinOperator_EQUAL); + model.PopulateTensor(model.input1(), {"A", "B", "A", "B"}); + model.PopulateTensor(model.input2(), {"A"}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + TEST(ComparisonsTest, NotEqualBool) { ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_BOOL, BuiltinOperator_NOT_EQUAL); @@ -181,6 +209,20 @@ TEST(ComparisonsTest, NotEqualInt) { EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } +TEST(ComparisonsTest, NotEqualString) { + if (SingleOpModel::GetForceUseNnapi()) { + return; + } + ComparisonOpModel model({1, 1, 1, 1, 4}, {1, 1, 1, 1, 4}, TensorType_STRING, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {"A", "B", "C", "D"}); + model.PopulateTensor(model.input2(), {"A", "C", "B", "D"}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 1, 4)); +} + TEST(ComparisonsTest, NotEqualBroadcast) { ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, BuiltinOperator_NOT_EQUAL); @@ -204,6 +246,20 @@ TEST(ComparisonsTest, NotEqualBroadcastTwoD) { EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } +TEST(ComparisonsTest, NotEqualBroadcastString) { + if (SingleOpModel::GetForceUseNnapi()) { + return; + } + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_STRING, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {"A", "B", "A", "B"}); + model.PopulateTensor(model.input2(), {"A"}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + TEST(ComparisonsTest, GreaterFloat) { ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, BuiltinOperator_GREATER); diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 48c53805532..7d2836fbe11 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -504,6 +504,7 @@ cc_library( ":tensor", ":tensor_utils", ":types", + "//tensorflow/lite:string_util", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:op_macros", "//tensorflow/lite/tools/optimize/sparsity:format_converter", @@ -566,6 +567,7 @@ cc_library( "//tensorflow/lite/kernels:op_macros", "@ruy//ruy/profiler:instrumentation", "//tensorflow/lite/tools/optimize/sparsity:format_converter", + "//tensorflow/lite:string_util", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, diff --git a/tensorflow/lite/kernels/internal/reference/comparisons.h b/tensorflow/lite/kernels/internal/reference/comparisons.h index 19a968e4670..379a20f5065 100644 --- a/tensorflow/lite/kernels/internal/reference/comparisons.h +++ b/tensorflow/lite/kernels/internal/reference/comparisons.h @@ -15,8 +15,10 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_ +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/string_util.h" namespace tflite { @@ -49,6 +51,18 @@ inline bool LessEqualFn(T lhs, T rhs) { return lhs <= rhs; } +inline bool StringRefEqualFn(const StringRef& lhs, const StringRef& rhs) { + if (lhs.len != rhs.len) return false; + for (int i = 0; i < lhs.len; ++i) { + if (lhs.str[i] != rhs.str[i]) return false; + } + return true; +} + +inline bool StringRefNotEqualFn(const StringRef& lhs, const StringRef& rhs) { + return !StringRefEqualFn(lhs, rhs); +} + template using ComparisonFn = bool (*)(T, T); @@ -64,6 +78,22 @@ inline void ComparisonImpl( } } +template +inline void ComparisonStringImpl(const RuntimeShape& input1_shape, + const TfLiteTensor* input1, + const RuntimeShape& input2_shape, + const TfLiteTensor* input2, + const RuntimeShape& output_shape, + bool* output_data) { + const int64_t flatsize = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + for (int64_t i = 0; i < flatsize; ++i) { + const auto lhs = GetString(input1, i); + const auto rhs = GetString(input2, i); + output_data[i] = F(lhs, rhs); + } +} + template F> inline void Comparison(const ComparisonParams& op_params, const RuntimeShape& input1_shape, @@ -105,35 +135,76 @@ inline void ComparisonWithScaling( } } +struct BroadcastComparison4DSlowCommon { + const RuntimeShape output_shape; + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; +}; + +inline BroadcastComparison4DSlowCommon BroadcastComparison4DSlowPreprocess( + const RuntimeShape& unextended_input1_shape, + const RuntimeShape& unextended_input2_shape, + const RuntimeShape& unextended_output_shape) { + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, + unextended_input2_shape, &desc1, &desc2); + return {RuntimeShape::ExtendedShape(4, unextended_output_shape), desc1, + desc2}; +} + template F> inline void BroadcastComparison4DSlowImpl( const ComparisonParams& op_params, const RuntimeShape& unextended_input1_shape, const T* input1_data, const RuntimeShape& unextended_input2_shape, const T* input2_data, const RuntimeShape& unextended_output_shape, bool* output_data) { - TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); - const RuntimeShape output_shape = - RuntimeShape::ExtendedShape(4, unextended_output_shape); + const BroadcastComparison4DSlowCommon dims = + BroadcastComparison4DSlowPreprocess(unextended_input1_shape, + unextended_input2_shape, + unextended_output_shape); - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, - unextended_input2_shape, &desc1, &desc2); - - for (int b = 0; b < output_shape.Dims(0); ++b) { - for (int y = 0; y < output_shape.Dims(1); ++y) { - for (int x = 0; x < output_shape.Dims(2); ++x) { - for (int c = 0; c < output_shape.Dims(3); ++c) { - output_data[Offset(output_shape, b, y, x, c)] = - F(input1_data[SubscriptToIndex(desc1, b, y, x, c)], - input2_data[SubscriptToIndex(desc2, b, y, x, c)]); + for (int b = 0; b < dims.output_shape.Dims(0); ++b) { + for (int y = 0; y < dims.output_shape.Dims(1); ++y) { + for (int x = 0; x < dims.output_shape.Dims(2); ++x) { + for (int c = 0; c < dims.output_shape.Dims(3); ++c) { + output_data[Offset(dims.output_shape, b, y, x, c)] = + F(input1_data[SubscriptToIndex(dims.desc1, b, y, x, c)], + input2_data[SubscriptToIndex(dims.desc2, b, y, x, c)]); } } } } } + +template +inline void BroadcastComparison4DSlowStringImpl( + const RuntimeShape& unextended_input1_shape, const TfLiteTensor* input1, + const RuntimeShape& unextended_input2_shape, const TfLiteTensor* input2, + const RuntimeShape& unextended_output_shape, bool* output_data) { + const BroadcastComparison4DSlowCommon dims = + BroadcastComparison4DSlowPreprocess(unextended_input1_shape, + unextended_input2_shape, + unextended_output_shape); + + for (int b = 0; b < dims.output_shape.Dims(0); ++b) { + for (int y = 0; y < dims.output_shape.Dims(1); ++y) { + for (int x = 0; x < dims.output_shape.Dims(2); ++x) { + for (int c = 0; c < dims.output_shape.Dims(3); ++c) { + const auto lhs = + GetString(input1, SubscriptToIndex(dims.desc1, b, y, x, c)); + const auto rhs = + GetString(input2, SubscriptToIndex(dims.desc2, b, y, x, c)); + output_data[Offset(dims.output_shape, b, y, x, c)] = F(lhs, rhs); + } + } + } + } +} + template F> inline void BroadcastComparison4DSlow(const ComparisonParams& op_params, const RuntimeShape& input1_shape, @@ -153,16 +224,10 @@ inline void BroadcastComparison4DSlowWithScaling( const RuntimeShape& unextended_input1_shape, const T* input1_data, const RuntimeShape& unextended_input2_shape, const T* input2_data, const RuntimeShape& unextended_output_shape, bool* output_data) { - TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); - const RuntimeShape output_shape = - RuntimeShape::ExtendedShape(4, unextended_output_shape); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, - unextended_input2_shape, &desc1, &desc2); + const BroadcastComparison4DSlowCommon dims = + BroadcastComparison4DSlowPreprocess(unextended_input1_shape, + unextended_input2_shape, + unextended_output_shape); int left_shift = op_params.left_shift; int32 input1_offset = op_params.input1_offset; @@ -172,14 +237,16 @@ inline void BroadcastComparison4DSlowWithScaling( int32 input2_multiplier = op_params.input2_multiplier; int input2_shift = op_params.input2_shift; - for (int b = 0; b < output_shape.Dims(0); ++b) { - for (int y = 0; y < output_shape.Dims(1); ++y) { - for (int x = 0; x < output_shape.Dims(2); ++x) { - for (int c = 0; c < output_shape.Dims(3); ++c) { + for (int b = 0; b < dims.output_shape.Dims(0); ++b) { + for (int y = 0; y < dims.output_shape.Dims(1); ++y) { + for (int x = 0; x < dims.output_shape.Dims(2); ++x) { + for (int c = 0; c < dims.output_shape.Dims(3); ++c) { const int32 input1_val = - input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)]; + input1_offset + + input1_data[SubscriptToIndex(dims.desc1, b, y, x, c)]; const int32 input2_val = - input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; + input2_offset + + input2_data[SubscriptToIndex(dims.desc2, b, y, x, c)]; const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = @@ -188,7 +255,7 @@ inline void BroadcastComparison4DSlowWithScaling( const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( shifted_input2_val, input2_multiplier, input2_shift); - output_data[Offset(output_shape, b, y, x, c)] = + output_data[Offset(dims.output_shape, b, y, x, c)] = F(scaled_input1_val, scaled_input2_val); } } diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 392369eaea0..f5608b1a820 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -224,10 +224,10 @@ BuiltinOpResolver::BuiltinOpResolver() { /* max_version = */ 3); AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL(), /* min_version = */ 1, - /* max_version = */ 2); + /* max_version = */ 3); AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL(), /* min_version = */ 1, - /* max_version = */ 2); + /* max_version = */ 3); AddBuiltin(BuiltinOperator_SQRT, Register_SQRT()); AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT()); AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE()); diff --git a/tensorflow/lite/testing/op_tests/equal.py b/tensorflow/lite/testing/op_tests/equal.py index 76a3fed1456..ddbece129d3 100644 --- a/tensorflow/lite/testing/op_tests/equal.py +++ b/tensorflow/lite/testing/op_tests/equal.py @@ -28,7 +28,7 @@ def make_equal_tests(options): """Make a set of tests to do equal.""" test_parameters = [{ - "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_dtype": [tf.float32, tf.int32, tf.int64, tf.string], "input_shape_pair": [([], []), ([1, 1, 1, 3], [1, 1, 1, 3]), ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), ([5, 5], [1]), ([10], [2, 4, 10])], @@ -60,4 +60,4 @@ def make_equal_tests(options): test_parameters, build_graph, build_inputs, - expected_tf_failures=3) + expected_tf_failures=4) diff --git a/tensorflow/lite/testing/op_tests/not_equal.py b/tensorflow/lite/testing/op_tests/not_equal.py index 7ecf6e2ffb6..e0f9d3c0735 100644 --- a/tensorflow/lite/testing/op_tests/not_equal.py +++ b/tensorflow/lite/testing/op_tests/not_equal.py @@ -28,7 +28,7 @@ def make_not_equal_tests(options): """Make a set of tests to do not equal.""" test_parameters = [{ - "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_dtype": [tf.float32, tf.int32, tf.int64, tf.string], "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), ([5, 5], [1]), ([10], [2, 4, 10])], @@ -60,4 +60,4 @@ def make_not_equal_tests(options): test_parameters, build_graph, build_inputs, - expected_tf_failures=3) + expected_tf_failures=4) diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index c5b5e3c71ab..9a2842a6046 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -183,8 +183,10 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kReverseSequence, 1}, "1.14.0"}, {{OperatorType::kEqual, 1}, "1.14.0"}, {{OperatorType::kEqual, 2}, "1.14.0"}, + {{OperatorType::kEqual, 3}, kPendingReleaseOpVersion}, {{OperatorType::kNotEqual, 1}, "1.14.0"}, {{OperatorType::kNotEqual, 2}, "1.14.0"}, + {{OperatorType::kNotEqual, 3}, kPendingReleaseOpVersion}, {{OperatorType::kGreater, 1}, "1.14.0"}, {{OperatorType::kGreater, 2}, "1.14.0"}, {{OperatorType::kGreaterEqual, 1}, "1.14.0"}, diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 2b7d3f7d316..622cb134198 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -407,6 +407,18 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { } return 1; + case BuiltinOperator_EQUAL: + case BuiltinOperator_NOT_EQUAL: + if (!op_sig.input_types.empty()) { + if (op_sig.input_types.at(0) == TensorType_STRING) { + return 3; + } + if (op_sig.input_types.at(0) == TensorType_INT8) { + return 2; + } + } + return 1; + case BuiltinOperator_ADD: case BuiltinOperator_CONCATENATION: case BuiltinOperator_PAD: @@ -426,8 +438,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { case BuiltinOperator_TOPK_V2: case BuiltinOperator_ARG_MAX: case BuiltinOperator_ARG_MIN: - case BuiltinOperator_EQUAL: - case BuiltinOperator_NOT_EQUAL: case BuiltinOperator_GREATER: case BuiltinOperator_GREATER_EQUAL: case BuiltinOperator_LESS: diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc index 5dde260241e..2a48ddd6714 100644 --- a/tensorflow/lite/tools/versioning/op_version_test.cc +++ b/tensorflow/lite/tools/versioning/op_version_test.cc @@ -86,10 +86,20 @@ void SimpleOutputVersioningTest(BuiltinOperator op) { TEST(OpVersionTest, VersioningEqualTest) { SimpleVersioningTest(BuiltinOperator_EQUAL); + OpSignature fake_op_sig = { + .op = BuiltinOperator_EQUAL, + .input_types = std::vector{TensorType_STRING}, + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); } TEST(OpVersionTest, VersioningNotEqualTest) { SimpleVersioningTest(BuiltinOperator_NOT_EQUAL); + OpSignature fake_op_sig = { + .op = BuiltinOperator_NOT_EQUAL, + .input_types = std::vector{TensorType_STRING}, + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3); } TEST(OpVersionTest, VersioningLessTest) {