Add uint16 support for tf.decode_raw (#12719)

* Add uint16 support for tf.decode_raw

This fix tries to address the request raised in 10124 where
uint16 support for tf.decode_raw is needed. tf.decode_raw
already support half, float32, float64, int8, int16, int32, int64,
uint8. And uint16 was not supported.

This fix adds uint16 support for tf.decode_raw.

This fix fixes 10124.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Fix test failure caused by uint16 support of decode_raw and add unit tests.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2017-09-13 10:03:00 -07:00 committed by Rasmus Munk Larsen
parent 3bc73f5e2a
commit 2bc7a155a7
4 changed files with 22 additions and 2 deletions

View File

@ -228,7 +228,10 @@ class TFExampleDecoderTest(test.TestCase):
image_shape = (2, 3, 3)
unused_image, serialized_example = self.GenerateImage(
image_format='jpeg', image_shape=image_shape)
with self.assertRaises(TypeError):
# decode_raw support uint16 now so ValueError will be thrown instead.
with self.assertRaisesRegexp(
ValueError,
'true_fn and false_fn must have the same type: uint16, uint8'):
unused_decoded_image = self.RunDecodeExample(
serialized_example,
tfexample_decoder.Image(dtype=dtypes.uint16),

View File

@ -105,6 +105,7 @@ REGISTER(Eigen::half);
REGISTER(float);
REGISTER(double);
REGISTER(int32);
REGISTER(uint16);
REGISTER(uint8);
REGISTER(int16);
REGISTER(int8);

View File

@ -26,7 +26,7 @@ using shape_inference::ShapeHandle;
REGISTER_OP("DecodeRaw")
.Input("bytes: string")
.Output("output: out_type")
.Attr("out_type: {half,float,double,int32,uint8,int16,int8,int64}")
.Attr("out_type: {half,float,double,int32,uint16,uint8,int16,int8,int64}")
.Attr("little_endian: bool = true")
.SetShapeFn([](InferenceContext* c) {
// Note: last dimension is data dependent.

View File

@ -93,6 +93,22 @@ class DecodeRawOpTest(test.TestCase):
result = decode.eval(feed_dict={in_bytes: [""]})
self.assertEqual(len(result), 1)
def testToUInt16(self):
with self.test_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.uint16)
self.assertEqual([None, None], decode.get_shape().as_list())
# Use FF/EE/DD/CC so that decoded value is higher than 32768 for uint16
result = decode.eval(feed_dict={in_bytes: [b"\xFF\xEE\xDD\xCC"]})
self.assertAllEqual(
[[0xFF + 0xEE * 256, 0xDD + 0xCC * 256]], result)
with self.assertRaisesOpError(
"Input to DecodeRaw has length 3 that is not a multiple of 2, the "
"size of uint16"):
decode.eval(feed_dict={in_bytes: ["123", "456"]})
if __name__ == "__main__":
test.main()