Allow output has a different shape from input in the image.transform (#17011).
PiperOrigin-RevId: 193564222
This commit is contained in:
parent
55706e693a
commit
7f1e64eb94
@ -70,6 +70,7 @@ class ImageProjectiveTransform : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& images_t = ctx->input(0);
|
||||
const Tensor& transform_t = ctx->input(1);
|
||||
const Tensor& output_dim = ctx->input(2);
|
||||
OP_REQUIRES(ctx, images_t.shape().dims() == 4,
|
||||
errors::InvalidArgument("Input images must have rank 4"));
|
||||
OP_REQUIRES(ctx,
|
||||
@ -83,7 +84,11 @@ class ImageProjectiveTransform : public OpKernel {
|
||||
auto images = images_t.tensor<T, 4>();
|
||||
auto transform = transform_t.matrix<float>();
|
||||
Tensor* output_t;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
|
||||
// Image is NHWC format.
|
||||
auto output_shape = images_t.shape();
|
||||
output_shape.set_dim(1, output_dim.vec<int>()(0));
|
||||
output_shape.set_dim(2, output_dim.vec<int>()(1));
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_t));
|
||||
auto output = output_t->tensor<T, 4>();
|
||||
(FillProjectiveTransform<Device, T>(interpolation_))(
|
||||
ctx->eigen_device<Device>(), &output, images, transform);
|
||||
|
||||
@ -161,7 +161,7 @@ struct FillProjectiveTransform {
|
||||
void operator()(const Device& device, OutputType* output,
|
||||
const InputType& images,
|
||||
const TransformsType& transform) const {
|
||||
output->device(device) = images.generate(
|
||||
output->device(device) = output->generate(
|
||||
ProjectiveGenerator<Device, T>(images, transform, interpolation_));
|
||||
}
|
||||
};
|
||||
|
||||
@ -19,9 +19,55 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using shape_inference::DimensionHandle;
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
|
||||
namespace {
|
||||
|
||||
// Sets output[0] to shape [batch_dim,height,width,channel_dim], where
|
||||
// height and width come from the size_tensor.
|
||||
Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
|
||||
int size_input_idx, DimensionHandle channel_dim) {
|
||||
// Verify shape of size input.
|
||||
ShapeHandle size;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size));
|
||||
DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused));
|
||||
|
||||
// Get size values from the size tensor.
|
||||
const Tensor* size_tensor = c->input_tensor(size_input_idx);
|
||||
DimensionHandle width;
|
||||
DimensionHandle height;
|
||||
if (size_tensor == nullptr) {
|
||||
width = c->UnknownDim();
|
||||
height = c->UnknownDim();
|
||||
} else {
|
||||
// TODO(petewarden) - Remove once we have constant evaluation in C++ only.
|
||||
if (size_tensor->dtype() != DT_INT32) {
|
||||
return errors::InvalidArgument(
|
||||
"Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
|
||||
"but got ",
|
||||
DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
|
||||
" in ", c->DebugString());
|
||||
}
|
||||
auto vec = size_tensor->vec<int32>();
|
||||
height = c->MakeDim(vec(0));
|
||||
width = c->MakeDim(vec(1));
|
||||
}
|
||||
c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ResizeShapeFn(InferenceContext* c) {
|
||||
ShapeHandle input;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
|
||||
return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
|
||||
c->Dim(input, 3));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
|
||||
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
|
||||
// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
|
||||
@ -29,13 +75,11 @@ using shape_inference::ShapeHandle;
|
||||
REGISTER_OP("ImageProjectiveTransform")
|
||||
.Input("images: dtype")
|
||||
.Input("transforms: float32")
|
||||
.Input("output_shape: int32")
|
||||
.Attr("dtype: {uint8, int32, int64, float32, float64}")
|
||||
.Attr("interpolation: string")
|
||||
.Output("transformed_images: dtype")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->input(0));
|
||||
return Status::OK();
|
||||
})
|
||||
.SetShapeFn(ResizeShapeFn)
|
||||
.Doc(R"doc(
|
||||
Applies the given transform to each of the images.
|
||||
|
||||
|
||||
@ -195,10 +195,40 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
|
||||
x_init_value=test_image)
|
||||
self.assertLess(left_err, 1e-10)
|
||||
|
||||
def _test_grad_different_shape(self, input_shape, output_shape):
|
||||
with self.test_session():
|
||||
test_image_shape = input_shape
|
||||
test_image = np.random.randn(*test_image_shape)
|
||||
test_image_tensor = constant_op.constant(
|
||||
test_image, shape=test_image_shape)
|
||||
test_transform = image_ops.angles_to_projective_transforms(
|
||||
np.pi / 2, 4, 4)
|
||||
|
||||
if len(output_shape) == 2:
|
||||
resize_shape = output_shape
|
||||
elif len(output_shape) == 3:
|
||||
resize_shape = output_shape[0:2]
|
||||
elif len(output_shape) == 4:
|
||||
resize_shape = output_shape[1:3]
|
||||
output = image_ops.transform(
|
||||
images=test_image_tensor,
|
||||
transforms=test_transform,
|
||||
output_shape=resize_shape)
|
||||
left_err = gradient_checker.compute_gradient_error(
|
||||
test_image_tensor,
|
||||
test_image_shape,
|
||||
output,
|
||||
output_shape,
|
||||
x_init_value=test_image)
|
||||
self.assertLess(left_err, 1e-10)
|
||||
|
||||
def test_grad(self):
|
||||
self._test_grad([16, 16])
|
||||
self._test_grad([4, 12, 12])
|
||||
self._test_grad([3, 4, 12, 12])
|
||||
self._test_grad_different_shape([16, 16], [8, 8])
|
||||
self._test_grad_different_shape([4, 12, 3], [8, 24, 3])
|
||||
self._test_grad_different_shape([3, 4, 12, 3], [3, 8, 24, 3])
|
||||
|
||||
|
||||
class BipartiteMatchTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@ -212,7 +212,11 @@ def translations_to_projective_transforms(translations, name=None):
|
||||
axis=1)
|
||||
|
||||
|
||||
def transform(images, transforms, interpolation="NEAREST", name=None):
|
||||
def transform(images,
|
||||
transforms,
|
||||
output_shape=None,
|
||||
interpolation="NEAREST",
|
||||
name=None):
|
||||
"""Applies the given transform(s) to the image(s).
|
||||
|
||||
Args:
|
||||
@ -228,7 +232,10 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
|
||||
where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
|
||||
the transform mapping input points to output points. Note that gradients
|
||||
are not backpropagated into transformation parameters.
|
||||
output_shape: Output dimesion after the transform, [height, width].
|
||||
If None, output is the same size as input image.
|
||||
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
|
||||
name: The name of the op.
|
||||
|
||||
Returns:
|
||||
Image(s) with the same type and shape as `images`, with the given
|
||||
@ -255,6 +262,14 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
|
||||
else:
|
||||
raise TypeError("Images should have rank between 2 and 4.")
|
||||
|
||||
if output_shape is None:
|
||||
output_shape = images.get_shape()[1:3]
|
||||
elif len(output_shape) != 2:
|
||||
raise TypeError(
|
||||
"output_shape must either be None or a vector of 2 elements.")
|
||||
output_shape = ops.convert_to_tensor(
|
||||
output_shape, name="output_shape", dtype=dtypes.int32)
|
||||
|
||||
if len(transform_or_transforms.get_shape()) == 1:
|
||||
transforms = transform_or_transforms[None]
|
||||
elif transform_or_transforms.get_shape().ndims is None:
|
||||
@ -265,7 +280,7 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
|
||||
else:
|
||||
raise TypeError("Transforms should have rank 1 or 2.")
|
||||
output = gen_image_ops.image_projective_transform(
|
||||
images, transforms, interpolation=interpolation.upper())
|
||||
images, transforms, output_shape, interpolation=interpolation.upper())
|
||||
if len(image_or_images.get_shape()) == 2:
|
||||
return output[0, :, :, 0]
|
||||
elif len(image_or_images.get_shape()) == 3:
|
||||
@ -375,14 +390,6 @@ def _image_projective_transform_grad(op, grad):
|
||||
|
||||
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
|
||||
raise TypeError("Invalid dtype %s." % image_or_images.dtype)
|
||||
if len(image_or_images.get_shape()) == 2:
|
||||
images = image_or_images[None, :, :, None]
|
||||
elif len(image_or_images.get_shape()) == 3:
|
||||
images = image_or_images[None, :, :, :]
|
||||
elif len(image_or_images.get_shape()) == 4:
|
||||
images = image_or_images
|
||||
else:
|
||||
raise TypeError("Images should have rank between 2 and 4")
|
||||
if len(transform_or_transforms.get_shape()) == 1:
|
||||
transforms = transform_or_transforms[None]
|
||||
elif len(transform_or_transforms.get_shape()) == 2:
|
||||
@ -395,13 +402,11 @@ def _image_projective_transform_grad(op, grad):
|
||||
inverse = linalg_ops.matrix_inverse(transforms)
|
||||
transforms = matrices_to_flat_transforms(inverse)
|
||||
output = gen_image_ops.image_projective_transform(
|
||||
grad, transforms, interpolation=interpolation)
|
||||
if len(image_or_images.get_shape()) == 2:
|
||||
return [output[0, :, :, 0], None]
|
||||
elif len(image_or_images.get_shape()) == 3:
|
||||
return [output[0, :, :, :], None]
|
||||
else:
|
||||
return [output, None]
|
||||
images=grad,
|
||||
transforms=transforms,
|
||||
output_shape=image_or_images.get_shape()[1:3],
|
||||
interpolation=interpolation)
|
||||
return [output, None, None]
|
||||
|
||||
|
||||
def bipartite_match(distance_mat,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user