Merge pull request #38549 from yongtang:12829-extract_glimpse

PiperOrigin-RevId: 312715239
Change-Id: I4d40eaeed615fceed8f29d380709dad5b9fed216
This commit is contained in:
TensorFlower Gardener 2020-05-21 12:14:19 -07:00
commit eaa710abca
9 changed files with 273 additions and 36 deletions

View File

@ -1,3 +1,16 @@
# Release 2.3.0
## Breaking Changes
* `tf.image.extract_glimpse` has been updated to correctly process the case
where `centered=False` and `normalized=False`. This is a breaking change as
the output is different from (incorrect) previous versions. Note this
breaking change only impacts `tf.image.extract_glimpse` and
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved
models will not be impacted.
# Release 2.1.1
## Bug Fixes and Other Changes

View File

@ -0,0 +1,86 @@
op {
graph_op_name: "ExtractGlimpseV2"
visibility: HIDDEN
in_arg {
name: "input"
description: <<END
A 4-D float tensor of shape `[batch_size, height, width, channels]`.
END
}
in_arg {
name: "size"
description: <<END
A 1-D tensor of 2 elements containing the size of the glimpses
to extract. The glimpse height must be specified first, following
by the glimpse width.
END
}
in_arg {
name: "offsets"
description: <<END
A 2-D integer tensor of shape `[batch_size, 2]` containing
the y, x locations of the center of each window.
END
}
out_arg {
name: "glimpse"
description: <<END
A tensor representing the glimpses `[batch_size,
glimpse_height, glimpse_width, channels]`.
END
}
attr {
name: "centered"
description: <<END
indicates if the offset coordinates are centered relative to
the image, in which case the (0, 0) offset is relative to the center
of the input images. If false, the (0,0) offset corresponds to the
upper left corner of the input images.
END
}
attr {
name: "normalized"
description: <<END
indicates if the offset coordinates are normalized.
END
}
attr {
name: "uniform_noise"
description: <<END
indicates if the noise should be generated using a
uniform distribution or a Gaussian distribution.
END
}
attr {
name: "noise"
description: <<END
indicates if the noise should `uniform`, `gaussian`, or
`zero`. The default is `uniform` which means the the noise type
will be decided by `uniform_noise`.
END
}
summary: "Extracts a glimpse from the input tensor."
description: <<END
Returns a set of windows called glimpses extracted at location
`offsets` from the input tensor. If the windows only partially
overlaps the inputs, the non overlapping areas will be filled with
random noise.
The result is a 4-D tensor of shape `[batch_size, glimpse_height,
glimpse_width, channels]`. The channels and batch dimensions are the
same as that of the input tensor. The height and width of the output
windows are specified in the `size` parameter.
The argument `normalized` and `centered` controls how the windows are built:
* If the coordinates are normalized but not centered, 0.0 and 1.0
correspond to the minimum and maximum of each height and width
dimension.
* If the coordinates are both normalized and centered, they range from
-1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper
left corner, the lower right corner is located at (1.0, 1.0) and the
center is at (0, 0).
* If the coordinates are not normalized they are interpreted as
numbers of pixels.
END
}

View File

@ -32,6 +32,8 @@ namespace tensorflow {
class ExtractGlimpseOp : public OpKernel {
public:
explicit ExtractGlimpseOp(OpKernelConstruction* context) : OpKernel(context) {
const string& op = context->def().op();
version_ = (op == "ExtractGlimpse") ? 1 : 2;
OP_REQUIRES_OK(context, context->GetAttr("normalized", &normalized_));
OP_REQUIRES_OK(context, context->GetAttr("centered", &centered_));
bool uniform_noise = false;
@ -117,21 +119,23 @@ class ExtractGlimpseOp : public OpKernel {
// calling TensorFlow operates with (y,x) as indices.
offset_vec.push_back(Eigen::IndexPair<float>(offset_x, offset_y));
}
output->tensor<float, 4>().swap_layout().device(
context->eigen_cpu_device()) =
Eigen::ExtractGlimpses(input.tensor<float, 4>().swap_layout(),
output_width, output_height, offset_vec,
normalized_, centered_, noise_);
normalized_, centered_, noise_, version_);
}
private:
bool normalized_;
bool centered_;
Eigen::ExtractGlimpsesNoiseMode noise_;
int32 version_;
};
REGISTER_KERNEL_BUILDER(Name("ExtractGlimpse").Device(DEVICE_CPU),
ExtractGlimpseOp);
REGISTER_KERNEL_BUILDER(Name("ExtractGlimpseV2").Device(DEVICE_CPU),
ExtractGlimpseOp);
} // end namespace tensorflow

View File

@ -56,13 +56,14 @@ struct GlimpseExtractionOp {
GlimpseExtractionOp(const Index width, const Index height,
const std::vector<IndexPair<float> >& offsets,
const bool normalized, const bool centered,
const ExtractGlimpsesNoiseMode noise)
const ExtractGlimpsesNoiseMode noise, const int version)
: width_(width),
height_(height),
offsets_(offsets),
normalized_(normalized),
centered_(centered),
noise_(noise) {}
noise_(noise),
version_(version) {}
template <typename Input>
DSizes<Index, 4> dimensions(const Input& input) const {
@ -101,6 +102,7 @@ struct GlimpseExtractionOp {
for (Index i = 0; i < batch_size; ++i) {
float x = offsets_[i].first, y = offsets_[i].second;
if (version_ == 1) {
// Un-normalize coordinates back to pixel space if normalized.
if (normalized_) {
x *= input_width;
@ -116,6 +118,28 @@ struct GlimpseExtractionOp {
// Remove half of the glimpse window.
x -= width_ / 2.0f;
y -= height_ / 2.0f;
} else {
if (normalized_) {
// Un-normalize coordinates back to pixel space if normalized.
x *= input_width;
y *= input_height;
if (centered_) {
// Un-center if coordinates are centered on the image center.
x /= 2.0f;
y /= 2.0f;
x += input_width / 2.0f;
y += input_height / 2.0f;
// Remove half of the glimpse window.
x -= width_ / 2.0f;
y -= height_ / 2.0f;
}
} else {
if (centered_) {
x += input_width / 2.0f;
y += input_height / 2.0f;
}
}
}
const Index offset_x = (Index)x;
const Index offset_y = (Index)y;
@ -243,6 +267,7 @@ struct GlimpseExtractionOp {
const bool normalized_;
const bool centered_;
const ExtractGlimpsesNoiseMode noise_;
const int version_;
};
} // namespace
@ -255,7 +280,8 @@ ExtractGlimpses(
const typename internal::traits<Input>::Index height,
const std::vector<IndexPair<float> >& offsets, const bool normalized = true,
const bool centered = true,
const ExtractGlimpsesNoiseMode noise = ExtractGlimpsesNoiseMode::UNIFORM) {
const ExtractGlimpsesNoiseMode noise = ExtractGlimpsesNoiseMode::UNIFORM,
const int version = 2) {
EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == ColMajor,
YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 4,
@ -263,7 +289,7 @@ ExtractGlimpses(
typedef typename internal::traits<Input>::Index Index;
const GlimpseExtractionOp<Index> op(width, height, offsets, normalized,
centered, noise);
centered, noise, version);
return input.customOp(op);
}

View File

@ -756,6 +756,41 @@ REGISTER_OP("ExtractGlimpse")
c->Dim(input, 3));
});
REGISTER_OP("ExtractGlimpseV2")
.Input("input: float")
.Input("size: int32")
.Input("offsets: float")
.Output("glimpse: float")
.Attr("centered: bool = true")
.Attr("normalized: bool = true")
.Attr("uniform_noise: bool = true")
.Attr("noise: string = 'uniform'")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
ShapeHandle offsets;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets));
DimensionHandle batch_dim;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim));
DimensionHandle unused;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused));
bool uniform_noise = false;
TF_RETURN_IF_ERROR(c->GetAttr("uniform_noise", &uniform_noise));
string noise;
TF_RETURN_IF_ERROR(c->GetAttr("noise", &noise));
if (uniform_noise && (!noise.empty() && noise != "uniform")) {
return errors::InvalidArgument(
"The uniform_noise and noise should not be specified at the same "
"time");
}
return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
c->Dim(input, 3));
});
// --------------------------------------------------------------------------
REGISTER_OP("CropAndResize")

View File

@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.platform import test
@ -196,6 +197,55 @@ class ExtractGlimpseTest(test.TestCase):
expected_rows=[None, None, None, 1, 2, 3, 4],
expected_cols=[56, 57, 58, 59, 60])
def testGlimpseNoiseZeroV1Compatible(self):
# Note: The old versions of extract_glimpse was incorrect in implementation.
# This test is for compatibility so that graph save in old versions behave
# the same. Notice the API uses gen_image_ops.extract_glimpse() on purpose.
#
# Image:
# [ 0. 1. 2. 3. 4.]
# [ 5. 6. 7. 8. 9.]
# [ 10. 11. 12. 13. 14.]
# [ 15. 16. 17. 18. 19.]
# [ 20. 21. 22. 23. 24.]
img = constant_op.constant(
np.arange(25).reshape((1, 5, 5, 1)), dtype=dtypes.float32)
with self.test_session():
# Result 1:
# [ 0. 0. 0.]
# [ 0. 0. 0.]
# [ 0. 0. 0.]
result1 = gen_image_ops.extract_glimpse(
img, [3, 3], [[-2, 2]],
centered=False,
normalized=False,
noise='zero',
uniform_noise=False)
self.assertAllEqual(
np.asarray([[0, 0, 0], [0, 0, 0], [0, 0, 0]]),
self.evaluate(result1)[0, :, :, 0])
# Result 2:
# [ 0. 0. 0. 0. 0. 0. 0.]
# [ 0. 0. 1. 2. 3. 4. 0.]
# [ 0. 5. 6. 7. 8. 9. 0.]
# [ 0. 10. 11. 12. 13. 14. 0.]
# [ 0. 15. 16. 17. 18. 19. 0.]
# [ 0. 20. 21. 22. 23. 24. 0.]
# [ 0. 0. 0. 0. 0. 0. 0.]
result2 = gen_image_ops.extract_glimpse(
img, [7, 7], [[0, 0]],
normalized=False,
noise='zero',
uniform_noise=False)
self.assertAllEqual(
np.asarray([[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 4, 0],
[0, 5, 6, 7, 8, 9, 0], [0, 10, 11, 12, 13, 14, 0],
[0, 15, 16, 17, 18, 19, 0], [0, 20, 21, 22, 23, 24, 0],
[0, 0, 0, 0, 0, 0, 0]]),
self.evaluate(result2)[0, :, :, 0])
def testGlimpseNoiseZero(self):
# Image:
# [ 0. 1. 2. 3. 4.]
@ -211,7 +261,7 @@ class ExtractGlimpseTest(test.TestCase):
# [ 0. 0. 0.]
# [ 0. 0. 0.]
result1 = image_ops.extract_glimpse_v2(
img, [3, 3], [[-2, 2]],
img, [3, 3], [[-2, -2]],
centered=False,
normalized=False,
noise='zero')
@ -220,22 +270,37 @@ class ExtractGlimpseTest(test.TestCase):
self.evaluate(result1)[0, :, :, 0])
# Result 2:
# [ 12. 13. 14. 0. 0. 0. 0.]
# [ 17. 18. 19. 0. 0. 0. 0.]
# [ 22. 23. 24. 0. 0. 0. 0.]
# [ 0. 0. 0. 0. 0. 0. 0.]
# [ 0. 0. 0. 0. 0. 0. 0.]
# [ 0. 0. 0. 0. 0. 0. 0.]
# [ 0. 0. 1. 2. 3. 4. 0.]
# [ 0. 5. 6. 7. 8. 9. 0.]
# [ 0. 10. 11. 12. 13. 14. 0.]
# [ 0. 15. 16. 17. 18. 19. 0.]
# [ 0. 20. 21. 22. 23. 24. 0.]
# [ 0. 0. 0. 0. 0. 0. 0.]
result2 = image_ops.extract_glimpse_v2(
img, [7, 7], [[0, 0]], normalized=False, noise='zero')
self.assertAllEqual(
np.asarray([[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 4, 0],
[0, 5, 6, 7, 8, 9, 0], [0, 10, 11, 12, 13, 14, 0],
[0, 15, 16, 17, 18, 19, 0], [0, 20, 21, 22, 23, 24, 0],
np.asarray([[12, 13, 14, 0, 0, 0, 0], [17, 18, 19, 0, 0, 0, 0],
[22, 23, 24, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0]]),
self.evaluate(result2)[0, :, :, 0])
def testGlimpseNonNormalizedNonCentered(self):
img = constant_op.constant(
np.arange(25).reshape((1, 5, 5, 1)), dtype=dtypes.float32)
with self.test_session():
result1 = image_ops.extract_glimpse_v2(
img, [3, 3], [[0, 0]], centered=False, normalized=False)
result2 = image_ops.extract_glimpse_v2(
img, [3, 3], [[1, 0]], centered=False, normalized=False)
self.assertAllEqual(
np.asarray([[0, 1, 2], [5, 6, 7], [10, 11, 12]]),
self.evaluate(result1)[0, :, :, 0])
self.assertAllEqual(
np.asarray([[5, 6, 7], [10, 11, 12], [15, 16, 17]]),
self.evaluate(result2)[0, :, :, 0])
if __name__ == '__main__':
test.main()

View File

@ -4114,7 +4114,7 @@ def extract_glimpse(
... [[6.0],
... [7.0],
... [8.0]]]]
>>> tf.image.extract_glimpse(x, size=(2, 2), offsets=[[1, 1]],
>>> tf.compat.v1.image.extract_glimpse(x, size=(2, 2), offsets=[[1, 1]],
... centered=False, normalized=False)
<tf.Tensor: shape=(1, 2, 2, 1), dtype=float32, numpy=
array([[[[0.],
@ -4203,10 +4203,10 @@ def extract_glimpse_v2(
>>> tf.image.extract_glimpse(x, size=(2, 2), offsets=[[1, 1]],
... centered=False, normalized=False)
<tf.Tensor: shape=(1, 2, 2, 1), dtype=float32, numpy=
array([[[[0.],
[1.]],
[[3.],
[4.]]]], dtype=float32)>
array([[[[4.],
[5.]],
[[7.],
[8.]]]], dtype=float32)>
Args:
input: A `Tensor` of type `float32`. A 4-D float tensor of shape
@ -4231,7 +4231,7 @@ def extract_glimpse_v2(
Returns:
A `Tensor` of type `float32`.
"""
return gen_image_ops.extract_glimpse(
return gen_image_ops.extract_glimpse_v2(
input=input,
size=size,
offsets=offsets,

View File

@ -1476,6 +1476,10 @@ tf_module {
name: "ExtractGlimpse"
argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'uniform\', \'None\'], "
}
member_method {
name: "ExtractGlimpseV2"
argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'uniform\', \'None\'], "
}
member_method {
name: "ExtractImagePatches"
argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -1476,6 +1476,10 @@ tf_module {
name: "ExtractGlimpse"
argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'uniform\', \'None\'], "
}
member_method {
name: "ExtractGlimpseV2"
argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'uniform\', \'None\'], "
}
member_method {
name: "ExtractImagePatches"
argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "