diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 53d4c5c5e38..a4e960dea81 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -152,7 +152,9 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V(), /* min_version = */ 1, /* max_version = */ 2); - AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); + AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE(), + /* min_version = */ 1, + /* max_version = */ 2); AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(), /* min_version = */ 1, /* max_version = */ 4); diff --git a/tensorflow/lite/kernels/squeeze.cc b/tensorflow/lite/kernels/squeeze.cc index c4dc51026a6..ac282fd0959 100644 --- a/tensorflow/lite/kernels/squeeze.cc +++ b/tensorflow/lite/kernels/squeeze.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/portable_tensor.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -78,6 +79,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { SqueezeContext op_context(context, node); + if (op_context.input->type == kTfLiteString) { + const int input_flat_size = GetTensorShape(op_context.input).FlatSize(); + const int output_flat_size = GetTensorShape(op_context.output).FlatSize(); + TF_LITE_ENSURE_EQ(context, input_flat_size, output_flat_size); + SequentialTensorWriter<string> writer(op_context.input, op_context.output); + for (int i = 0; i < input_flat_size; i++) { + writer.Write(i); + } + return kTfLiteOk; + } + TF_LITE_ENSURE_EQ(context, op_context.input->bytes, op_context.output->bytes); memcpy(op_context.output->data.raw, op_context.input->data.raw, op_context.input->bytes); diff --git a/tensorflow/lite/kernels/squeeze_test.cc b/tensorflow/lite/kernels/squeeze_test.cc index 4239ae43e1c..9aac56cf2ef 100644 --- a/tensorflow/lite/kernels/squeeze_test.cc +++ b/tensorflow/lite/kernels/squeeze_test.cc @@ -56,7 +56,14 @@ class SqueezeOpModel : public BaseSqueezeOpModel { void SetInput(std::initializer_list<T> data) { PopulateTensor(input_, data); } + void SetStringInput(std::initializer_list<string> data) { + PopulateStringTensor(input_, data); + } + std::vector<T> GetOutput() { return ExtractVector<T>(output_); } + std::vector<string> GetStringOutput() { + return ExtractVector<string>(output_); + } std::vector<int> GetOutputShape() { return GetTensorShape(output_); } }; @@ -122,5 +129,36 @@ TYPED_TEST(SqueezeOpTest, SqueezeAllDims) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); } +TEST(SqueezeOpTest, SqueezeAllString) { + std::initializer_list<std::string> data = {"a", "b"}; + SqueezeOpModel<std::string> m({GetTensorType<std::string>(), {1, 2, 1}}, + {GetTensorType<std::string>(), {2}}, {}); + m.SetStringInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a", "b"})); +} + +TEST(SqueezeOpTest, SqueezeNegativeAxisString) { + std::initializer_list<std::string> data = {"a", "b"}; + SqueezeOpModel<std::string> m({GetTensorType<std::string>(), {1, 2, 1}}, + {GetTensorType<std::string>(), {24}}, {-1}); + m.SetStringInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a", "b"})); +} + +TYPED_TEST(SqueezeOpTest, SqueezeAllDimsString) { + std::initializer_list<std::string> data = {"a"}; + SqueezeOpModel<std::string> m( + {GetTensorType<std::string>(), {1, 1, 1, 1, 1, 1, 1}}, + {GetTensorType<std::string>(), {1}}, {}); + m.SetStringInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a"})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/testing/op_tests/squeeze.py b/tensorflow/lite/testing/op_tests/squeeze.py index 481dfd7612c..00726869892 100644 --- a/tensorflow/lite/testing/op_tests/squeeze.py +++ b/tensorflow/lite/testing/op_tests/squeeze.py @@ -65,6 +65,11 @@ def make_squeeze_tests(options): "input_shape": [[1, 1, 5, 10], [1, 5, 1, 10], [5, 1, 10]], "axis": [[0], [1], [3, 0], [-2, 0, 3, 2]], "fully_quantize": [True], + }, { + "dtype": [tf.string], + "input_shape": [[1, 1, 5, 10], [1, 5, 1, 10]], + "axis": [[0], []], + "fully_quantize": [False], }] def build_graph(parameters): diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 8627c492c70..aff9a3cbde2 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -447,7 +447,12 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { if (op_sig.input_types.at(0) == TensorType_STRING) { return 2; } + return 1; + case BuiltinOperator_SQUEEZE: + if (op_sig.input_types.at(0) == TensorType_STRING) { + return 2; + } return 1; case BuiltinOperator_SPACE_TO_BATCH_ND: diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index 5366b46ca2b..edcfb70c3ab 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -201,6 +201,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_RNN, 3}, "2.3.0"}, {{BuiltinOperator_SKIP_GRAM, 1}, "1.5.0"}, {{BuiltinOperator_SQUEEZE, 1}, "1.6.0"}, + {{BuiltinOperator_SQUEEZE, 2}, kPendingReleaseVersion}, {{BuiltinOperator_SPLIT, 1}, "1.5.0"}, {{BuiltinOperator_SPLIT, 2}, "1.14.0"}, {{BuiltinOperator_SPLIT, 3}, "1.14.0"},