From ca16be74dc05f7d9663359daf3efef134a21b9bc Mon Sep 17 00:00:00 2001 From: Thai Nguyen Date: Mon, 2 Nov 2020 22:35:21 -0800 Subject: [PATCH] Support string in TFLite Squeeze kernel PiperOrigin-RevId: 340382663 Change-Id: I4ff462f7a66097aaac8a0bf2182c17ce4020b4f9 --- tensorflow/lite/kernels/register.cc | 4 +- tensorflow/lite/kernels/squeeze.cc | 12 ++++++ tensorflow/lite/kernels/squeeze_test.cc | 38 +++++++++++++++++++ tensorflow/lite/testing/op_tests/squeeze.py | 5 +++ .../lite/tools/versioning/op_version.cc | 5 +++ .../lite/tools/versioning/runtime_version.cc | 1 + 6 files changed, 64 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 53d4c5c5e38..a4e960dea81 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -152,7 +152,9 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V(), /* min_version = */ 1, /* max_version = */ 2); - AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); + AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE(), + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(), /* min_version = */ 1, /* max_version = */ 4); diff --git a/tensorflow/lite/kernels/squeeze.cc b/tensorflow/lite/kernels/squeeze.cc index c4dc51026a6..ac282fd0959 100644 --- a/tensorflow/lite/kernels/squeeze.cc +++ b/tensorflow/lite/kernels/squeeze.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/portable_tensor.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -78,6 +79,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { SqueezeContext op_context(context, node); + if (op_context.input->type == kTfLiteString) { + const int input_flat_size = GetTensorShape(op_context.input).FlatSize(); + const int output_flat_size = GetTensorShape(op_context.output).FlatSize(); + TF_LITE_ENSURE_EQ(context, input_flat_size, output_flat_size); + SequentialTensorWriter writer(op_context.input, op_context.output); + for (int i = 0; i < input_flat_size; i++) { + writer.Write(i); + } + return kTfLiteOk; + } + TF_LITE_ENSURE_EQ(context, op_context.input->bytes, op_context.output->bytes); memcpy(op_context.output->data.raw, op_context.input->data.raw, op_context.input->bytes); diff --git a/tensorflow/lite/kernels/squeeze_test.cc b/tensorflow/lite/kernels/squeeze_test.cc index 4239ae43e1c..9aac56cf2ef 100644 --- a/tensorflow/lite/kernels/squeeze_test.cc +++ b/tensorflow/lite/kernels/squeeze_test.cc @@ -56,7 +56,14 @@ class SqueezeOpModel : public BaseSqueezeOpModel { void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } + void SetStringInput(std::initializer_list data) { + PopulateStringTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetStringOutput() { + return ExtractVector(output_); + } std::vector GetOutputShape() { return GetTensorShape(output_); } }; @@ -122,5 +129,36 @@ TYPED_TEST(SqueezeOpTest, SqueezeAllDims) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); } +TEST(SqueezeOpTest, SqueezeAllString) { + std::initializer_list data = {"a", "b"}; + SqueezeOpModel m({GetTensorType(), {1, 2, 1}}, + {GetTensorType(), {2}}, {}); + m.SetStringInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a", "b"})); +} + +TEST(SqueezeOpTest, SqueezeNegativeAxisString) { + std::initializer_list data = {"a", "b"}; + SqueezeOpModel m({GetTensorType(), {1, 2, 1}}, + {GetTensorType(), {24}}, {-1}); + m.SetStringInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a", "b"})); +} + +TYPED_TEST(SqueezeOpTest, SqueezeAllDimsString) { + std::initializer_list data = {"a"}; + SqueezeOpModel m( + {GetTensorType(), {1, 1, 1, 1, 1, 1, 1}}, + {GetTensorType(), {1}}, {}); + m.SetStringInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a"})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/testing/op_tests/squeeze.py b/tensorflow/lite/testing/op_tests/squeeze.py index 481dfd7612c..00726869892 100644 --- a/tensorflow/lite/testing/op_tests/squeeze.py +++ b/tensorflow/lite/testing/op_tests/squeeze.py @@ -65,6 +65,11 @@ def make_squeeze_tests(options): "input_shape": [[1, 1, 5, 10], [1, 5, 1, 10], [5, 1, 10]], "axis": [[0], [1], [3, 0], [-2, 0, 3, 2]], "fully_quantize": [True], + }, { + "dtype": [tf.string], + "input_shape": [[1, 1, 5, 10], [1, 5, 1, 10]], + "axis": [[0], []], + "fully_quantize": [False], }] def build_graph(parameters): diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 8627c492c70..aff9a3cbde2 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -447,7 +447,12 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { if (op_sig.input_types.at(0) == TensorType_STRING) { return 2; } + return 1; + case BuiltinOperator_SQUEEZE: + if (op_sig.input_types.at(0) == TensorType_STRING) { + return 2; + } return 1; case BuiltinOperator_SPACE_TO_BATCH_ND: diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index 5366b46ca2b..edcfb70c3ab 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -201,6 +201,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_RNN, 3}, "2.3.0"}, {{BuiltinOperator_SKIP_GRAM, 1}, "1.5.0"}, {{BuiltinOperator_SQUEEZE, 1}, "1.6.0"}, + {{BuiltinOperator_SQUEEZE, 2}, kPendingReleaseVersion}, {{BuiltinOperator_SPLIT, 1}, "1.5.0"}, {{BuiltinOperator_SPLIT, 2}, "1.14.0"}, {{BuiltinOperator_SPLIT, 3}, "1.14.0"},