Merge pull request #41278 from wwwind:16x8_gather

PiperOrigin-RevId: 323451276
This commit is contained in:
TensorFlower Gardener 2020-07-27 15:23:03 -07:00
commit f4d7bffce8
8 changed files with 33 additions and 2 deletions

View File

@ -214,6 +214,7 @@ TypesGatherOpTest/Float32Int32,29
TypesGatherOpTest/Int32Int32,29
TypesGatherOpTest/Uint8Int32,29
TypesGatherOpTest/Int8Int32,29
-TypesGatherOpTest/.*Int16.*
# hashtable_lookup_test
# All test excepted the string one should be accelerated

View File

@ -61,6 +61,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteFloat32:
case kTfLiteUInt8:
case kTfLiteInt8:
case kTfLiteInt16:
case kTfLiteInt64:
case kTfLiteInt32:
case kTfLiteBool:
@ -143,6 +144,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return Gather<uint8_t, int32_t>(*params, input, positions, output);
case kTfLiteInt8:
return Gather<int8_t, int32_t>(*params, input, positions, output);
case kTfLiteInt16:
return Gather<int16_t, int32_t>(*params, input, positions, output);
case kTfLiteInt32:
return Gather<int32_t, int32_t>(*params, input, positions, output);
case kTfLiteInt64:
@ -165,6 +168,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return Gather<uint8_t, int64_t>(*params, input, positions, output);
case kTfLiteInt8:
return Gather<int8_t, int64_t>(*params, input, positions, output);
case kTfLiteInt16:
return Gather<int16_t, int64_t>(*params, input, positions, output);
case kTfLiteInt32:
return Gather<int32_t, int64_t>(*params, input, positions, output);
case kTfLiteInt64:

View File

@ -272,6 +272,26 @@ TEST(TypesGatherOpTest, Int8Int64) {
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({14, 15, -13, -120}));
}
TEST(TypesGatherOpTest, Int16Int32) {
GatherOpModel m({TensorType_INT16, {2, 2}}, {TensorType_INT32, {2}});
m.SetInput<int16_t>({-13, -32000, 0, 32500});
m.SetPositions<int32_t>({1, 0});
m.Invoke();
EXPECT_THAT(m.GetOutput<int16_t>(),
ElementsAreArray({0, 32500, -13, -32000}));
}
TEST(TypesGatherOpTest, Int16Int64) {
GatherOpModel m({TensorType_INT16, {2, 2}}, {TensorType_INT64, {2}});
m.SetInput<int16_t>({-13, -32000, 0, 32500});
m.SetPositions<int64_t>({1LL, 0LL});
m.Invoke();
EXPECT_THAT(m.GetOutput<int16_t>(),
ElementsAreArray({0, 32500, -13, -32000}));
}
TEST(TypesGatherOpTest, Int64Int32) {
GatherOpModel m({TensorType_INT64, {2, 2}}, {TensorType_INT32, {2}});
m.SetInput<int64_t>({-(1LL << 34), 134LL, 14LL, 15LL});

View File

@ -131,7 +131,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_DEPTH_TO_SPACE, Register_DEPTH_TO_SPACE());
AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(),
/* min_version = */ 1,
/* max_version = */ 3);
/* max_version = */ 4);
AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE(),
/* min_version = */ 1,
/* max_version = */ 4);

View File

@ -80,6 +80,7 @@ std::string GetMinimumRuntimeVersionForModel(const Model& model) {
{{OperatorType::kGather, 1}, "1.6.0"},
{{OperatorType::kGather, 2}, "1.14.0"},
{{OperatorType::kGather, 3}, "1.15.0"},
{{OperatorType::kGather, 4}, kPendingReleaseOpVersion},
{{OperatorType::kGatherNd, 1}, "1.14.0"},
{{OperatorType::kGatherNd, 2}, kPendingReleaseOpVersion},
{{OperatorType::kSvdf, 1}, "1.5.0"},

View File

@ -190,8 +190,8 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
property.inputs = {{0, {}}};
property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true;
property.quantize_input_as_activations = true;
property.version = 2;
property.quantizable_int16 = false;
break;
case BuiltinOperator_HARD_SWISH: {
property.inputs = {{0, {}}};

View File

@ -176,6 +176,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 1;
case BuiltinOperator_GATHER:
if (op_sig.input_types.at(0) == TensorType_INT16) {
return 4;
}
// If the op takes bool input, it is version 3.
if (op_sig.input_types.at(0) == TensorType_BOOL) {
return 3;

View File

@ -109,6 +109,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_GATHER, 1}, "1.6.0"},
{{BuiltinOperator_GATHER, 2}, "1.14.0"},
{{BuiltinOperator_GATHER, 3}, "1.15.0"},
{{BuiltinOperator_GATHER, 4}, kPendingReleaseVersion},
{{BuiltinOperator_GATHER_ND, 1}, "1.14.0"},
{{BuiltinOperator_GATHER_ND, 2}, "2.3.0"},
{{BuiltinOperator_HASHTABLE_LOOKUP, 1}, "1.5.0"},