Merge pull request #38549 from yongtang:12829-extract_glimpse
PiperOrigin-RevId: 312715239 Change-Id: I4d40eaeed615fceed8f29d380709dad5b9fed216
This commit is contained in:
commit
eaa710abca
13
RELEASE.md
13
RELEASE.md
@ -1,3 +1,16 @@
|
||||
# Release 2.3.0
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
* `tf.image.extract_glimpse` has been updated to correctly process the case
|
||||
where `centered=False` and `normalized=False`. This is a breaking change as
|
||||
the output is different from (incorrect) previous versions. Note this
|
||||
breaking change only impacts `tf.image.extract_glimpse` and
|
||||
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
|
||||
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
|
||||
exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved
|
||||
models will not be impacted.
|
||||
|
||||
# Release 2.1.1
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
@ -0,0 +1,86 @@
|
||||
op {
|
||||
graph_op_name: "ExtractGlimpseV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: <<END
|
||||
A 4-D float tensor of shape `[batch_size, height, width, channels]`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "size"
|
||||
description: <<END
|
||||
A 1-D tensor of 2 elements containing the size of the glimpses
|
||||
to extract. The glimpse height must be specified first, following
|
||||
by the glimpse width.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "offsets"
|
||||
description: <<END
|
||||
A 2-D integer tensor of shape `[batch_size, 2]` containing
|
||||
the y, x locations of the center of each window.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "glimpse"
|
||||
description: <<END
|
||||
A tensor representing the glimpses `[batch_size,
|
||||
glimpse_height, glimpse_width, channels]`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "centered"
|
||||
description: <<END
|
||||
indicates if the offset coordinates are centered relative to
|
||||
the image, in which case the (0, 0) offset is relative to the center
|
||||
of the input images. If false, the (0,0) offset corresponds to the
|
||||
upper left corner of the input images.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "normalized"
|
||||
description: <<END
|
||||
indicates if the offset coordinates are normalized.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "uniform_noise"
|
||||
description: <<END
|
||||
indicates if the noise should be generated using a
|
||||
uniform distribution or a Gaussian distribution.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "noise"
|
||||
description: <<END
|
||||
indicates if the noise should `uniform`, `gaussian`, or
|
||||
`zero`. The default is `uniform` which means the the noise type
|
||||
will be decided by `uniform_noise`.
|
||||
END
|
||||
}
|
||||
summary: "Extracts a glimpse from the input tensor."
|
||||
description: <<END
|
||||
Returns a set of windows called glimpses extracted at location
|
||||
`offsets` from the input tensor. If the windows only partially
|
||||
overlaps the inputs, the non overlapping areas will be filled with
|
||||
random noise.
|
||||
|
||||
The result is a 4-D tensor of shape `[batch_size, glimpse_height,
|
||||
glimpse_width, channels]`. The channels and batch dimensions are the
|
||||
same as that of the input tensor. The height and width of the output
|
||||
windows are specified in the `size` parameter.
|
||||
|
||||
The argument `normalized` and `centered` controls how the windows are built:
|
||||
|
||||
* If the coordinates are normalized but not centered, 0.0 and 1.0
|
||||
correspond to the minimum and maximum of each height and width
|
||||
dimension.
|
||||
* If the coordinates are both normalized and centered, they range from
|
||||
-1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper
|
||||
left corner, the lower right corner is located at (1.0, 1.0) and the
|
||||
center is at (0, 0).
|
||||
* If the coordinates are not normalized they are interpreted as
|
||||
numbers of pixels.
|
||||
END
|
||||
}
|
@ -32,6 +32,8 @@ namespace tensorflow {
|
||||
class ExtractGlimpseOp : public OpKernel {
|
||||
public:
|
||||
explicit ExtractGlimpseOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
const string& op = context->def().op();
|
||||
version_ = (op == "ExtractGlimpse") ? 1 : 2;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("normalized", &normalized_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("centered", ¢ered_));
|
||||
bool uniform_noise = false;
|
||||
@ -117,21 +119,23 @@ class ExtractGlimpseOp : public OpKernel {
|
||||
// calling TensorFlow operates with (y,x) as indices.
|
||||
offset_vec.push_back(Eigen::IndexPair<float>(offset_x, offset_y));
|
||||
}
|
||||
|
||||
output->tensor<float, 4>().swap_layout().device(
|
||||
context->eigen_cpu_device()) =
|
||||
Eigen::ExtractGlimpses(input.tensor<float, 4>().swap_layout(),
|
||||
output_width, output_height, offset_vec,
|
||||
normalized_, centered_, noise_);
|
||||
normalized_, centered_, noise_, version_);
|
||||
}
|
||||
|
||||
private:
|
||||
bool normalized_;
|
||||
bool centered_;
|
||||
Eigen::ExtractGlimpsesNoiseMode noise_;
|
||||
int32 version_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ExtractGlimpse").Device(DEVICE_CPU),
|
||||
ExtractGlimpseOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ExtractGlimpseV2").Device(DEVICE_CPU),
|
||||
ExtractGlimpseOp);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -56,13 +56,14 @@ struct GlimpseExtractionOp {
|
||||
GlimpseExtractionOp(const Index width, const Index height,
|
||||
const std::vector<IndexPair<float> >& offsets,
|
||||
const bool normalized, const bool centered,
|
||||
const ExtractGlimpsesNoiseMode noise)
|
||||
const ExtractGlimpsesNoiseMode noise, const int version)
|
||||
: width_(width),
|
||||
height_(height),
|
||||
offsets_(offsets),
|
||||
normalized_(normalized),
|
||||
centered_(centered),
|
||||
noise_(noise) {}
|
||||
noise_(noise),
|
||||
version_(version) {}
|
||||
|
||||
template <typename Input>
|
||||
DSizes<Index, 4> dimensions(const Input& input) const {
|
||||
@ -101,6 +102,7 @@ struct GlimpseExtractionOp {
|
||||
for (Index i = 0; i < batch_size; ++i) {
|
||||
float x = offsets_[i].first, y = offsets_[i].second;
|
||||
|
||||
if (version_ == 1) {
|
||||
// Un-normalize coordinates back to pixel space if normalized.
|
||||
if (normalized_) {
|
||||
x *= input_width;
|
||||
@ -116,6 +118,28 @@ struct GlimpseExtractionOp {
|
||||
// Remove half of the glimpse window.
|
||||
x -= width_ / 2.0f;
|
||||
y -= height_ / 2.0f;
|
||||
} else {
|
||||
if (normalized_) {
|
||||
// Un-normalize coordinates back to pixel space if normalized.
|
||||
x *= input_width;
|
||||
y *= input_height;
|
||||
if (centered_) {
|
||||
// Un-center if coordinates are centered on the image center.
|
||||
x /= 2.0f;
|
||||
y /= 2.0f;
|
||||
x += input_width / 2.0f;
|
||||
y += input_height / 2.0f;
|
||||
// Remove half of the glimpse window.
|
||||
x -= width_ / 2.0f;
|
||||
y -= height_ / 2.0f;
|
||||
}
|
||||
} else {
|
||||
if (centered_) {
|
||||
x += input_width / 2.0f;
|
||||
y += input_height / 2.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const Index offset_x = (Index)x;
|
||||
const Index offset_y = (Index)y;
|
||||
@ -243,6 +267,7 @@ struct GlimpseExtractionOp {
|
||||
const bool normalized_;
|
||||
const bool centered_;
|
||||
const ExtractGlimpsesNoiseMode noise_;
|
||||
const int version_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@ -255,7 +280,8 @@ ExtractGlimpses(
|
||||
const typename internal::traits<Input>::Index height,
|
||||
const std::vector<IndexPair<float> >& offsets, const bool normalized = true,
|
||||
const bool centered = true,
|
||||
const ExtractGlimpsesNoiseMode noise = ExtractGlimpsesNoiseMode::UNIFORM) {
|
||||
const ExtractGlimpsesNoiseMode noise = ExtractGlimpsesNoiseMode::UNIFORM,
|
||||
const int version = 2) {
|
||||
EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == ColMajor,
|
||||
YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 4,
|
||||
@ -263,7 +289,7 @@ ExtractGlimpses(
|
||||
|
||||
typedef typename internal::traits<Input>::Index Index;
|
||||
const GlimpseExtractionOp<Index> op(width, height, offsets, normalized,
|
||||
centered, noise);
|
||||
centered, noise, version);
|
||||
return input.customOp(op);
|
||||
}
|
||||
|
||||
|
@ -756,6 +756,41 @@ REGISTER_OP("ExtractGlimpse")
|
||||
c->Dim(input, 3));
|
||||
});
|
||||
|
||||
REGISTER_OP("ExtractGlimpseV2")
|
||||
.Input("input: float")
|
||||
.Input("size: int32")
|
||||
.Input("offsets: float")
|
||||
.Output("glimpse: float")
|
||||
.Attr("centered: bool = true")
|
||||
.Attr("normalized: bool = true")
|
||||
.Attr("uniform_noise: bool = true")
|
||||
.Attr("noise: string = 'uniform'")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle input;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
|
||||
ShapeHandle offsets;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets));
|
||||
|
||||
DimensionHandle batch_dim;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim));
|
||||
DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused));
|
||||
|
||||
bool uniform_noise = false;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("uniform_noise", &uniform_noise));
|
||||
string noise;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("noise", &noise));
|
||||
if (uniform_noise && (!noise.empty() && noise != "uniform")) {
|
||||
return errors::InvalidArgument(
|
||||
"The uniform_noise and noise should not be specified at the same "
|
||||
"time");
|
||||
}
|
||||
|
||||
return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
|
||||
c->Dim(input, 3));
|
||||
});
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
REGISTER_OP("CropAndResize")
|
||||
|
@ -23,6 +23,7 @@ import numpy as np
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_image_ops
|
||||
from tensorflow.python.ops import image_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -196,6 +197,55 @@ class ExtractGlimpseTest(test.TestCase):
|
||||
expected_rows=[None, None, None, 1, 2, 3, 4],
|
||||
expected_cols=[56, 57, 58, 59, 60])
|
||||
|
||||
def testGlimpseNoiseZeroV1Compatible(self):
|
||||
# Note: The old versions of extract_glimpse was incorrect in implementation.
|
||||
# This test is for compatibility so that graph save in old versions behave
|
||||
# the same. Notice the API uses gen_image_ops.extract_glimpse() on purpose.
|
||||
#
|
||||
# Image:
|
||||
# [ 0. 1. 2. 3. 4.]
|
||||
# [ 5. 6. 7. 8. 9.]
|
||||
# [ 10. 11. 12. 13. 14.]
|
||||
# [ 15. 16. 17. 18. 19.]
|
||||
# [ 20. 21. 22. 23. 24.]
|
||||
img = constant_op.constant(
|
||||
np.arange(25).reshape((1, 5, 5, 1)), dtype=dtypes.float32)
|
||||
with self.test_session():
|
||||
# Result 1:
|
||||
# [ 0. 0. 0.]
|
||||
# [ 0. 0. 0.]
|
||||
# [ 0. 0. 0.]
|
||||
result1 = gen_image_ops.extract_glimpse(
|
||||
img, [3, 3], [[-2, 2]],
|
||||
centered=False,
|
||||
normalized=False,
|
||||
noise='zero',
|
||||
uniform_noise=False)
|
||||
self.assertAllEqual(
|
||||
np.asarray([[0, 0, 0], [0, 0, 0], [0, 0, 0]]),
|
||||
self.evaluate(result1)[0, :, :, 0])
|
||||
|
||||
# Result 2:
|
||||
# [ 0. 0. 0. 0. 0. 0. 0.]
|
||||
# [ 0. 0. 1. 2. 3. 4. 0.]
|
||||
# [ 0. 5. 6. 7. 8. 9. 0.]
|
||||
# [ 0. 10. 11. 12. 13. 14. 0.]
|
||||
# [ 0. 15. 16. 17. 18. 19. 0.]
|
||||
# [ 0. 20. 21. 22. 23. 24. 0.]
|
||||
# [ 0. 0. 0. 0. 0. 0. 0.]
|
||||
result2 = gen_image_ops.extract_glimpse(
|
||||
img, [7, 7], [[0, 0]],
|
||||
normalized=False,
|
||||
noise='zero',
|
||||
uniform_noise=False)
|
||||
self.assertAllEqual(
|
||||
np.asarray([[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 4, 0],
|
||||
[0, 5, 6, 7, 8, 9, 0], [0, 10, 11, 12, 13, 14, 0],
|
||||
[0, 15, 16, 17, 18, 19, 0], [0, 20, 21, 22, 23, 24, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0]]),
|
||||
self.evaluate(result2)[0, :, :, 0])
|
||||
|
||||
|
||||
def testGlimpseNoiseZero(self):
|
||||
# Image:
|
||||
# [ 0. 1. 2. 3. 4.]
|
||||
@ -211,7 +261,7 @@ class ExtractGlimpseTest(test.TestCase):
|
||||
# [ 0. 0. 0.]
|
||||
# [ 0. 0. 0.]
|
||||
result1 = image_ops.extract_glimpse_v2(
|
||||
img, [3, 3], [[-2, 2]],
|
||||
img, [3, 3], [[-2, -2]],
|
||||
centered=False,
|
||||
normalized=False,
|
||||
noise='zero')
|
||||
@ -220,22 +270,37 @@ class ExtractGlimpseTest(test.TestCase):
|
||||
self.evaluate(result1)[0, :, :, 0])
|
||||
|
||||
# Result 2:
|
||||
# [ 12. 13. 14. 0. 0. 0. 0.]
|
||||
# [ 17. 18. 19. 0. 0. 0. 0.]
|
||||
# [ 22. 23. 24. 0. 0. 0. 0.]
|
||||
# [ 0. 0. 0. 0. 0. 0. 0.]
|
||||
# [ 0. 0. 0. 0. 0. 0. 0.]
|
||||
# [ 0. 0. 0. 0. 0. 0. 0.]
|
||||
# [ 0. 0. 1. 2. 3. 4. 0.]
|
||||
# [ 0. 5. 6. 7. 8. 9. 0.]
|
||||
# [ 0. 10. 11. 12. 13. 14. 0.]
|
||||
# [ 0. 15. 16. 17. 18. 19. 0.]
|
||||
# [ 0. 20. 21. 22. 23. 24. 0.]
|
||||
# [ 0. 0. 0. 0. 0. 0. 0.]
|
||||
result2 = image_ops.extract_glimpse_v2(
|
||||
img, [7, 7], [[0, 0]], normalized=False, noise='zero')
|
||||
self.assertAllEqual(
|
||||
np.asarray([[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 4, 0],
|
||||
[0, 5, 6, 7, 8, 9, 0], [0, 10, 11, 12, 13, 14, 0],
|
||||
[0, 15, 16, 17, 18, 19, 0], [0, 20, 21, 22, 23, 24, 0],
|
||||
np.asarray([[12, 13, 14, 0, 0, 0, 0], [17, 18, 19, 0, 0, 0, 0],
|
||||
[22, 23, 24, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0]]),
|
||||
self.evaluate(result2)[0, :, :, 0])
|
||||
|
||||
def testGlimpseNonNormalizedNonCentered(self):
|
||||
img = constant_op.constant(
|
||||
np.arange(25).reshape((1, 5, 5, 1)), dtype=dtypes.float32)
|
||||
with self.test_session():
|
||||
result1 = image_ops.extract_glimpse_v2(
|
||||
img, [3, 3], [[0, 0]], centered=False, normalized=False)
|
||||
result2 = image_ops.extract_glimpse_v2(
|
||||
img, [3, 3], [[1, 0]], centered=False, normalized=False)
|
||||
self.assertAllEqual(
|
||||
np.asarray([[0, 1, 2], [5, 6, 7], [10, 11, 12]]),
|
||||
self.evaluate(result1)[0, :, :, 0])
|
||||
self.assertAllEqual(
|
||||
np.asarray([[5, 6, 7], [10, 11, 12], [15, 16, 17]]),
|
||||
self.evaluate(result2)[0, :, :, 0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -4114,7 +4114,7 @@ def extract_glimpse(
|
||||
... [[6.0],
|
||||
... [7.0],
|
||||
... [8.0]]]]
|
||||
>>> tf.image.extract_glimpse(x, size=(2, 2), offsets=[[1, 1]],
|
||||
>>> tf.compat.v1.image.extract_glimpse(x, size=(2, 2), offsets=[[1, 1]],
|
||||
... centered=False, normalized=False)
|
||||
<tf.Tensor: shape=(1, 2, 2, 1), dtype=float32, numpy=
|
||||
array([[[[0.],
|
||||
@ -4203,10 +4203,10 @@ def extract_glimpse_v2(
|
||||
>>> tf.image.extract_glimpse(x, size=(2, 2), offsets=[[1, 1]],
|
||||
... centered=False, normalized=False)
|
||||
<tf.Tensor: shape=(1, 2, 2, 1), dtype=float32, numpy=
|
||||
array([[[[0.],
|
||||
[1.]],
|
||||
[[3.],
|
||||
[4.]]]], dtype=float32)>
|
||||
array([[[[4.],
|
||||
[5.]],
|
||||
[[7.],
|
||||
[8.]]]], dtype=float32)>
|
||||
|
||||
Args:
|
||||
input: A `Tensor` of type `float32`. A 4-D float tensor of shape
|
||||
@ -4231,7 +4231,7 @@ def extract_glimpse_v2(
|
||||
Returns:
|
||||
A `Tensor` of type `float32`.
|
||||
"""
|
||||
return gen_image_ops.extract_glimpse(
|
||||
return gen_image_ops.extract_glimpse_v2(
|
||||
input=input,
|
||||
size=size,
|
||||
offsets=offsets,
|
||||
|
@ -1476,6 +1476,10 @@ tf_module {
|
||||
name: "ExtractGlimpse"
|
||||
argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'uniform\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ExtractGlimpseV2"
|
||||
argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'uniform\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ExtractImagePatches"
|
||||
argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -1476,6 +1476,10 @@ tf_module {
|
||||
name: "ExtractGlimpse"
|
||||
argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'uniform\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ExtractGlimpseV2"
|
||||
argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'uniform\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ExtractImagePatches"
|
||||
argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user