Update select to support int8.

PiperOrigin-RevId: 232892451
This commit is contained in:
Shashi Shekhar 2019-02-07 09:52:45 -08:00 committed by TensorFlower Gardener
parent b50b7911ec
commit 81b9a95045
5 changed files with 53 additions and 3 deletions

View File

@ -277,7 +277,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR());
AddBuiltin(BuiltinOperator_CEIL, Register_CEIL());
AddBuiltin(BuiltinOperator_NEG, Register_NEG());
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT());
AddBuiltin(BuiltinOperator_SELECT, Register_SELECT(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE());
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());

View File

@ -89,6 +89,9 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8: \
TF_LITE_SELECT(uint8_t, op); \
break; \
case kTfLiteInt8: \
TF_LITE_SELECT(int8_t, op); \
break; \
case kTfLiteInt16: \
TF_LITE_SELECT(int16_t, op); \
break; \

View File

@ -96,6 +96,19 @@ TEST(SelectOpTest, SelectUInt8) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
TEST(SelectOpTest, SelectInt8) {
SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
TensorType_INT8);
model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
model.PopulateTensor<int8_t>(model.input2(), {1, -2, 3, 4});
model.PopulateTensor<int8_t>(model.input3(), {5, 6, 7, -8});
model.Invoke();
EXPECT_THAT(model.GetOutput<int8_t>(), ElementsAreArray({5, -2, 7, -8}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
TEST(SelectOpTest, SelectInt16) {
SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
TensorType_INT16);

View File

@ -2030,6 +2030,20 @@ class LessEqual : public SimpleOperator<TensorFlowLessEqualOperator> {
}
};
class Select : public SimpleOperator<SelectOperator> {
public:
explicit Select() : SimpleOperator("SELECT", OperatorType::kSelect) {}
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
// Version 2 supports signed int8 input types.
if (input_array.data_type == ArrayDataType::kInt8) {
return 2;
}
return 1;
}
};
namespace {
// Build a vector containing all the known operators.
std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
@ -2216,8 +2230,7 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
ops.push_back(MakeUnique<NotEqual>());
ops.push_back(
MakeUnique<SimpleOperator<NegOperator>>("NEG", OperatorType::kNeg));
ops.push_back(MakeUnique<SimpleOperator<SelectOperator>>(
"SELECT", OperatorType::kSelect));
ops.push_back(MakeUnique<Select>());
ops.push_back(
MakeUnique<SimpleOperator<SliceOperator>>("SLICE", OperatorType::kSlice));
ops.push_back(

View File

@ -721,6 +721,25 @@ TEST_F(OperatorTest, VersioningGreaterEqualTest) {
VersioningTest<TensorFlowGreaterEqualOperator>();
}
TEST_F(OperatorTest, VersioningSelectTest) {
SelectOperator select_op;
select_op.inputs = {"input1"};
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
const BaseOperator* op = operator_by_type_map.at(select_op.type).get();
Model uint8_model;
Array& uint8_array = uint8_model.GetOrCreateArray(select_op.inputs[0]);
uint8_array.data_type = ArrayDataType::kUint8;
OperatorSignature uint8_signature = {.model = &uint8_model, .op = &select_op};
EXPECT_EQ(op->GetVersion(uint8_signature), 1);
Model int8_model;
Array& int8_array = int8_model.GetOrCreateArray(select_op.inputs[0]);
int8_array.data_type = ArrayDataType::kInt8;
OperatorSignature int8_signature = {.model = &int8_model, .op = &select_op};
EXPECT_EQ(op->GetVersion(int8_signature), 2);
}
} // namespace
} // namespace tflite