Use ExtractGlimpseV2 and ExtractGlimpse to make sure C++ kernel is backward compatible

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2020-04-27 17:09:52 +00:00
parent 353d22eb43
commit 756b7ed2d6
4 changed files with 82 additions and 21 deletions

View File

@ -32,6 +32,8 @@ namespace tensorflow {
class ExtractGlimpseOp : public OpKernel {
public:
explicit ExtractGlimpseOp(OpKernelConstruction* context) : OpKernel(context) {
const string& op = context->def().op();
version_ = (op == "ExtractGlimpse") ? 1 : 2;
OP_REQUIRES_OK(context, context->GetAttr("normalized", &normalized_));
OP_REQUIRES_OK(context, context->GetAttr("centered", &centered_));
bool uniform_noise = false;
@ -117,21 +119,23 @@ class ExtractGlimpseOp : public OpKernel {
// calling TensorFlow operates with (y,x) as indices.
offset_vec.push_back(Eigen::IndexPair<float>(offset_x, offset_y));
}
output->tensor<float, 4>().swap_layout().device(
context->eigen_cpu_device()) =
Eigen::ExtractGlimpses(input.tensor<float, 4>().swap_layout(),
output_width, output_height, offset_vec,
normalized_, centered_, noise_);
normalized_, centered_, noise_, version_);
}
private:
bool normalized_;
bool centered_;
Eigen::ExtractGlimpsesNoiseMode noise_;
int32 version_;
};
REGISTER_KERNEL_BUILDER(Name("ExtractGlimpse").Device(DEVICE_CPU),
ExtractGlimpseOp);
REGISTER_KERNEL_BUILDER(Name("ExtractGlimpseV2").Device(DEVICE_CPU),
ExtractGlimpseOp);
} // end namespace tensorflow

View File

@ -56,13 +56,15 @@ struct GlimpseExtractionOp {
GlimpseExtractionOp(const Index width, const Index height,
const std::vector<IndexPair<float> >& offsets,
const bool normalized, const bool centered,
const ExtractGlimpsesNoiseMode noise)
const ExtractGlimpsesNoiseMode noise,
const int version)
: width_(width),
height_(height),
offsets_(offsets),
normalized_(normalized),
centered_(centered),
noise_(noise) {}
noise_(noise),
version_(version) {}
template <typename Input>
DSizes<Index, 4> dimensions(const Input& input) const {
@ -101,24 +103,42 @@ struct GlimpseExtractionOp {
for (Index i = 0; i < batch_size; ++i) {
float x = offsets_[i].first, y = offsets_[i].second;
if (normalized_) {
if (version_ == 1) {
// Un-normalize coordinates back to pixel space if normalized.
x *= input_width;
y *= input_height;
if (normalized_) {
x *= input_width;
y *= input_height;
}
// Un-center if coordinates are centered on the image center.
if (centered_) {
// Un-center if coordinates are centered on the image center.
x /= 2.0f;
y /= 2.0f;
x += input_width / 2.0f;
y += input_height / 2.0f;
// Remove half of the glimpse window.
x -= width_ / 2.0f;
y -= height_ / 2.0f;
}
// Remove half of the glimpse window.
x -= width_ / 2.0f;
y -= height_ / 2.0f;
} else {
if (centered_) {
x += input_width / 2.0f;
y += input_height / 2.0f;
if (normalized_) {
// Un-normalize coordinates back to pixel space if normalized.
x *= input_width;
y *= input_height;
if (centered_) {
// Un-center if coordinates are centered on the image center.
x /= 2.0f;
y /= 2.0f;
x += input_width / 2.0f;
y += input_height / 2.0f;
// Remove half of the glimpse window.
x -= width_ / 2.0f;
y -= height_ / 2.0f;
}
} else {
if (centered_) {
x += input_width / 2.0f;
y += input_height / 2.0f;
}
}
}
@ -248,6 +268,7 @@ struct GlimpseExtractionOp {
const bool normalized_;
const bool centered_;
const ExtractGlimpsesNoiseMode noise_;
const int version_;
};
} // namespace
@ -260,7 +281,8 @@ ExtractGlimpses(
const typename internal::traits<Input>::Index height,
const std::vector<IndexPair<float> >& offsets, const bool normalized = true,
const bool centered = true,
const ExtractGlimpsesNoiseMode noise = ExtractGlimpsesNoiseMode::UNIFORM) {
const ExtractGlimpsesNoiseMode noise = ExtractGlimpsesNoiseMode::UNIFORM,
const int version = 2) {
EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == ColMajor,
YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 4,
@ -268,7 +290,7 @@ ExtractGlimpses(
typedef typename internal::traits<Input>::Index Index;
const GlimpseExtractionOp<Index> op(width, height, offsets, normalized,
centered, noise);
centered, noise, version);
return input.customOp(op);
}

View File

@ -756,6 +756,41 @@ REGISTER_OP("ExtractGlimpse")
c->Dim(input, 3));
});
REGISTER_OP("ExtractGlimpseV2")
.Input("input: float")
.Input("size: int32")
.Input("offsets: float")
.Output("glimpse: float")
.Attr("centered: bool = true")
.Attr("normalized: bool = true")
.Attr("uniform_noise: bool = true")
.Attr("noise: string = 'uniform'")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
ShapeHandle offsets;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets));
DimensionHandle batch_dim;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim));
DimensionHandle unused;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused));
bool uniform_noise = false;
TF_RETURN_IF_ERROR(c->GetAttr("uniform_noise", &uniform_noise));
string noise;
TF_RETURN_IF_ERROR(c->GetAttr("noise", &noise));
if (uniform_noise && (!noise.empty() && noise != "uniform")) {
return errors::InvalidArgument(
"The uniform_noise and noise should not be specified at the same "
"time");
}
return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
c->Dim(input, 3));
});
// --------------------------------------------------------------------------
REGISTER_OP("CropAndResize")

View File

@ -4063,10 +4063,10 @@ def extract_glimpse(
>>> tf.image.extract_glimpse(x, size=(2, 2), offsets=[[1, 1]],
... centered=False, normalized=False)
<tf.Tensor: shape=(1, 2, 2, 1), dtype=float32, numpy=
array([[[[4.],
[5.]],
[[7.],
[8.]]]], dtype=float32)>
array([[[[0.],
[1.]],
[[3.],
[4.]]]], dtype=float32)>
Args:
input: A `Tensor` of type `float32`. A 4-D float tensor of shape
@ -4176,7 +4176,7 @@ def extract_glimpse_v2(
Returns:
A `Tensor` of type `float32`.
"""
return gen_image_ops.extract_glimpse(
return gen_image_ops.extract_glimpse_v2(
input=input,
size=size,
offsets=offsets,