diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc index 7a69b999f1f..114d30ebf40 100644 --- a/tensorflow/lite/kernels/fully_connected_test.cc +++ b/tensorflow/lite/kernels/fully_connected_test.cc @@ -132,7 +132,6 @@ static float fully_connected_golden_output[] = { class BaseFullyConnectedOpModel : public SingleOpModel { public: - // TODO(ahentz): test different activation types too. BaseFullyConnectedOpModel( TfLiteRegistration* registration, int units, int batches, const TensorData& input, const TensorData& output = {TensorType_FLOAT32}, @@ -428,6 +427,99 @@ TEST(FloatFullyConnectedOpTest, SimpleTestNoBias) { EXPECT_THAT(m.GetOutput(), ElementsAre(10, 8)); } +TEST(FloatFullyConnectedOpTest, ActivationRelu6) { + // The optimized kernel assumes that the bias is specified. + FloatFullyConnectedOpModel m( + ops::builtin::Register_FULLY_CONNECTED_PIE(), + /*units=*/1, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 2}}, + /*output=*/{TensorType_FLOAT32}, + /*bias_tensor_optional=*/false, + /*ActivationFunctionType*/ ActivationFunctionType_RELU6); + m.SetWeights({ + 2, 4, // u = 0 + }); + + m.SetInput({ + 1, 2, // b = 0 + 2, 1, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(6, 6)); +} + +TEST(FloatFullyConnectedOpTest, ActivationTanh) { + // The optimized kernel assumes that the bias is specified. + FloatFullyConnectedOpModel m( + ops::builtin::Register_FULLY_CONNECTED_PIE(), + /*units=*/1, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 2}}, + /*output=*/{TensorType_FLOAT32}, + /*bias_tensor_optional=*/false, + /*ActivationFunctionType*/ ActivationFunctionType_TANH); + m.SetWeights({ + -2, 1, // u = 0 + }); + + m.SetInput({ + 1, 4, // b = 0 + 2, 1, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({0.964028, -0.995055}))); +} + +TEST(FloatFullyConnectedOpTest, ActivationSign) { + FloatFullyConnectedOpModel m( + ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT(), + /*units=*/1, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 2}}, + /*output=*/{TensorType_FLOAT32}, + /*bias_tensor_optional=*/false, + /*ActivationFunctionType*/ ActivationFunctionType_SIGN_BIT); + m.SetWeights({ + 2, 4, // u = 0 + }); + m.SetBias({1}); + + m.SetInput({ + 1, -2, // b = 0 + -2, 1, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(-5, 1)); +} + +TEST(FloatFullyConnectedOpTest, ActivationN1) { + FloatFullyConnectedOpModel m( + ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT(), + /*units=*/1, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 2}}, + /*output=*/{TensorType_FLOAT32}, + /*bias_tensor_optional=*/false, + /*ActivationFunctionType*/ ActivationFunctionType_RELU_N1_TO_1); + m.SetWeights({ + 2, 4, // u = 0 + }); + m.SetBias({1}); + + m.SetInput({ + 1, -2, // b = 0 + -2, 1, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(-1, 1)); +} + TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8) { QuantizedFullyConnectedOpModel m( GetRegistration(), /*units=*/3, /*batches*/ 2,