diff --git a/tensorflow/core/kernels/eigen_attention.h b/tensorflow/core/kernels/eigen_attention.h index f4c42372b18..887b9b72218 100644 --- a/tensorflow/core/kernels/eigen_attention.h +++ b/tensorflow/core/kernels/eigen_attention.h @@ -81,21 +81,26 @@ struct GlimpseExtractionOp { for (Index i = 0; i < batch_size; ++i) { float x = offsets_[i].first, y = offsets_[i].second; - // Un-normalize coordinates back to pixel space if normalized. if (normalized_) { + // Un-normalize coordinates back to pixel space if normalized. x *= input_width; y *= input_height; + if (centered_) { + // Un-center if coordinates are centered on the image center. + x /= 2.0f; + y /= 2.0f; + x += input_width / 2.0f; + y += input_height / 2.0f; + // Remove half of the glimpse window. + x -= width_ / 2.0f; + y -= height_ / 2.0f; + } + } else { + if (centered_) { + x += input_width / 2.0f; + y += input_height / 2.0f; + } } - // Un-center if coordinates are centered on the image center. - if (centered_) { - x /= 2.0f; - y /= 2.0f; - x += input_width / 2.0f; - y += input_height / 2.0f; - } - // Remove half of the glimpse window. - x -= width_ / 2.0f; - y -= height_ / 2.0f; const Index offset_x = (Index) x; const Index offset_y = (Index) y; diff --git a/tensorflow/python/kernel_tests/attention_ops_test.py b/tensorflow/python/kernel_tests/attention_ops_test.py index f9c1727309e..9e8a4f17068 100644 --- a/tensorflow/python/kernel_tests/attention_ops_test.py +++ b/tensorflow/python/kernel_tests/attention_ops_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import image_ops from tensorflow.python.platform import test @@ -196,6 +197,18 @@ class ExtractGlimpseTest(test.TestCase): expected_rows=[None, None, None, 1, 2, 3, 4], expected_cols=[56, 57, 58, 59, 60]) + def testGlimpseNonNormalizedNonCentered(self): + img = constant_op.constant(np.arange(25).reshape((1, 5, 5, 1)), + dtype=dtypes.float32) + with self.test_session(): + result1 = image_ops.extract_glimpse(img, [3, 3], [[0, 0]], + centered=False, normalized=False) + result2 = image_ops.extract_glimpse(img, [3, 3], [[1, 0]], + centered=False, normalized=False) + self.assertAllEqual(np.asarray([[0, 1, 2], [5, 6, 7], [10, 11, 12]]), + result1.eval()[0, :, :, 0]) + self.assertAllEqual(np.asarray([[5, 6, 7], [10, 11, 12], [15, 16, 17]]), + result2.eval()[0, :, :, 0]) if __name__ == '__main__': test.main()