Support string in TFLite Squeeze kernel

PiperOrigin-RevId: 340382663
Change-Id: I4ff462f7a66097aaac8a0bf2182c17ce4020b4f9
This commit is contained in:
Thai Nguyen 2020-11-02 22:35:21 -08:00 committed by TensorFlower Gardener
parent cfd834264d
commit ca16be74dc
6 changed files with 64 additions and 1 deletions

View File

@ -152,7 +152,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V(),
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE(),
/* min_version = */ 1,
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(),
/* min_version = */ 1,
/* max_version = */ 4);

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@ -78,6 +79,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
SqueezeContext op_context(context, node);
if (op_context.input->type == kTfLiteString) {
const int input_flat_size = GetTensorShape(op_context.input).FlatSize();
const int output_flat_size = GetTensorShape(op_context.output).FlatSize();
TF_LITE_ENSURE_EQ(context, input_flat_size, output_flat_size);
SequentialTensorWriter<string> writer(op_context.input, op_context.output);
for (int i = 0; i < input_flat_size; i++) {
writer.Write(i);
}
return kTfLiteOk;
}
TF_LITE_ENSURE_EQ(context, op_context.input->bytes, op_context.output->bytes);
memcpy(op_context.output->data.raw, op_context.input->data.raw,
op_context.input->bytes);

View File

@ -56,7 +56,14 @@ class SqueezeOpModel : public BaseSqueezeOpModel {
void SetInput(std::initializer_list<T> data) { PopulateTensor(input_, data); }
void SetStringInput(std::initializer_list<string> data) {
PopulateStringTensor(input_, data);
}
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<string> GetStringOutput() {
return ExtractVector<string>(output_);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
};
@ -122,5 +129,36 @@ TYPED_TEST(SqueezeOpTest, SqueezeAllDims) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
}
TEST(SqueezeOpTest, SqueezeAllString) {
std::initializer_list<std::string> data = {"a", "b"};
SqueezeOpModel<std::string> m({GetTensorType<std::string>(), {1, 2, 1}},
{GetTensorType<std::string>(), {2}}, {});
m.SetStringInput(data);
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a", "b"}));
}
TEST(SqueezeOpTest, SqueezeNegativeAxisString) {
std::initializer_list<std::string> data = {"a", "b"};
SqueezeOpModel<std::string> m({GetTensorType<std::string>(), {1, 2, 1}},
{GetTensorType<std::string>(), {24}}, {-1});
m.SetStringInput(data);
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a", "b"}));
}
TYPED_TEST(SqueezeOpTest, SqueezeAllDimsString) {
std::initializer_list<std::string> data = {"a"};
SqueezeOpModel<std::string> m(
{GetTensorType<std::string>(), {1, 1, 1, 1, 1, 1, 1}},
{GetTensorType<std::string>(), {1}}, {});
m.SetStringInput(data);
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), IsEmpty());
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a"}));
}
} // namespace
} // namespace tflite

View File

@ -65,6 +65,11 @@ def make_squeeze_tests(options):
"input_shape": [[1, 1, 5, 10], [1, 5, 1, 10], [5, 1, 10]],
"axis": [[0], [1], [3, 0], [-2, 0, 3, 2]],
"fully_quantize": [True],
}, {
"dtype": [tf.string],
"input_shape": [[1, 1, 5, 10], [1, 5, 1, 10]],
"axis": [[0], []],
"fully_quantize": [False],
}]
def build_graph(parameters):

View File

@ -447,7 +447,12 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
if (op_sig.input_types.at(0) == TensorType_STRING) {
return 2;
}
return 1;
case BuiltinOperator_SQUEEZE:
if (op_sig.input_types.at(0) == TensorType_STRING) {
return 2;
}
return 1;
case BuiltinOperator_SPACE_TO_BATCH_ND:

View File

@ -201,6 +201,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_RNN, 3}, "2.3.0"},
{{BuiltinOperator_SKIP_GRAM, 1}, "1.5.0"},
{{BuiltinOperator_SQUEEZE, 1}, "1.6.0"},
{{BuiltinOperator_SQUEEZE, 2}, kPendingReleaseVersion},
{{BuiltinOperator_SPLIT, 1}, "1.5.0"},
{{BuiltinOperator_SPLIT, 2}, "1.14.0"},
{{BuiltinOperator_SPLIT, 3}, "1.14.0"},