Update select to support int8.
PiperOrigin-RevId: 232892451
This commit is contained in:
parent
b50b7911ec
commit
81b9a95045
@ -277,7 +277,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR());
|
||||
AddBuiltin(BuiltinOperator_CEIL, Register_CEIL());
|
||||
AddBuiltin(BuiltinOperator_NEG, Register_NEG());
|
||||
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT());
|
||||
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE());
|
||||
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
||||
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());
|
||||
|
@ -89,6 +89,9 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteUInt8: \
|
||||
TF_LITE_SELECT(uint8_t, op); \
|
||||
break; \
|
||||
case kTfLiteInt8: \
|
||||
TF_LITE_SELECT(int8_t, op); \
|
||||
break; \
|
||||
case kTfLiteInt16: \
|
||||
TF_LITE_SELECT(int16_t, op); \
|
||||
break; \
|
||||
|
@ -96,6 +96,19 @@ TEST(SelectOpTest, SelectUInt8) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
|
||||
}
|
||||
|
||||
TEST(SelectOpTest, SelectInt8) {
|
||||
SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
|
||||
TensorType_INT8);
|
||||
|
||||
model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
|
||||
model.PopulateTensor<int8_t>(model.input2(), {1, -2, 3, 4});
|
||||
model.PopulateTensor<int8_t>(model.input3(), {5, 6, 7, -8});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput<int8_t>(), ElementsAreArray({5, -2, 7, -8}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
|
||||
}
|
||||
|
||||
TEST(SelectOpTest, SelectInt16) {
|
||||
SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
|
||||
TensorType_INT16);
|
||||
|
@ -2030,6 +2030,20 @@ class LessEqual : public SimpleOperator<TensorFlowLessEqualOperator> {
|
||||
}
|
||||
};
|
||||
|
||||
class Select : public SimpleOperator<SelectOperator> {
|
||||
public:
|
||||
explicit Select() : SimpleOperator("SELECT", OperatorType::kSelect) {}
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const string& input_name = op_signature.op->inputs[0];
|
||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||
// Version 2 supports signed int8 input types.
|
||||
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
// Build a vector containing all the known operators.
|
||||
std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
@ -2216,8 +2230,7 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
ops.push_back(MakeUnique<NotEqual>());
|
||||
ops.push_back(
|
||||
MakeUnique<SimpleOperator<NegOperator>>("NEG", OperatorType::kNeg));
|
||||
ops.push_back(MakeUnique<SimpleOperator<SelectOperator>>(
|
||||
"SELECT", OperatorType::kSelect));
|
||||
ops.push_back(MakeUnique<Select>());
|
||||
ops.push_back(
|
||||
MakeUnique<SimpleOperator<SliceOperator>>("SLICE", OperatorType::kSlice));
|
||||
ops.push_back(
|
||||
|
@ -721,6 +721,25 @@ TEST_F(OperatorTest, VersioningGreaterEqualTest) {
|
||||
VersioningTest<TensorFlowGreaterEqualOperator>();
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningSelectTest) {
|
||||
SelectOperator select_op;
|
||||
select_op.inputs = {"input1"};
|
||||
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
||||
const BaseOperator* op = operator_by_type_map.at(select_op.type).get();
|
||||
|
||||
Model uint8_model;
|
||||
Array& uint8_array = uint8_model.GetOrCreateArray(select_op.inputs[0]);
|
||||
uint8_array.data_type = ArrayDataType::kUint8;
|
||||
OperatorSignature uint8_signature = {.model = &uint8_model, .op = &select_op};
|
||||
EXPECT_EQ(op->GetVersion(uint8_signature), 1);
|
||||
|
||||
Model int8_model;
|
||||
Array& int8_array = int8_model.GetOrCreateArray(select_op.inputs[0]);
|
||||
int8_array.data_type = ArrayDataType::kInt8;
|
||||
OperatorSignature int8_signature = {.model = &int8_model, .op = &select_op};
|
||||
EXPECT_EQ(op->GetVersion(int8_signature), 2);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user