Add a wrapper around encode_png which converts to tensor before calling into
the gen ops code. This way if there is uint16 np array, we'll use uint16 as the dtype correctly instead of using uint8 (which would happen before due to EncodePng having a default type of uint8). PiperOrigin-RevId: 282633227 Change-Id: I46d5bda1f94fda3e0a4b15c57591ef86dd1649a4
This commit is contained in:
parent
d9a5dad3be
commit
9e8a730f08
@ -3,4 +3,5 @@ op {
|
||||
endpoint {
|
||||
name: "image.encode_png"
|
||||
}
|
||||
visibility: HIDDEN
|
||||
}
|
||||
|
@ -2232,6 +2232,35 @@ tf_export(
|
||||
gen_image_ops.extract_jpeg_shape)
|
||||
|
||||
|
||||
@tf_export('image.encode_png')
|
||||
def encode_png(image, compression=-1, name=None):
|
||||
r"""PNG-encode an image.
|
||||
|
||||
`image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]`
|
||||
where `channels` is:
|
||||
|
||||
* 1: for grayscale.
|
||||
* 2: for grayscale + alpha.
|
||||
* 3: for RGB.
|
||||
* 4: for RGBA.
|
||||
|
||||
The ZLIB compression level, `compression`, can be -1 for the PNG-encoder
|
||||
default or a value from 0 to 9. 9 is the highest compression level,
|
||||
generating the smallest output, but is slower.
|
||||
|
||||
Args:
|
||||
image: A `Tensor`. Must be one of the following types: `uint8`, `uint16`.
|
||||
3-D with shape `[height, width, channels]`.
|
||||
compression: An optional `int`. Defaults to `-1`. Compression level.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` of type `string`.
|
||||
"""
|
||||
return gen_image_ops.encode_png(
|
||||
ops.convert_to_tensor(image), compression, name)
|
||||
|
||||
|
||||
@tf_export(
|
||||
'io.decode_image',
|
||||
'image.decode_image',
|
||||
|
@ -5119,6 +5119,7 @@ class SobelEdgesTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllClose(expected_batch, actual_sobel)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class DecodeImageTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testJpegUint16(self):
|
||||
@ -5141,6 +5142,13 @@ class DecodeImageTest(test_util.TensorFlowTestCase):
|
||||
image0, image1 = self.evaluate([image0, image1])
|
||||
self.assertAllEqual(image0, image1)
|
||||
|
||||
# NumPy conversions should happen before
|
||||
x = np.random.randint(256, size=(4, 4, 3), dtype=np.uint16)
|
||||
x_str = image_ops_impl.encode_png(x)
|
||||
x_dec = image_ops_impl.decode_image(
|
||||
x_str, channels=3, dtype=dtypes.uint16)
|
||||
self.assertAllEqual(x, x_dec)
|
||||
|
||||
def testGifUint16(self):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
base = "tensorflow/core/lib/gif/testdata"
|
||||
|
Loading…
Reference in New Issue
Block a user