Merge pull request #26737 from amitsrivastava78:fully
PiperOrigin-RevId: 238734455
This commit is contained in:
commit
e1172009b5
@ -132,7 +132,6 @@ static float fully_connected_golden_output[] = {
|
|||||||
|
|
||||||
class BaseFullyConnectedOpModel : public SingleOpModel {
|
class BaseFullyConnectedOpModel : public SingleOpModel {
|
||||||
public:
|
public:
|
||||||
// TODO(ahentz): test different activation types too.
|
|
||||||
BaseFullyConnectedOpModel(
|
BaseFullyConnectedOpModel(
|
||||||
TfLiteRegistration* registration, int units, int batches,
|
TfLiteRegistration* registration, int units, int batches,
|
||||||
const TensorData& input, const TensorData& output = {TensorType_FLOAT32},
|
const TensorData& input, const TensorData& output = {TensorType_FLOAT32},
|
||||||
@ -428,6 +427,99 @@ TEST(FloatFullyConnectedOpTest, SimpleTestNoBias) {
|
|||||||
EXPECT_THAT(m.GetOutput(), ElementsAre(10, 8));
|
EXPECT_THAT(m.GetOutput(), ElementsAre(10, 8));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(FloatFullyConnectedOpTest, ActivationRelu6) {
|
||||||
|
// The optimized kernel assumes that the bias is specified.
|
||||||
|
FloatFullyConnectedOpModel m(
|
||||||
|
ops::builtin::Register_FULLY_CONNECTED_PIE(),
|
||||||
|
/*units=*/1, /*batches=*/2,
|
||||||
|
/*input=*/{TensorType_FLOAT32, {2, 2}},
|
||||||
|
/*output=*/{TensorType_FLOAT32},
|
||||||
|
/*bias_tensor_optional=*/false,
|
||||||
|
/*ActivationFunctionType*/ ActivationFunctionType_RELU6);
|
||||||
|
m.SetWeights({
|
||||||
|
2, 4, // u = 0
|
||||||
|
});
|
||||||
|
|
||||||
|
m.SetInput({
|
||||||
|
1, 2, // b = 0
|
||||||
|
2, 1, // b = 1
|
||||||
|
});
|
||||||
|
|
||||||
|
m.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAre(6, 6));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FloatFullyConnectedOpTest, ActivationTanh) {
|
||||||
|
// The optimized kernel assumes that the bias is specified.
|
||||||
|
FloatFullyConnectedOpModel m(
|
||||||
|
ops::builtin::Register_FULLY_CONNECTED_PIE(),
|
||||||
|
/*units=*/1, /*batches=*/2,
|
||||||
|
/*input=*/{TensorType_FLOAT32, {2, 2}},
|
||||||
|
/*output=*/{TensorType_FLOAT32},
|
||||||
|
/*bias_tensor_optional=*/false,
|
||||||
|
/*ActivationFunctionType*/ ActivationFunctionType_TANH);
|
||||||
|
m.SetWeights({
|
||||||
|
-2, 1, // u = 0
|
||||||
|
});
|
||||||
|
|
||||||
|
m.SetInput({
|
||||||
|
1, 4, // b = 0
|
||||||
|
2, 1, // b = 1
|
||||||
|
});
|
||||||
|
|
||||||
|
m.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(m.GetOutput(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({0.964028, -0.995055})));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FloatFullyConnectedOpTest, ActivationSign) {
|
||||||
|
FloatFullyConnectedOpModel m(
|
||||||
|
ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT(),
|
||||||
|
/*units=*/1, /*batches=*/2,
|
||||||
|
/*input=*/{TensorType_FLOAT32, {2, 2}},
|
||||||
|
/*output=*/{TensorType_FLOAT32},
|
||||||
|
/*bias_tensor_optional=*/false,
|
||||||
|
/*ActivationFunctionType*/ ActivationFunctionType_SIGN_BIT);
|
||||||
|
m.SetWeights({
|
||||||
|
2, 4, // u = 0
|
||||||
|
});
|
||||||
|
m.SetBias({1});
|
||||||
|
|
||||||
|
m.SetInput({
|
||||||
|
1, -2, // b = 0
|
||||||
|
-2, 1, // b = 1
|
||||||
|
});
|
||||||
|
|
||||||
|
m.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAre(-5, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FloatFullyConnectedOpTest, ActivationN1) {
|
||||||
|
FloatFullyConnectedOpModel m(
|
||||||
|
ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT(),
|
||||||
|
/*units=*/1, /*batches=*/2,
|
||||||
|
/*input=*/{TensorType_FLOAT32, {2, 2}},
|
||||||
|
/*output=*/{TensorType_FLOAT32},
|
||||||
|
/*bias_tensor_optional=*/false,
|
||||||
|
/*ActivationFunctionType*/ ActivationFunctionType_RELU_N1_TO_1);
|
||||||
|
m.SetWeights({
|
||||||
|
2, 4, // u = 0
|
||||||
|
});
|
||||||
|
m.SetBias({1});
|
||||||
|
|
||||||
|
m.SetInput({
|
||||||
|
1, -2, // b = 0
|
||||||
|
-2, 1, // b = 1
|
||||||
|
});
|
||||||
|
|
||||||
|
m.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAre(-1, 1));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8) {
|
TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8) {
|
||||||
QuantizedFullyConnectedOpModel m(
|
QuantizedFullyConnectedOpModel m(
|
||||||
GetRegistration(), /*units=*/3, /*batches*/ 2,
|
GetRegistration(), /*units=*/3, /*batches*/ 2,
|
||||||
|
Loading…
Reference in New Issue
Block a user