Add a wrapper around encode_png which converts to tensor before calling into

the gen ops code.

This way if there is uint16 np array, we'll use uint16 as the dtype correctly
instead of using uint8 (which would happen before due to EncodePng having a
default type of uint8).

PiperOrigin-RevId: 282633227
Change-Id: I46d5bda1f94fda3e0a4b15c57591ef86dd1649a4
This commit is contained in:
Akshay Modi 2019-11-26 13:52:58 -08:00 committed by TensorFlower Gardener
parent d9a5dad3be
commit 9e8a730f08
3 changed files with 38 additions and 0 deletions

View File

@ -3,4 +3,5 @@ op {
endpoint {
name: "image.encode_png"
}
visibility: HIDDEN
}

View File

@ -2232,6 +2232,35 @@ tf_export(
gen_image_ops.extract_jpeg_shape)
@tf_export('image.encode_png')
def encode_png(image, compression=-1, name=None):
r"""PNG-encode an image.
`image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]`
where `channels` is:
* 1: for grayscale.
* 2: for grayscale + alpha.
* 3: for RGB.
* 4: for RGBA.
The ZLIB compression level, `compression`, can be -1 for the PNG-encoder
default or a value from 0 to 9. 9 is the highest compression level,
generating the smallest output, but is slower.
Args:
image: A `Tensor`. Must be one of the following types: `uint8`, `uint16`.
3-D with shape `[height, width, channels]`.
compression: An optional `int`. Defaults to `-1`. Compression level.
name: A name for the operation (optional).
Returns:
A `Tensor` of type `string`.
"""
return gen_image_ops.encode_png(
ops.convert_to_tensor(image), compression, name)
@tf_export(
'io.decode_image',
'image.decode_image',

View File

@ -5119,6 +5119,7 @@ class SobelEdgesTest(test_util.TensorFlowTestCase):
self.assertAllClose(expected_batch, actual_sobel)
@test_util.run_all_in_graph_and_eager_modes
class DecodeImageTest(test_util.TensorFlowTestCase):
def testJpegUint16(self):
@ -5141,6 +5142,13 @@ class DecodeImageTest(test_util.TensorFlowTestCase):
image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
# NumPy conversions should happen before
x = np.random.randint(256, size=(4, 4, 3), dtype=np.uint16)
x_str = image_ops_impl.encode_png(x)
x_dec = image_ops_impl.decode_image(
x_str, channels=3, dtype=dtypes.uint16)
self.assertAllEqual(x, x_dec)
def testGifUint16(self):
with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/gif/testdata"