Fix issue in tf.image.extract_glimpse (#12829)
* Fix issue in tf.image.extract_glimpse This fix tries to fix the issue raised in 2134 where `tf.image.extract_glimpse` does not work as expected when `centered=False` and `normalized=False` This fix fixes 2134. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test cases for tf.image.extract_glimpse Add test cases for tf.image.extract_glimpse with centered=False and normalized=False Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
74cfc64734
commit
3303168ed6
@ -81,21 +81,26 @@ struct GlimpseExtractionOp {
|
||||
for (Index i = 0; i < batch_size; ++i) {
|
||||
float x = offsets_[i].first, y = offsets_[i].second;
|
||||
|
||||
// Un-normalize coordinates back to pixel space if normalized.
|
||||
if (normalized_) {
|
||||
// Un-normalize coordinates back to pixel space if normalized.
|
||||
x *= input_width;
|
||||
y *= input_height;
|
||||
}
|
||||
// Un-center if coordinates are centered on the image center.
|
||||
if (centered_) {
|
||||
// Un-center if coordinates are centered on the image center.
|
||||
x /= 2.0f;
|
||||
y /= 2.0f;
|
||||
x += input_width / 2.0f;
|
||||
y += input_height / 2.0f;
|
||||
}
|
||||
// Remove half of the glimpse window.
|
||||
x -= width_ / 2.0f;
|
||||
y -= height_ / 2.0f;
|
||||
}
|
||||
} else {
|
||||
if (centered_) {
|
||||
x += input_width / 2.0f;
|
||||
y += input_height / 2.0f;
|
||||
}
|
||||
}
|
||||
|
||||
const Index offset_x = (Index) x;
|
||||
const Index offset_y = (Index) y;
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import image_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -196,6 +197,18 @@ class ExtractGlimpseTest(test.TestCase):
|
||||
expected_rows=[None, None, None, 1, 2, 3, 4],
|
||||
expected_cols=[56, 57, 58, 59, 60])
|
||||
|
||||
def testGlimpseNonNormalizedNonCentered(self):
|
||||
img = constant_op.constant(np.arange(25).reshape((1, 5, 5, 1)),
|
||||
dtype=dtypes.float32)
|
||||
with self.test_session():
|
||||
result1 = image_ops.extract_glimpse(img, [3, 3], [[0, 0]],
|
||||
centered=False, normalized=False)
|
||||
result2 = image_ops.extract_glimpse(img, [3, 3], [[1, 0]],
|
||||
centered=False, normalized=False)
|
||||
self.assertAllEqual(np.asarray([[0, 1, 2], [5, 6, 7], [10, 11, 12]]),
|
||||
result1.eval()[0, :, :, 0])
|
||||
self.assertAllEqual(np.asarray([[5, 6, 7], [10, 11, 12], [15, 16, 17]]),
|
||||
result2.eval()[0, :, :, 0])
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user