Fix issue in tf.image.extract_glimpse (#12829)

* Fix issue in tf.image.extract_glimpse

This fix tries to fix the issue raised in 2134 where
`tf.image.extract_glimpse` does not work as expected
when `centered=False` and `normalized=False`

This fix fixes 2134.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add test cases for tf.image.extract_glimpse

Add test cases for tf.image.extract_glimpse with
centered=False and normalized=False

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2017-09-18 06:18:53 -07:00 committed by Benoit Steiner
parent 74cfc64734
commit 3303168ed6
2 changed files with 29 additions and 11 deletions

View File

@ -81,21 +81,26 @@ struct GlimpseExtractionOp {
for (Index i = 0; i < batch_size; ++i) {
float x = offsets_[i].first, y = offsets_[i].second;
// Un-normalize coordinates back to pixel space if normalized.
if (normalized_) {
// Un-normalize coordinates back to pixel space if normalized.
x *= input_width;
y *= input_height;
if (centered_) {
// Un-center if coordinates are centered on the image center.
x /= 2.0f;
y /= 2.0f;
x += input_width / 2.0f;
y += input_height / 2.0f;
// Remove half of the glimpse window.
x -= width_ / 2.0f;
y -= height_ / 2.0f;
}
} else {
if (centered_) {
x += input_width / 2.0f;
y += input_height / 2.0f;
}
}
// Un-center if coordinates are centered on the image center.
if (centered_) {
x /= 2.0f;
y /= 2.0f;
x += input_width / 2.0f;
y += input_height / 2.0f;
}
// Remove half of the glimpse window.
x -= width_ / 2.0f;
y -= height_ / 2.0f;
const Index offset_x = (Index) x;
const Index offset_y = (Index) y;

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.platform import test
@ -196,6 +197,18 @@ class ExtractGlimpseTest(test.TestCase):
expected_rows=[None, None, None, 1, 2, 3, 4],
expected_cols=[56, 57, 58, 59, 60])
def testGlimpseNonNormalizedNonCentered(self):
img = constant_op.constant(np.arange(25).reshape((1, 5, 5, 1)),
dtype=dtypes.float32)
with self.test_session():
result1 = image_ops.extract_glimpse(img, [3, 3], [[0, 0]],
centered=False, normalized=False)
result2 = image_ops.extract_glimpse(img, [3, 3], [[1, 0]],
centered=False, normalized=False)
self.assertAllEqual(np.asarray([[0, 1, 2], [5, 6, 7], [10, 11, 12]]),
result1.eval()[0, :, :, 0])
self.assertAllEqual(np.asarray([[5, 6, 7], [10, 11, 12], [15, 16, 17]]),
result2.eval()[0, :, :, 0])
if __name__ == '__main__':
test.main()