From fc8c61ccfae0bea7b8675342875b0e5138418c01 Mon Sep 17 00:00:00 2001 From: Yong Tang <yong.tang.github@outlook.com> Date: Tue, 13 Mar 2018 20:02:36 +0000 Subject: [PATCH] Add additional test cases Signed-off-by: Yong Tang <yong.tang.github@outlook.com> --- .../python/kernel_tests/attention_ops_test.py | 44 ++++++++++++++++--- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/kernel_tests/attention_ops_test.py b/tensorflow/python/kernel_tests/attention_ops_test.py index a90362645e5..2db74157fee 100644 --- a/tensorflow/python/kernel_tests/attention_ops_test.py +++ b/tensorflow/python/kernel_tests/attention_ops_test.py @@ -196,15 +196,47 @@ class ExtractGlimpseTest(test.TestCase): expected_rows=[None, None, None, 1, 2, 3, 4], expected_cols=[56, 57, 58, 59, 60]) - def testGlimpseNoOverlapZero(self): + def testGlimpseNoiseZero(self): + # Image: + # [ 0. 1. 2. 3. 4.] + # [ 5. 6. 7. 8. 9.] + # [ 10. 11. 12. 13. 14.] + # [ 15. 16. 17. 18. 19.] + # [ 20. 21. 22. 23. 24.] img = constant_op.constant(np.arange(25).reshape((1, 5, 5, 1)), dtype=dtypes.float32) with self.test_session(): - result = image_ops.extract_glimpse(img, [3, 3], [[-2, 2]], - centered=False, normalized=False, - uniform_noise=False, noise="zero") - self.assertAllEqual(np.asarray([[0, 0, 0], [0, 0, 0], [0, 0, 0]]), - result.eval()[0, :, :, 0]) + # Result 1: + # [ 0. 0. 0.] + # [ 0. 0. 0.] + # [ 0. 0. 0.] + result1 = image_ops.extract_glimpse(img, [3, 3], [[-2, 2]], + centered=False, normalized=False, + uniform_noise=False, noise="zero") + self.assertAllEqual(np.asarray([[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]), + result1.eval()[0, :, :, 0]) + + # Result 2: + # [ 0. 0. 0. 0. 0. 0. 0.] + # [ 0. 0. 1. 2. 3. 4. 0.] + # [ 0. 5. 6. 7. 8. 9. 0.] + # [ 0. 10. 11. 12. 13. 14. 0.] + # [ 0. 15. 16. 17. 18. 19. 0.] + # [ 0. 20. 21. 22. 23. 24. 0.] + # [ 0. 0. 0. 0. 0. 0. 0.]] + result2 = image_ops.extract_glimpse(img, [7, 7], [[0, 0]], + normalized=False, + uniform_noise=False, noise="zero") + self.assertAllEqual(np.asarray([[0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 3, 4, 0], + [0, 5, 6, 7, 8, 9, 0], + [0, 10, 11, 12, 13, 14, 0], + [0, 15, 16, 17, 18, 19, 0], + [0, 20, 21, 22, 23, 24, 0], + [0, 0, 0, 0, 0, 0, 0]]), + result2.eval()[0, :, :, 0]) if __name__ == '__main__': test.main()