Merge pull request #41278 from wwwind:16x8_gather
PiperOrigin-RevId: 323451276
This commit is contained in:
commit
f4d7bffce8
@ -214,6 +214,7 @@ TypesGatherOpTest/Float32Int32,29
|
||||
TypesGatherOpTest/Int32Int32,29
|
||||
TypesGatherOpTest/Uint8Int32,29
|
||||
TypesGatherOpTest/Int8Int32,29
|
||||
-TypesGatherOpTest/.*Int16.*
|
||||
|
||||
# hashtable_lookup_test
|
||||
# All test excepted the string one should be accelerated
|
||||
|
@ -61,6 +61,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteFloat32:
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8:
|
||||
case kTfLiteInt16:
|
||||
case kTfLiteInt64:
|
||||
case kTfLiteInt32:
|
||||
case kTfLiteBool:
|
||||
@ -143,6 +144,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return Gather<uint8_t, int32_t>(*params, input, positions, output);
|
||||
case kTfLiteInt8:
|
||||
return Gather<int8_t, int32_t>(*params, input, positions, output);
|
||||
case kTfLiteInt16:
|
||||
return Gather<int16_t, int32_t>(*params, input, positions, output);
|
||||
case kTfLiteInt32:
|
||||
return Gather<int32_t, int32_t>(*params, input, positions, output);
|
||||
case kTfLiteInt64:
|
||||
@ -165,6 +168,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return Gather<uint8_t, int64_t>(*params, input, positions, output);
|
||||
case kTfLiteInt8:
|
||||
return Gather<int8_t, int64_t>(*params, input, positions, output);
|
||||
case kTfLiteInt16:
|
||||
return Gather<int16_t, int64_t>(*params, input, positions, output);
|
||||
case kTfLiteInt32:
|
||||
return Gather<int32_t, int64_t>(*params, input, positions, output);
|
||||
case kTfLiteInt64:
|
||||
|
@ -272,6 +272,26 @@ TEST(TypesGatherOpTest, Int8Int64) {
|
||||
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({14, 15, -13, -120}));
|
||||
}
|
||||
|
||||
TEST(TypesGatherOpTest, Int16Int32) {
|
||||
GatherOpModel m({TensorType_INT16, {2, 2}}, {TensorType_INT32, {2}});
|
||||
m.SetInput<int16_t>({-13, -32000, 0, 32500});
|
||||
m.SetPositions<int32_t>({1, 0});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput<int16_t>(),
|
||||
ElementsAreArray({0, 32500, -13, -32000}));
|
||||
}
|
||||
|
||||
TEST(TypesGatherOpTest, Int16Int64) {
|
||||
GatherOpModel m({TensorType_INT16, {2, 2}}, {TensorType_INT64, {2}});
|
||||
m.SetInput<int16_t>({-13, -32000, 0, 32500});
|
||||
m.SetPositions<int64_t>({1LL, 0LL});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput<int16_t>(),
|
||||
ElementsAreArray({0, 32500, -13, -32000}));
|
||||
}
|
||||
|
||||
TEST(TypesGatherOpTest, Int64Int32) {
|
||||
GatherOpModel m({TensorType_INT64, {2, 2}}, {TensorType_INT32, {2}});
|
||||
m.SetInput<int64_t>({-(1LL << 34), 134LL, 14LL, 15LL});
|
||||
|
@ -131,7 +131,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_DEPTH_TO_SPACE, Register_DEPTH_TO_SPACE());
|
||||
AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 3);
|
||||
/* max_version = */ 4);
|
||||
AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 4);
|
||||
|
@ -80,6 +80,7 @@ std::string GetMinimumRuntimeVersionForModel(const Model& model) {
|
||||
{{OperatorType::kGather, 1}, "1.6.0"},
|
||||
{{OperatorType::kGather, 2}, "1.14.0"},
|
||||
{{OperatorType::kGather, 3}, "1.15.0"},
|
||||
{{OperatorType::kGather, 4}, kPendingReleaseOpVersion},
|
||||
{{OperatorType::kGatherNd, 1}, "1.14.0"},
|
||||
{{OperatorType::kGatherNd, 2}, kPendingReleaseOpVersion},
|
||||
{{OperatorType::kSvdf, 1}, "1.5.0"},
|
||||
|
@ -190,8 +190,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
||||
property.inputs = {{0, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.quantize_input_as_activations = true;
|
||||
property.version = 2;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
case BuiltinOperator_HARD_SWISH: {
|
||||
property.inputs = {{0, {}}};
|
||||
|
@ -176,6 +176,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_GATHER:
|
||||
if (op_sig.input_types.at(0) == TensorType_INT16) {
|
||||
return 4;
|
||||
}
|
||||
// If the op takes bool input, it is version 3.
|
||||
if (op_sig.input_types.at(0) == TensorType_BOOL) {
|
||||
return 3;
|
||||
|
@ -109,6 +109,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
||||
{{BuiltinOperator_GATHER, 1}, "1.6.0"},
|
||||
{{BuiltinOperator_GATHER, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_GATHER, 3}, "1.15.0"},
|
||||
{{BuiltinOperator_GATHER, 4}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_GATHER_ND, 1}, "1.14.0"},
|
||||
{{BuiltinOperator_GATHER_ND, 2}, "2.3.0"},
|
||||
{{BuiltinOperator_HASHTABLE_LOOKUP, 1}, "1.5.0"},
|
||||
|
Loading…
Reference in New Issue
Block a user