Update select to support int8.
PiperOrigin-RevId: 232892451
This commit is contained in:
parent
b50b7911ec
commit
81b9a95045
@ -277,7 +277,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR());
|
AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR());
|
||||||
AddBuiltin(BuiltinOperator_CEIL, Register_CEIL());
|
AddBuiltin(BuiltinOperator_CEIL, Register_CEIL());
|
||||||
AddBuiltin(BuiltinOperator_NEG, Register_NEG());
|
AddBuiltin(BuiltinOperator_NEG, Register_NEG());
|
||||||
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT());
|
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT(),
|
||||||
|
/* min_version */ 1,
|
||||||
|
/* max_version */ 2);
|
||||||
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE());
|
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE());
|
||||||
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
||||||
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());
|
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());
|
||||||
|
@ -89,6 +89,9 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
case kTfLiteUInt8: \
|
case kTfLiteUInt8: \
|
||||||
TF_LITE_SELECT(uint8_t, op); \
|
TF_LITE_SELECT(uint8_t, op); \
|
||||||
break; \
|
break; \
|
||||||
|
case kTfLiteInt8: \
|
||||||
|
TF_LITE_SELECT(int8_t, op); \
|
||||||
|
break; \
|
||||||
case kTfLiteInt16: \
|
case kTfLiteInt16: \
|
||||||
TF_LITE_SELECT(int16_t, op); \
|
TF_LITE_SELECT(int16_t, op); \
|
||||||
break; \
|
break; \
|
||||||
|
@ -96,6 +96,19 @@ TEST(SelectOpTest, SelectUInt8) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SelectOpTest, SelectInt8) {
|
||||||
|
SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
|
||||||
|
TensorType_INT8);
|
||||||
|
|
||||||
|
model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
|
||||||
|
model.PopulateTensor<int8_t>(model.input2(), {1, -2, 3, 4});
|
||||||
|
model.PopulateTensor<int8_t>(model.input3(), {5, 6, 7, -8});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutput<int8_t>(), ElementsAreArray({5, -2, 7, -8}));
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(SelectOpTest, SelectInt16) {
|
TEST(SelectOpTest, SelectInt16) {
|
||||||
SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
|
SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
|
||||||
TensorType_INT16);
|
TensorType_INT16);
|
||||||
|
@ -2030,6 +2030,20 @@ class LessEqual : public SimpleOperator<TensorFlowLessEqualOperator> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Select : public SimpleOperator<SelectOperator> {
|
||||||
|
public:
|
||||||
|
explicit Select() : SimpleOperator("SELECT", OperatorType::kSelect) {}
|
||||||
|
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||||
|
const string& input_name = op_signature.op->inputs[0];
|
||||||
|
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||||
|
// Version 2 supports signed int8 input types.
|
||||||
|
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Build a vector containing all the known operators.
|
// Build a vector containing all the known operators.
|
||||||
std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||||
@ -2216,8 +2230,7 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
|||||||
ops.push_back(MakeUnique<NotEqual>());
|
ops.push_back(MakeUnique<NotEqual>());
|
||||||
ops.push_back(
|
ops.push_back(
|
||||||
MakeUnique<SimpleOperator<NegOperator>>("NEG", OperatorType::kNeg));
|
MakeUnique<SimpleOperator<NegOperator>>("NEG", OperatorType::kNeg));
|
||||||
ops.push_back(MakeUnique<SimpleOperator<SelectOperator>>(
|
ops.push_back(MakeUnique<Select>());
|
||||||
"SELECT", OperatorType::kSelect));
|
|
||||||
ops.push_back(
|
ops.push_back(
|
||||||
MakeUnique<SimpleOperator<SliceOperator>>("SLICE", OperatorType::kSlice));
|
MakeUnique<SimpleOperator<SliceOperator>>("SLICE", OperatorType::kSlice));
|
||||||
ops.push_back(
|
ops.push_back(
|
||||||
|
@ -721,6 +721,25 @@ TEST_F(OperatorTest, VersioningGreaterEqualTest) {
|
|||||||
VersioningTest<TensorFlowGreaterEqualOperator>();
|
VersioningTest<TensorFlowGreaterEqualOperator>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OperatorTest, VersioningSelectTest) {
|
||||||
|
SelectOperator select_op;
|
||||||
|
select_op.inputs = {"input1"};
|
||||||
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
||||||
|
const BaseOperator* op = operator_by_type_map.at(select_op.type).get();
|
||||||
|
|
||||||
|
Model uint8_model;
|
||||||
|
Array& uint8_array = uint8_model.GetOrCreateArray(select_op.inputs[0]);
|
||||||
|
uint8_array.data_type = ArrayDataType::kUint8;
|
||||||
|
OperatorSignature uint8_signature = {.model = &uint8_model, .op = &select_op};
|
||||||
|
EXPECT_EQ(op->GetVersion(uint8_signature), 1);
|
||||||
|
|
||||||
|
Model int8_model;
|
||||||
|
Array& int8_array = int8_model.GetOrCreateArray(select_op.inputs[0]);
|
||||||
|
int8_array.data_type = ArrayDataType::kInt8;
|
||||||
|
OperatorSignature int8_signature = {.model = &int8_model, .op = &select_op};
|
||||||
|
EXPECT_EQ(op->GetVersion(int8_signature), 2);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user