Update int8 logistic test case.
PiperOrigin-RevId: 234895201
This commit is contained in:
parent
8b228c17d8
commit
2127948229
@ -602,26 +602,6 @@ TEST_F(OperatorTest, BuiltinLeakyRelu) {
|
||||
EXPECT_EQ(op.alpha, output_toco_op->alpha);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningLogisticTest) {
|
||||
LogisticOperator logistic_op;
|
||||
logistic_op.inputs = {"input1"};
|
||||
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
||||
const BaseOperator* op = operator_by_type_map.at(logistic_op.type).get();
|
||||
|
||||
Model uint8_model;
|
||||
Array& uint8_array = uint8_model.GetOrCreateArray(logistic_op.inputs[0]);
|
||||
uint8_array.data_type = ArrayDataType::kUint8;
|
||||
OperatorSignature uint8_signature = {.model = &uint8_model,
|
||||
.op = &logistic_op};
|
||||
EXPECT_EQ(op->GetVersion(uint8_signature), 1);
|
||||
|
||||
Model int8_model;
|
||||
Array& int8_array = int8_model.GetOrCreateArray(logistic_op.inputs[0]);
|
||||
int8_array.data_type = ArrayDataType::kInt8;
|
||||
OperatorSignature int8_signature = {.model = &int8_model, .op = &logistic_op};
|
||||
EXPECT_EQ(op->GetVersion(int8_signature), 2);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, BuiltinSquaredDifference) {
|
||||
SquaredDifferenceOperator op;
|
||||
auto output_toco_op = SerializeAndDeserialize(
|
||||
@ -829,6 +809,10 @@ TEST_F(OperatorTest, VersioningSliceTest) {
|
||||
SimpleVersioningTest<SliceOperator>();
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningLogisticTest) {
|
||||
SimpleVersioningTest<LogisticOperator>();
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningAddTest) { SimpleVersioningTest<AddOperator>(); }
|
||||
|
||||
TEST_F(OperatorTest, VersioningSubTest) { SimpleVersioningTest<SubOperator>(); }
|
||||
|
Loading…
Reference in New Issue
Block a user