Support string in TFLite Squeeze kernel
PiperOrigin-RevId: 340382663 Change-Id: I4ff462f7a66097aaac8a0bf2182c17ce4020b4f9
This commit is contained in:
parent
cfd834264d
commit
ca16be74dc
@ -152,7 +152,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
|
||||
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 4);
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
@ -78,6 +79,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
SqueezeContext op_context(context, node);
|
||||
if (op_context.input->type == kTfLiteString) {
|
||||
const int input_flat_size = GetTensorShape(op_context.input).FlatSize();
|
||||
const int output_flat_size = GetTensorShape(op_context.output).FlatSize();
|
||||
TF_LITE_ENSURE_EQ(context, input_flat_size, output_flat_size);
|
||||
SequentialTensorWriter<string> writer(op_context.input, op_context.output);
|
||||
for (int i = 0; i < input_flat_size; i++) {
|
||||
writer.Write(i);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, op_context.input->bytes, op_context.output->bytes);
|
||||
memcpy(op_context.output->data.raw, op_context.input->data.raw,
|
||||
op_context.input->bytes);
|
||||
|
@ -56,7 +56,14 @@ class SqueezeOpModel : public BaseSqueezeOpModel {
|
||||
|
||||
void SetInput(std::initializer_list<T> data) { PopulateTensor(input_, data); }
|
||||
|
||||
void SetStringInput(std::initializer_list<string> data) {
|
||||
PopulateStringTensor(input_, data);
|
||||
}
|
||||
|
||||
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
|
||||
std::vector<string> GetStringOutput() {
|
||||
return ExtractVector<string>(output_);
|
||||
}
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
};
|
||||
|
||||
@ -122,5 +129,36 @@ TYPED_TEST(SqueezeOpTest, SqueezeAllDims) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
|
||||
}
|
||||
|
||||
TEST(SqueezeOpTest, SqueezeAllString) {
|
||||
std::initializer_list<std::string> data = {"a", "b"};
|
||||
SqueezeOpModel<std::string> m({GetTensorType<std::string>(), {1, 2, 1}},
|
||||
{GetTensorType<std::string>(), {2}}, {});
|
||||
m.SetStringInput(data);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
|
||||
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a", "b"}));
|
||||
}
|
||||
|
||||
TEST(SqueezeOpTest, SqueezeNegativeAxisString) {
|
||||
std::initializer_list<std::string> data = {"a", "b"};
|
||||
SqueezeOpModel<std::string> m({GetTensorType<std::string>(), {1, 2, 1}},
|
||||
{GetTensorType<std::string>(), {24}}, {-1});
|
||||
m.SetStringInput(data);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
|
||||
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a", "b"}));
|
||||
}
|
||||
|
||||
TYPED_TEST(SqueezeOpTest, SqueezeAllDimsString) {
|
||||
std::initializer_list<std::string> data = {"a"};
|
||||
SqueezeOpModel<std::string> m(
|
||||
{GetTensorType<std::string>(), {1, 1, 1, 1, 1, 1, 1}},
|
||||
{GetTensorType<std::string>(), {1}}, {});
|
||||
m.SetStringInput(data);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), IsEmpty());
|
||||
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"a"}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -65,6 +65,11 @@ def make_squeeze_tests(options):
|
||||
"input_shape": [[1, 1, 5, 10], [1, 5, 1, 10], [5, 1, 10]],
|
||||
"axis": [[0], [1], [3, 0], [-2, 0, 3, 2]],
|
||||
"fully_quantize": [True],
|
||||
}, {
|
||||
"dtype": [tf.string],
|
||||
"input_shape": [[1, 1, 5, 10], [1, 5, 1, 10]],
|
||||
"axis": [[0], []],
|
||||
"fully_quantize": [False],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
|
@ -447,7 +447,12 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
if (op_sig.input_types.at(0) == TensorType_STRING) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_SQUEEZE:
|
||||
if (op_sig.input_types.at(0) == TensorType_STRING) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_SPACE_TO_BATCH_ND:
|
||||
|
@ -201,6 +201,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
||||
{{BuiltinOperator_RNN, 3}, "2.3.0"},
|
||||
{{BuiltinOperator_SKIP_GRAM, 1}, "1.5.0"},
|
||||
{{BuiltinOperator_SQUEEZE, 1}, "1.6.0"},
|
||||
{{BuiltinOperator_SQUEEZE, 2}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_SPLIT, 1}, "1.5.0"},
|
||||
{{BuiltinOperator_SPLIT, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_SPLIT, 3}, "1.14.0"},
|
||||
|
Loading…
Reference in New Issue
Block a user