Add crop_and_decode_jpeg_op that combines the crop and decode for better
performance. PiperOrigin-RevId: 168493125
This commit is contained in:
parent
48ddf64d0e
commit
9d56f419cf
@ -4675,6 +4675,7 @@ filegroup(
|
||||
"encode_jpeg_op.*",
|
||||
"extract_jpeg_shape_op.*",
|
||||
"decode_jpeg_op.*",
|
||||
"decode_and_crop_jpeg_op.*",
|
||||
"decode_gif_op.*",
|
||||
"identity_reader_op.*",
|
||||
"remote_fused_graph_execute_op.*",
|
||||
|
@ -71,6 +71,9 @@ class DecodeImageOp : public OpKernel {
|
||||
// Determine which op we are: jpeg, png, gif, or any
|
||||
if (type_string() == "DecodeJpeg") {
|
||||
format_ = kJpgFormat;
|
||||
} else if (type_string() == "DecodeAndCropJpeg") {
|
||||
format_ = kJpgFormat;
|
||||
flags_.crop = true;
|
||||
} else if (type_string() == "DecodePng") {
|
||||
format_ = kPngFormat;
|
||||
} else if (type_string() == "DecodeGif") {
|
||||
@ -185,12 +188,31 @@ class DecodeImageOp : public OpKernel {
|
||||
errors::InvalidArgument(
|
||||
"channels must be 0, 1, or 3 for JPEG, got ", channels_));
|
||||
|
||||
// Decode jpeg, allocating tensor once the size is known
|
||||
// Use local copy of flags to avoid race condition as the class member is
|
||||
// shared among different invocations.
|
||||
jpeg::UncompressFlags flags = flags_;
|
||||
if (flags.crop) {
|
||||
// Update flags to include crop window.
|
||||
const Tensor& crop_window = context->input(1);
|
||||
OP_REQUIRES(context, crop_window.dims() == 1,
|
||||
errors::InvalidArgument("crop_window must be 1-D, got shape ",
|
||||
crop_window.shape().DebugString()));
|
||||
OP_REQUIRES(context, crop_window.dim_size(0) == 4,
|
||||
errors::InvalidArgument("crop_size must have four elements ",
|
||||
crop_window.shape().DebugString()));
|
||||
auto crop_window_vec = crop_window.vec<int32>();
|
||||
flags.crop_y = crop_window_vec(0);
|
||||
flags.crop_x = crop_window_vec(1);
|
||||
flags.crop_height = crop_window_vec(2);
|
||||
flags.crop_width = crop_window_vec(3);
|
||||
}
|
||||
|
||||
// Decode jpeg, allocating tensor once the size is known.
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
jpeg::Uncompress(
|
||||
input.data(), input.size(), flags_, nullptr /* nwarn */,
|
||||
input.data(), input.size(), flags, nullptr /* nwarn */,
|
||||
[=, &output](int width, int height, int channels) -> uint8* {
|
||||
Status status(context->allocate_output(
|
||||
0,
|
||||
@ -205,7 +227,8 @@ class DecodeImageOp : public OpKernel {
|
||||
}
|
||||
return output->flat<uint8>().data();
|
||||
}),
|
||||
errors::InvalidArgument("Invalid JPEG data, size ", input.size()));
|
||||
errors::InvalidArgument("Invalid JPEG data or crop window, data size ",
|
||||
input.size()));
|
||||
}
|
||||
|
||||
void DecodePng(OpKernelContext* context, StringPiece input) {
|
||||
@ -311,6 +334,8 @@ class DecodeImageOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("DecodeJpeg").Device(DEVICE_CPU), DecodeImageOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("DecodePng").Device(DEVICE_CPU), DecodeImageOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("DecodeGif").Device(DEVICE_CPU), DecodeImageOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("DecodeAndCropJpeg").Device(DEVICE_CPU),
|
||||
DecodeImageOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -25,6 +25,42 @@ using shape_inference::ShapeHandle;
|
||||
|
||||
namespace {
|
||||
|
||||
const char kDecodeJpegCommonDocStr[] = R"doc(
|
||||
The attr `channels` indicates the desired number of color channels for the
|
||||
decoded image.
|
||||
|
||||
Accepted values are:
|
||||
|
||||
* 0: Use the number of channels in the JPEG-encoded image.
|
||||
* 1: output a grayscale image.
|
||||
* 3: output an RGB image.
|
||||
|
||||
If needed, the JPEG-encoded image is transformed to match the requested number
|
||||
of color channels.
|
||||
|
||||
The attr `ratio` allows downscaling the image by an integer factor during
|
||||
decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than
|
||||
downscaling the image later.
|
||||
|
||||
)doc";
|
||||
|
||||
const char kDecodeJpegCommonParamsDocStr[] = R"doc(
|
||||
channels: Number of color channels for the decoded image.
|
||||
ratio: Downscaling ratio.
|
||||
fancy_upscaling: If true use a slower but nicer upscaling of the
|
||||
chroma planes (yuv420/422 only).
|
||||
try_recover_truncated: If true try to recover an image from truncated input.
|
||||
acceptable_fraction: The minimum required fraction of lines before a truncated
|
||||
input is accepted.
|
||||
dct_method: string specifying a hint about the algorithm used for
|
||||
decompression. Defaults to "" which maps to a system-specific
|
||||
default. Currently valid values are ["INTEGER_FAST",
|
||||
"INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal
|
||||
jpeg library changes to a version that does not have that specific
|
||||
option.)
|
||||
image: 3-D with shape `[height, width, channels]`..
|
||||
)doc";
|
||||
|
||||
// Sets output[0] to shape [batch_dim,height,width,channel_dim], where
|
||||
// height and width come from the size_tensor.
|
||||
Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
|
||||
@ -370,44 +406,40 @@ REGISTER_OP("DecodeJpeg")
|
||||
.Attr("dct_method: string = ''")
|
||||
.Output("image: uint8")
|
||||
.SetShapeFn(DecodeImageShapeFn)
|
||||
.Doc(R"doc(
|
||||
.Doc(strings::StrCat(R"doc(
|
||||
Decode a JPEG-encoded image to a uint8 tensor.
|
||||
|
||||
The attr `channels` indicates the desired number of color channels for the
|
||||
decoded image.
|
||||
|
||||
Accepted values are:
|
||||
|
||||
* 0: Use the number of channels in the JPEG-encoded image.
|
||||
* 1: output a grayscale image.
|
||||
* 3: output an RGB image.
|
||||
|
||||
If needed, the JPEG-encoded image is transformed to match the requested number
|
||||
of color channels.
|
||||
|
||||
The attr `ratio` allows downscaling the image by an integer factor during
|
||||
decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than
|
||||
downscaling the image later.
|
||||
|
||||
)doc",
|
||||
kDecodeJpegCommonDocStr, R"doc(
|
||||
This op also supports decoding PNGs and non-animated GIFs since the interface is
|
||||
the same, though it is cleaner to use `tf.image.decode_image`.
|
||||
|
||||
contents: 0-D. The JPEG-encoded image.
|
||||
channels: Number of color channels for the decoded image.
|
||||
ratio: Downscaling ratio.
|
||||
fancy_upscaling: If true use a slower but nicer upscaling of the
|
||||
chroma planes (yuv420/422 only).
|
||||
try_recover_truncated: If true try to recover an image from truncated input.
|
||||
acceptable_fraction: The minimum required fraction of lines before a truncated
|
||||
input is accepted.
|
||||
dct_method: string specifying a hint about the algorithm used for
|
||||
decompression. Defaults to "" which maps to a system-specific
|
||||
default. Currently valid values are ["INTEGER_FAST",
|
||||
"INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal
|
||||
jpeg library changes to a version that does not have that specific
|
||||
option.)
|
||||
image: 3-D with shape `[height, width, channels]`..
|
||||
)doc");
|
||||
)doc",
|
||||
kDecodeJpegCommonParamsDocStr));
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("DecodeAndCropJpeg")
|
||||
.Input("contents: string")
|
||||
.Input("crop_window: int32")
|
||||
.Attr("channels: int = 0")
|
||||
.Attr("ratio: int = 1")
|
||||
.Attr("fancy_upscaling: bool = true")
|
||||
.Attr("try_recover_truncated: bool = false")
|
||||
.Attr("acceptable_fraction: float = 1.0")
|
||||
.Attr("dct_method: string = ''")
|
||||
.Output("image: uint8")
|
||||
.SetShapeFn(DecodeImageShapeFn)
|
||||
.Doc(strings::StrCat(R"doc(
|
||||
Decode and Crop a JPEG-encoded image to a uint8 tensor.
|
||||
)doc",
|
||||
kDecodeJpegCommonDocStr, R"doc(
|
||||
It is equivalent to a combination of decode and crop, but much faster by only
|
||||
decoding partial jpeg image.
|
||||
|
||||
contents: 0-D. The JPEG-encoded image.
|
||||
crop_window: 1-D. The crop window: [crop_y, crop_x, crop_height, crop_width].
|
||||
)doc",
|
||||
kDecodeJpegCommonParamsDocStr));
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("EncodeJpeg")
|
||||
|
@ -90,6 +90,58 @@ TEST(ImageOpsTest, DecodeImage_ShapeFn) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ImageOpsTest, DecodeAndCropJpeg_ShapeFn) {
|
||||
const char* op_name = "DecodeAndCropJpeg";
|
||||
ShapeInferenceTestOp op(op_name);
|
||||
|
||||
// Check the number of inputs.
|
||||
INFER_ERROR("Wrong number of inputs passed: 1 while 2 expected", op, "[1]");
|
||||
|
||||
// Rank check.
|
||||
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];?");
|
||||
|
||||
// Set the channel to zero - output is not known.
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", op_name)
|
||||
.Input({"img", 0, DT_STRING})
|
||||
.Input({"crop_window", 1, DT_INT32})
|
||||
.Finalize(&op.node_def));
|
||||
INFER_OK(op, "[];[]", "[?,?,?]");
|
||||
|
||||
// Set the channel, so that part of output shape is known.
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", op_name)
|
||||
.Input({"img", 0, DT_STRING})
|
||||
.Input({"crop_window", 1, DT_INT32})
|
||||
.Attr("channels", 4)
|
||||
.Finalize(&op.node_def));
|
||||
INFER_OK(op, "[];[]", "[?,?,4]");
|
||||
|
||||
// Negative channel value is rejected.
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", op_name)
|
||||
.Input({"img", 0, DT_STRING})
|
||||
.Input({"crop_window", 1, DT_INT32})
|
||||
.Attr("channels", -1)
|
||||
.Finalize(&op.node_def));
|
||||
INFER_ERROR("channels must be non-negative, got -1", op, "[];[]");
|
||||
}
|
||||
|
||||
TEST(ImageOpsTest, DecodeAndCropJpeg_InvalidCropWindow) {
|
||||
const char* op_name = "DecodeAndCropJpeg";
|
||||
ShapeInferenceTestOp op(op_name);
|
||||
|
||||
// Check the number of inputs.
|
||||
INFER_ERROR("Wrong number of inputs passed: 1 while 2 expected", op, "[1]");
|
||||
|
||||
// Rank check.
|
||||
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];?");
|
||||
|
||||
// Set the channel to zero - output is not known.
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", op_name)
|
||||
.Input({"img", 0, DT_STRING})
|
||||
.Input({"crop_window", 1, DT_INT32})
|
||||
.Finalize(&op.node_def));
|
||||
INFER_OK(op, "[];[]", "[?,?,?]");
|
||||
}
|
||||
|
||||
TEST(ImageOpsTest, EncodeImage_ShapeFn) {
|
||||
for (const char* op_name : {"EncodeJpeg", "EncodePng"}) {
|
||||
ShapeInferenceTestOp op(op_name);
|
||||
|
@ -37,7 +37,13 @@ prefix_path = 'third_party/tensorflow/core/lib/jpeg/testdata'
|
||||
class DecodeJpegBenchmark(test.Benchmark):
|
||||
"""Evaluate tensorflow DecodeJpegOp performance."""
|
||||
|
||||
def _evalDecodeJpeg(self, image_name, parallelism, num_iters, tile=None):
|
||||
def _evalDecodeJpeg(self,
|
||||
image_name,
|
||||
parallelism,
|
||||
num_iters,
|
||||
crop_during_decode=None,
|
||||
crop_window=None,
|
||||
tile=None):
|
||||
"""Evaluate DecodeJpegOp for the given image.
|
||||
|
||||
TODO(tanmingxing): add decoding+cropping as well.
|
||||
@ -46,6 +52,10 @@ class DecodeJpegBenchmark(test.Benchmark):
|
||||
image_name: a string of image file name (without suffix).
|
||||
parallelism: the number of concurrent decode_jpeg ops to be run.
|
||||
num_iters: number of iterations for evaluation.
|
||||
crop_during_decode: If true, use fused DecodeAndCropJpeg instead of
|
||||
separate decode and crop ops. It is ignored if crop_window is None.
|
||||
crop_window: if not None, crop the decoded image. Depending on
|
||||
crop_during_decode, cropping could happen during or after decoding.
|
||||
tile: if not None, tile the image to composite a larger fake image.
|
||||
|
||||
Returns:
|
||||
@ -71,11 +81,25 @@ class DecodeJpegBenchmark(test.Benchmark):
|
||||
with session.Session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
images = []
|
||||
for i in xrange(parallelism):
|
||||
images.append(
|
||||
image_ops.decode_jpeg(
|
||||
image_content, channels=3, name='image_%d' % (i)))
|
||||
for _ in xrange(parallelism):
|
||||
if crop_window is None:
|
||||
# No crop.
|
||||
image = image_ops.decode_jpeg(image_content, channels=3)
|
||||
elif crop_during_decode:
|
||||
# combined decode and crop.
|
||||
image = image_ops.decode_and_crop_jpeg(
|
||||
image_content, crop_window, channels=3)
|
||||
else:
|
||||
# separate decode and crop.
|
||||
image = image_ops.decode_jpeg(image_content, channels=3)
|
||||
image = image_ops.crop_to_bounding_box(
|
||||
image,
|
||||
offset_height=crop_window[0],
|
||||
offset_width=crop_window[1],
|
||||
target_height=crop_window[2],
|
||||
target_width=crop_window[3])
|
||||
|
||||
images.append(image)
|
||||
r = control_flow_ops.group(*images)
|
||||
|
||||
for _ in xrange(3):
|
||||
@ -89,38 +113,77 @@ class DecodeJpegBenchmark(test.Benchmark):
|
||||
|
||||
def benchmarkDecodeJpegSmall(self):
|
||||
"""Evaluate single DecodeImageOp for small size image."""
|
||||
parallelism = 1
|
||||
num_iters = 10
|
||||
for parallelism in [1, 10, 100]:
|
||||
duration = self._evalDecodeJpeg('small.jpg', parallelism, num_iters)
|
||||
crop_window = [10, 10, 50, 50]
|
||||
for parallelism in [1, 100]:
|
||||
duration_decode = self._evalDecodeJpeg('small.jpg', parallelism,
|
||||
num_iters)
|
||||
duration_decode_crop = self._evalDecodeJpeg('small.jpg', parallelism,
|
||||
num_iters, False, crop_window)
|
||||
duration_decode_after_crop = self._evalDecodeJpeg(
|
||||
'small.jpg', parallelism, num_iters, True, crop_window)
|
||||
self.report_benchmark(
|
||||
name='decode_jpeg_small_p%d' % (parallelism),
|
||||
iters=num_iters,
|
||||
wall_time=duration)
|
||||
wall_time=duration_decode)
|
||||
self.report_benchmark(
|
||||
name='decode_crop_jpeg_small_p%d' % (parallelism),
|
||||
iters=num_iters,
|
||||
wall_time=duration_decode_crop)
|
||||
self.report_benchmark(
|
||||
name='decode_after_crop_jpeg_small_p%d' % (parallelism),
|
||||
iters=num_iters,
|
||||
wall_time=duration_decode_after_crop)
|
||||
|
||||
def benchmarkDecodeJpegMedium(self):
|
||||
"""Evaluate single DecodeImageOp for medium size image."""
|
||||
parallelism = 1
|
||||
num_iters = 10
|
||||
for parallelism in [1, 10, 100]:
|
||||
duration = self._evalDecodeJpeg('medium.jpg', parallelism, num_iters)
|
||||
crop_window = [10, 10, 50, 50]
|
||||
for parallelism in [1, 100]:
|
||||
duration_decode = self._evalDecodeJpeg('medium.jpg', parallelism,
|
||||
num_iters)
|
||||
duration_decode_crop = self._evalDecodeJpeg('medium.jpg', parallelism,
|
||||
num_iters, False, crop_window)
|
||||
duration_decode_after_crop = self._evalDecodeJpeg(
|
||||
'medium.jpg', parallelism, num_iters, True, crop_window)
|
||||
self.report_benchmark(
|
||||
name='decode_jpeg_medium_p%d' % (parallelism),
|
||||
iters=num_iters,
|
||||
wall_time=duration)
|
||||
wall_time=duration_decode)
|
||||
self.report_benchmark(
|
||||
name='decode_crop_jpeg_medium_p%d' % (parallelism),
|
||||
iters=num_iters,
|
||||
wall_time=duration_decode_crop)
|
||||
self.report_benchmark(
|
||||
name='decode_after_crop_jpeg_medium_p%d' % (parallelism),
|
||||
iters=num_iters,
|
||||
wall_time=duration_decode_after_crop)
|
||||
|
||||
def benchmarkDecodeJpegLarge(self):
|
||||
"""Evaluate single DecodeImageOp for large size image."""
|
||||
parallelism = 1
|
||||
num_iters = 10
|
||||
for parallelism in [1, 10, 100]:
|
||||
crop_window = [10, 10, 50, 50]
|
||||
tile = [4, 4, 1]
|
||||
for parallelism in [1, 100]:
|
||||
# Tile the medium size image to composite a larger fake image.
|
||||
duration = self._evalDecodeJpeg(
|
||||
'medium.jpg', parallelism, num_iters, tile=[4, 4, 1])
|
||||
duration_decode = self._evalDecodeJpeg('medium.jpg', parallelism,
|
||||
num_iters, tile)
|
||||
duration_decode_crop = self._evalDecodeJpeg(
|
||||
'medium.jpg', parallelism, num_iters, False, crop_window, tile)
|
||||
duration_decode_after_crop = self._evalDecodeJpeg(
|
||||
'medium.jpg', parallelism, num_iters, True, crop_window, tile)
|
||||
self.report_benchmark(
|
||||
name='decode_jpeg_large_p%d' % (parallelism),
|
||||
iters=num_iters,
|
||||
wall_time=duration)
|
||||
wall_time=duration_decode)
|
||||
self.report_benchmark(
|
||||
name='decode_crop_jpeg_large_p%d' % (parallelism),
|
||||
iters=num_iters,
|
||||
wall_time=duration_decode_crop)
|
||||
self.report_benchmark(
|
||||
name='decode_after_crop_jpeg_large_p%d' % (parallelism),
|
||||
iters=num_iters,
|
||||
wall_time=duration_decode_after_crop)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -21,6 +21,7 @@ See the @{$python/image} guide.
|
||||
@@decode_bmp
|
||||
@@decode_gif
|
||||
@@decode_jpeg
|
||||
@@decode_and_crop_jpeg
|
||||
@@encode_jpeg
|
||||
@@extract_jpeg_shape
|
||||
@@decode_png
|
||||
|
@ -2391,6 +2391,46 @@ class JpegTest(test_util.TensorFlowTestCase):
|
||||
error = self.averageError(rgb, cmyk)
|
||||
self.assertLess(error, 4)
|
||||
|
||||
def testCropAndDecodeJpeg(self):
|
||||
with self.test_session() as sess:
|
||||
# Encode it, then decode it, then encode it
|
||||
base = "tensorflow/core/lib/jpeg/testdata"
|
||||
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
|
||||
|
||||
h, w, _ = 256, 128, 3
|
||||
crop_windows = [[0, 0, 5, 5], [0, 0, 5, w], [0, 0, h, 5],
|
||||
[h - 6, w - 5, 6, 5], [6, 5, 15, 10], [0, 0, h, w]]
|
||||
for crop_window in crop_windows:
|
||||
# Explicit two stages: decode + crop.
|
||||
image1 = image_ops.decode_jpeg(jpeg0)
|
||||
y, x, h, w = crop_window
|
||||
image1_crop = image_ops.crop_to_bounding_box(image1, y, x, h, w)
|
||||
|
||||
# Combined crop+decode.
|
||||
image2 = image_ops.decode_and_crop_jpeg(jpeg0, crop_window)
|
||||
|
||||
# CropAndDecode should be equal to DecodeJpeg+Crop.
|
||||
image1_crop, image2 = sess.run([image1_crop, image2])
|
||||
self.assertAllEqual(image1_crop, image2)
|
||||
|
||||
def testCropAndDecodeJpegWithInvalidCropWindow(self):
|
||||
with self.test_session() as sess:
|
||||
# Encode it, then decode it, then encode it
|
||||
base = "tensorflow/core/lib/jpeg/testdata"
|
||||
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
|
||||
|
||||
h, w, _ = 256, 128, 3
|
||||
# Invalid crop windows.
|
||||
crop_windows = [[-1, 11, 11, 11], [11, -1, 11, 11], [11, 11, -1, 11],
|
||||
[11, 11, 11, -1], [11, 11, 0, 11], [11, 11, 11, 0],
|
||||
[0, 0, h + 1, w], [0, 0, h, w + 1]]
|
||||
for crop_window in crop_windows:
|
||||
result = image_ops.decode_and_crop_jpeg(jpeg0, crop_window)
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
errors.InvalidArgumentError,
|
||||
lambda e: "Invalid JPEG data or crop window" in str(e)):
|
||||
sess.run(result)
|
||||
|
||||
def testSynthetic(self):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
# Encode it, then decode it, then encode it
|
||||
|
@ -40,6 +40,10 @@ tf_module {
|
||||
name: "crop_to_bounding_box"
|
||||
argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "decode_and_crop_jpeg"
|
||||
argspec: "args=[\'contents\', \'crop_window\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \'True\', \'False\', \'1\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "decode_bmp"
|
||||
argspec: "args=[\'contents\', \'channels\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user