Add crop_and_decode_jpeg_op that combines the crop and decode for better

performance.

PiperOrigin-RevId: 168493125
This commit is contained in:
Mingxing Tan 2017-09-12 21:48:48 -07:00 committed by TensorFlower Gardener
parent 48ddf64d0e
commit 9d56f419cf
8 changed files with 272 additions and 54 deletions

View File

@ -4675,6 +4675,7 @@ filegroup(
"encode_jpeg_op.*",
"extract_jpeg_shape_op.*",
"decode_jpeg_op.*",
"decode_and_crop_jpeg_op.*",
"decode_gif_op.*",
"identity_reader_op.*",
"remote_fused_graph_execute_op.*",

View File

@ -71,6 +71,9 @@ class DecodeImageOp : public OpKernel {
// Determine which op we are: jpeg, png, gif, or any
if (type_string() == "DecodeJpeg") {
format_ = kJpgFormat;
} else if (type_string() == "DecodeAndCropJpeg") {
format_ = kJpgFormat;
flags_.crop = true;
} else if (type_string() == "DecodePng") {
format_ = kPngFormat;
} else if (type_string() == "DecodeGif") {
@ -185,12 +188,31 @@ class DecodeImageOp : public OpKernel {
errors::InvalidArgument(
"channels must be 0, 1, or 3 for JPEG, got ", channels_));
// Decode jpeg, allocating tensor once the size is known
// Use local copy of flags to avoid race condition as the class member is
// shared among different invocations.
jpeg::UncompressFlags flags = flags_;
if (flags.crop) {
// Update flags to include crop window.
const Tensor& crop_window = context->input(1);
OP_REQUIRES(context, crop_window.dims() == 1,
errors::InvalidArgument("crop_window must be 1-D, got shape ",
crop_window.shape().DebugString()));
OP_REQUIRES(context, crop_window.dim_size(0) == 4,
errors::InvalidArgument("crop_size must have four elements ",
crop_window.shape().DebugString()));
auto crop_window_vec = crop_window.vec<int32>();
flags.crop_y = crop_window_vec(0);
flags.crop_x = crop_window_vec(1);
flags.crop_height = crop_window_vec(2);
flags.crop_width = crop_window_vec(3);
}
// Decode jpeg, allocating tensor once the size is known.
Tensor* output = nullptr;
OP_REQUIRES(
context,
jpeg::Uncompress(
input.data(), input.size(), flags_, nullptr /* nwarn */,
input.data(), input.size(), flags, nullptr /* nwarn */,
[=, &output](int width, int height, int channels) -> uint8* {
Status status(context->allocate_output(
0,
@ -205,7 +227,8 @@ class DecodeImageOp : public OpKernel {
}
return output->flat<uint8>().data();
}),
errors::InvalidArgument("Invalid JPEG data, size ", input.size()));
errors::InvalidArgument("Invalid JPEG data or crop window, data size ",
input.size()));
}
void DecodePng(OpKernelContext* context, StringPiece input) {
@ -311,6 +334,8 @@ class DecodeImageOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("DecodeJpeg").Device(DEVICE_CPU), DecodeImageOp);
REGISTER_KERNEL_BUILDER(Name("DecodePng").Device(DEVICE_CPU), DecodeImageOp);
REGISTER_KERNEL_BUILDER(Name("DecodeGif").Device(DEVICE_CPU), DecodeImageOp);
REGISTER_KERNEL_BUILDER(Name("DecodeAndCropJpeg").Device(DEVICE_CPU),
DecodeImageOp);
} // namespace
} // namespace tensorflow

View File

@ -25,6 +25,42 @@ using shape_inference::ShapeHandle;
namespace {
const char kDecodeJpegCommonDocStr[] = R"doc(
The attr `channels` indicates the desired number of color channels for the
decoded image.
Accepted values are:
* 0: Use the number of channels in the JPEG-encoded image.
* 1: output a grayscale image.
* 3: output an RGB image.
If needed, the JPEG-encoded image is transformed to match the requested number
of color channels.
The attr `ratio` allows downscaling the image by an integer factor during
decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than
downscaling the image later.
)doc";
const char kDecodeJpegCommonParamsDocStr[] = R"doc(
channels: Number of color channels for the decoded image.
ratio: Downscaling ratio.
fancy_upscaling: If true use a slower but nicer upscaling of the
chroma planes (yuv420/422 only).
try_recover_truncated: If true try to recover an image from truncated input.
acceptable_fraction: The minimum required fraction of lines before a truncated
input is accepted.
dct_method: string specifying a hint about the algorithm used for
decompression. Defaults to "" which maps to a system-specific
default. Currently valid values are ["INTEGER_FAST",
"INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal
jpeg library changes to a version that does not have that specific
option.)
image: 3-D with shape `[height, width, channels]`..
)doc";
// Sets output[0] to shape [batch_dim,height,width,channel_dim], where
// height and width come from the size_tensor.
Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
@ -370,44 +406,40 @@ REGISTER_OP("DecodeJpeg")
.Attr("dct_method: string = ''")
.Output("image: uint8")
.SetShapeFn(DecodeImageShapeFn)
.Doc(R"doc(
.Doc(strings::StrCat(R"doc(
Decode a JPEG-encoded image to a uint8 tensor.
The attr `channels` indicates the desired number of color channels for the
decoded image.
Accepted values are:
* 0: Use the number of channels in the JPEG-encoded image.
* 1: output a grayscale image.
* 3: output an RGB image.
If needed, the JPEG-encoded image is transformed to match the requested number
of color channels.
The attr `ratio` allows downscaling the image by an integer factor during
decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than
downscaling the image later.
)doc",
kDecodeJpegCommonDocStr, R"doc(
This op also supports decoding PNGs and non-animated GIFs since the interface is
the same, though it is cleaner to use `tf.image.decode_image`.
contents: 0-D. The JPEG-encoded image.
channels: Number of color channels for the decoded image.
ratio: Downscaling ratio.
fancy_upscaling: If true use a slower but nicer upscaling of the
chroma planes (yuv420/422 only).
try_recover_truncated: If true try to recover an image from truncated input.
acceptable_fraction: The minimum required fraction of lines before a truncated
input is accepted.
dct_method: string specifying a hint about the algorithm used for
decompression. Defaults to "" which maps to a system-specific
default. Currently valid values are ["INTEGER_FAST",
"INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal
jpeg library changes to a version that does not have that specific
option.)
image: 3-D with shape `[height, width, channels]`..
)doc");
)doc",
kDecodeJpegCommonParamsDocStr));
// --------------------------------------------------------------------------
REGISTER_OP("DecodeAndCropJpeg")
.Input("contents: string")
.Input("crop_window: int32")
.Attr("channels: int = 0")
.Attr("ratio: int = 1")
.Attr("fancy_upscaling: bool = true")
.Attr("try_recover_truncated: bool = false")
.Attr("acceptable_fraction: float = 1.0")
.Attr("dct_method: string = ''")
.Output("image: uint8")
.SetShapeFn(DecodeImageShapeFn)
.Doc(strings::StrCat(R"doc(
Decode and Crop a JPEG-encoded image to a uint8 tensor.
)doc",
kDecodeJpegCommonDocStr, R"doc(
It is equivalent to a combination of decode and crop, but much faster by only
decoding partial jpeg image.
contents: 0-D. The JPEG-encoded image.
crop_window: 1-D. The crop window: [crop_y, crop_x, crop_height, crop_width].
)doc",
kDecodeJpegCommonParamsDocStr));
// --------------------------------------------------------------------------
REGISTER_OP("EncodeJpeg")

View File

@ -90,6 +90,58 @@ TEST(ImageOpsTest, DecodeImage_ShapeFn) {
}
}
TEST(ImageOpsTest, DecodeAndCropJpeg_ShapeFn) {
const char* op_name = "DecodeAndCropJpeg";
ShapeInferenceTestOp op(op_name);
// Check the number of inputs.
INFER_ERROR("Wrong number of inputs passed: 1 while 2 expected", op, "[1]");
// Rank check.
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];?");
// Set the channel to zero - output is not known.
TF_ASSERT_OK(NodeDefBuilder("test", op_name)
.Input({"img", 0, DT_STRING})
.Input({"crop_window", 1, DT_INT32})
.Finalize(&op.node_def));
INFER_OK(op, "[];[]", "[?,?,?]");
// Set the channel, so that part of output shape is known.
TF_ASSERT_OK(NodeDefBuilder("test", op_name)
.Input({"img", 0, DT_STRING})
.Input({"crop_window", 1, DT_INT32})
.Attr("channels", 4)
.Finalize(&op.node_def));
INFER_OK(op, "[];[]", "[?,?,4]");
// Negative channel value is rejected.
TF_ASSERT_OK(NodeDefBuilder("test", op_name)
.Input({"img", 0, DT_STRING})
.Input({"crop_window", 1, DT_INT32})
.Attr("channels", -1)
.Finalize(&op.node_def));
INFER_ERROR("channels must be non-negative, got -1", op, "[];[]");
}
TEST(ImageOpsTest, DecodeAndCropJpeg_InvalidCropWindow) {
const char* op_name = "DecodeAndCropJpeg";
ShapeInferenceTestOp op(op_name);
// Check the number of inputs.
INFER_ERROR("Wrong number of inputs passed: 1 while 2 expected", op, "[1]");
// Rank check.
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];?");
// Set the channel to zero - output is not known.
TF_ASSERT_OK(NodeDefBuilder("test", op_name)
.Input({"img", 0, DT_STRING})
.Input({"crop_window", 1, DT_INT32})
.Finalize(&op.node_def));
INFER_OK(op, "[];[]", "[?,?,?]");
}
TEST(ImageOpsTest, EncodeImage_ShapeFn) {
for (const char* op_name : {"EncodeJpeg", "EncodePng"}) {
ShapeInferenceTestOp op(op_name);

View File

@ -37,7 +37,13 @@ prefix_path = 'third_party/tensorflow/core/lib/jpeg/testdata'
class DecodeJpegBenchmark(test.Benchmark):
"""Evaluate tensorflow DecodeJpegOp performance."""
def _evalDecodeJpeg(self, image_name, parallelism, num_iters, tile=None):
def _evalDecodeJpeg(self,
image_name,
parallelism,
num_iters,
crop_during_decode=None,
crop_window=None,
tile=None):
"""Evaluate DecodeJpegOp for the given image.
TODO(tanmingxing): add decoding+cropping as well.
@ -46,6 +52,10 @@ class DecodeJpegBenchmark(test.Benchmark):
image_name: a string of image file name (without suffix).
parallelism: the number of concurrent decode_jpeg ops to be run.
num_iters: number of iterations for evaluation.
crop_during_decode: If true, use fused DecodeAndCropJpeg instead of
separate decode and crop ops. It is ignored if crop_window is None.
crop_window: if not None, crop the decoded image. Depending on
crop_during_decode, cropping could happen during or after decoding.
tile: if not None, tile the image to composite a larger fake image.
Returns:
@ -71,11 +81,25 @@ class DecodeJpegBenchmark(test.Benchmark):
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
images = []
for i in xrange(parallelism):
images.append(
image_ops.decode_jpeg(
image_content, channels=3, name='image_%d' % (i)))
for _ in xrange(parallelism):
if crop_window is None:
# No crop.
image = image_ops.decode_jpeg(image_content, channels=3)
elif crop_during_decode:
# combined decode and crop.
image = image_ops.decode_and_crop_jpeg(
image_content, crop_window, channels=3)
else:
# separate decode and crop.
image = image_ops.decode_jpeg(image_content, channels=3)
image = image_ops.crop_to_bounding_box(
image,
offset_height=crop_window[0],
offset_width=crop_window[1],
target_height=crop_window[2],
target_width=crop_window[3])
images.append(image)
r = control_flow_ops.group(*images)
for _ in xrange(3):
@ -89,38 +113,77 @@ class DecodeJpegBenchmark(test.Benchmark):
def benchmarkDecodeJpegSmall(self):
"""Evaluate single DecodeImageOp for small size image."""
parallelism = 1
num_iters = 10
for parallelism in [1, 10, 100]:
duration = self._evalDecodeJpeg('small.jpg', parallelism, num_iters)
crop_window = [10, 10, 50, 50]
for parallelism in [1, 100]:
duration_decode = self._evalDecodeJpeg('small.jpg', parallelism,
num_iters)
duration_decode_crop = self._evalDecodeJpeg('small.jpg', parallelism,
num_iters, False, crop_window)
duration_decode_after_crop = self._evalDecodeJpeg(
'small.jpg', parallelism, num_iters, True, crop_window)
self.report_benchmark(
name='decode_jpeg_small_p%d' % (parallelism),
iters=num_iters,
wall_time=duration)
wall_time=duration_decode)
self.report_benchmark(
name='decode_crop_jpeg_small_p%d' % (parallelism),
iters=num_iters,
wall_time=duration_decode_crop)
self.report_benchmark(
name='decode_after_crop_jpeg_small_p%d' % (parallelism),
iters=num_iters,
wall_time=duration_decode_after_crop)
def benchmarkDecodeJpegMedium(self):
"""Evaluate single DecodeImageOp for medium size image."""
parallelism = 1
num_iters = 10
for parallelism in [1, 10, 100]:
duration = self._evalDecodeJpeg('medium.jpg', parallelism, num_iters)
crop_window = [10, 10, 50, 50]
for parallelism in [1, 100]:
duration_decode = self._evalDecodeJpeg('medium.jpg', parallelism,
num_iters)
duration_decode_crop = self._evalDecodeJpeg('medium.jpg', parallelism,
num_iters, False, crop_window)
duration_decode_after_crop = self._evalDecodeJpeg(
'medium.jpg', parallelism, num_iters, True, crop_window)
self.report_benchmark(
name='decode_jpeg_medium_p%d' % (parallelism),
iters=num_iters,
wall_time=duration)
wall_time=duration_decode)
self.report_benchmark(
name='decode_crop_jpeg_medium_p%d' % (parallelism),
iters=num_iters,
wall_time=duration_decode_crop)
self.report_benchmark(
name='decode_after_crop_jpeg_medium_p%d' % (parallelism),
iters=num_iters,
wall_time=duration_decode_after_crop)
def benchmarkDecodeJpegLarge(self):
"""Evaluate single DecodeImageOp for large size image."""
parallelism = 1
num_iters = 10
for parallelism in [1, 10, 100]:
crop_window = [10, 10, 50, 50]
tile = [4, 4, 1]
for parallelism in [1, 100]:
# Tile the medium size image to composite a larger fake image.
duration = self._evalDecodeJpeg(
'medium.jpg', parallelism, num_iters, tile=[4, 4, 1])
duration_decode = self._evalDecodeJpeg('medium.jpg', parallelism,
num_iters, tile)
duration_decode_crop = self._evalDecodeJpeg(
'medium.jpg', parallelism, num_iters, False, crop_window, tile)
duration_decode_after_crop = self._evalDecodeJpeg(
'medium.jpg', parallelism, num_iters, True, crop_window, tile)
self.report_benchmark(
name='decode_jpeg_large_p%d' % (parallelism),
iters=num_iters,
wall_time=duration)
wall_time=duration_decode)
self.report_benchmark(
name='decode_crop_jpeg_large_p%d' % (parallelism),
iters=num_iters,
wall_time=duration_decode_crop)
self.report_benchmark(
name='decode_after_crop_jpeg_large_p%d' % (parallelism),
iters=num_iters,
wall_time=duration_decode_after_crop)
if __name__ == '__main__':

View File

@ -21,6 +21,7 @@ See the @{$python/image} guide.
@@decode_bmp
@@decode_gif
@@decode_jpeg
@@decode_and_crop_jpeg
@@encode_jpeg
@@extract_jpeg_shape
@@decode_png

View File

@ -2391,6 +2391,46 @@ class JpegTest(test_util.TensorFlowTestCase):
error = self.averageError(rgb, cmyk)
self.assertLess(error, 4)
def testCropAndDecodeJpeg(self):
with self.test_session() as sess:
# Encode it, then decode it, then encode it
base = "tensorflow/core/lib/jpeg/testdata"
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
h, w, _ = 256, 128, 3
crop_windows = [[0, 0, 5, 5], [0, 0, 5, w], [0, 0, h, 5],
[h - 6, w - 5, 6, 5], [6, 5, 15, 10], [0, 0, h, w]]
for crop_window in crop_windows:
# Explicit two stages: decode + crop.
image1 = image_ops.decode_jpeg(jpeg0)
y, x, h, w = crop_window
image1_crop = image_ops.crop_to_bounding_box(image1, y, x, h, w)
# Combined crop+decode.
image2 = image_ops.decode_and_crop_jpeg(jpeg0, crop_window)
# CropAndDecode should be equal to DecodeJpeg+Crop.
image1_crop, image2 = sess.run([image1_crop, image2])
self.assertAllEqual(image1_crop, image2)
def testCropAndDecodeJpegWithInvalidCropWindow(self):
with self.test_session() as sess:
# Encode it, then decode it, then encode it
base = "tensorflow/core/lib/jpeg/testdata"
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
h, w, _ = 256, 128, 3
# Invalid crop windows.
crop_windows = [[-1, 11, 11, 11], [11, -1, 11, 11], [11, 11, -1, 11],
[11, 11, 11, -1], [11, 11, 0, 11], [11, 11, 11, 0],
[0, 0, h + 1, w], [0, 0, h, w + 1]]
for crop_window in crop_windows:
result = image_ops.decode_and_crop_jpeg(jpeg0, crop_window)
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
lambda e: "Invalid JPEG data or crop window" in str(e)):
sess.run(result)
def testSynthetic(self):
with self.test_session(use_gpu=True) as sess:
# Encode it, then decode it, then encode it

View File

@ -40,6 +40,10 @@ tf_module {
name: "crop_to_bounding_box"
argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "decode_and_crop_jpeg"
argspec: "args=[\'contents\', \'crop_window\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \'True\', \'False\', \'1\', \'\', \'None\'], "
}
member_method {
name: "decode_bmp"
argspec: "args=[\'contents\', \'channels\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "