Add uint16 support for tf.decode_raw (#12719)
* Add uint16 support for tf.decode_raw This fix tries to address the request raised in 10124 where uint16 support for tf.decode_raw is needed. tf.decode_raw already support half, float32, float64, int8, int16, int32, int64, uint8. And uint16 was not supported. This fix adds uint16 support for tf.decode_raw. This fix fixes 10124. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Fix test failure caused by uint16 support of decode_raw and add unit tests. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
3bc73f5e2a
commit
2bc7a155a7
@ -228,7 +228,10 @@ class TFExampleDecoderTest(test.TestCase):
|
||||
image_shape = (2, 3, 3)
|
||||
unused_image, serialized_example = self.GenerateImage(
|
||||
image_format='jpeg', image_shape=image_shape)
|
||||
with self.assertRaises(TypeError):
|
||||
# decode_raw support uint16 now so ValueError will be thrown instead.
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
'true_fn and false_fn must have the same type: uint16, uint8'):
|
||||
unused_decoded_image = self.RunDecodeExample(
|
||||
serialized_example,
|
||||
tfexample_decoder.Image(dtype=dtypes.uint16),
|
||||
|
@ -105,6 +105,7 @@ REGISTER(Eigen::half);
|
||||
REGISTER(float);
|
||||
REGISTER(double);
|
||||
REGISTER(int32);
|
||||
REGISTER(uint16);
|
||||
REGISTER(uint8);
|
||||
REGISTER(int16);
|
||||
REGISTER(int8);
|
||||
|
@ -26,7 +26,7 @@ using shape_inference::ShapeHandle;
|
||||
REGISTER_OP("DecodeRaw")
|
||||
.Input("bytes: string")
|
||||
.Output("output: out_type")
|
||||
.Attr("out_type: {half,float,double,int32,uint8,int16,int8,int64}")
|
||||
.Attr("out_type: {half,float,double,int32,uint16,uint8,int16,int8,int64}")
|
||||
.Attr("little_endian: bool = true")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// Note: last dimension is data dependent.
|
||||
|
@ -93,6 +93,22 @@ class DecodeRawOpTest(test.TestCase):
|
||||
result = decode.eval(feed_dict={in_bytes: [""]})
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
def testToUInt16(self):
|
||||
with self.test_session():
|
||||
in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
|
||||
decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.uint16)
|
||||
self.assertEqual([None, None], decode.get_shape().as_list())
|
||||
|
||||
# Use FF/EE/DD/CC so that decoded value is higher than 32768 for uint16
|
||||
result = decode.eval(feed_dict={in_bytes: [b"\xFF\xEE\xDD\xCC"]})
|
||||
self.assertAllEqual(
|
||||
[[0xFF + 0xEE * 256, 0xDD + 0xCC * 256]], result)
|
||||
|
||||
with self.assertRaisesOpError(
|
||||
"Input to DecodeRaw has length 3 that is not a multiple of 2, the "
|
||||
"size of uint16"):
|
||||
decode.eval(feed_dict={in_bytes: ["123", "456"]})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user