Add uint16 support for tf.decode_raw (#12719)
* Add uint16 support for tf.decode_raw This fix tries to address the request raised in 10124 where uint16 support for tf.decode_raw is needed. tf.decode_raw already support half, float32, float64, int8, int16, int32, int64, uint8. And uint16 was not supported. This fix adds uint16 support for tf.decode_raw. This fix fixes 10124. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Fix test failure caused by uint16 support of decode_raw and add unit tests. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
3bc73f5e2a
commit
2bc7a155a7
@ -228,7 +228,10 @@ class TFExampleDecoderTest(test.TestCase):
|
|||||||
image_shape = (2, 3, 3)
|
image_shape = (2, 3, 3)
|
||||||
unused_image, serialized_example = self.GenerateImage(
|
unused_image, serialized_example = self.GenerateImage(
|
||||||
image_format='jpeg', image_shape=image_shape)
|
image_format='jpeg', image_shape=image_shape)
|
||||||
with self.assertRaises(TypeError):
|
# decode_raw support uint16 now so ValueError will be thrown instead.
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError,
|
||||||
|
'true_fn and false_fn must have the same type: uint16, uint8'):
|
||||||
unused_decoded_image = self.RunDecodeExample(
|
unused_decoded_image = self.RunDecodeExample(
|
||||||
serialized_example,
|
serialized_example,
|
||||||
tfexample_decoder.Image(dtype=dtypes.uint16),
|
tfexample_decoder.Image(dtype=dtypes.uint16),
|
||||||
|
@ -105,6 +105,7 @@ REGISTER(Eigen::half);
|
|||||||
REGISTER(float);
|
REGISTER(float);
|
||||||
REGISTER(double);
|
REGISTER(double);
|
||||||
REGISTER(int32);
|
REGISTER(int32);
|
||||||
|
REGISTER(uint16);
|
||||||
REGISTER(uint8);
|
REGISTER(uint8);
|
||||||
REGISTER(int16);
|
REGISTER(int16);
|
||||||
REGISTER(int8);
|
REGISTER(int8);
|
||||||
|
@ -26,7 +26,7 @@ using shape_inference::ShapeHandle;
|
|||||||
REGISTER_OP("DecodeRaw")
|
REGISTER_OP("DecodeRaw")
|
||||||
.Input("bytes: string")
|
.Input("bytes: string")
|
||||||
.Output("output: out_type")
|
.Output("output: out_type")
|
||||||
.Attr("out_type: {half,float,double,int32,uint8,int16,int8,int64}")
|
.Attr("out_type: {half,float,double,int32,uint16,uint8,int16,int8,int64}")
|
||||||
.Attr("little_endian: bool = true")
|
.Attr("little_endian: bool = true")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
// Note: last dimension is data dependent.
|
// Note: last dimension is data dependent.
|
||||||
|
@ -93,6 +93,22 @@ class DecodeRawOpTest(test.TestCase):
|
|||||||
result = decode.eval(feed_dict={in_bytes: [""]})
|
result = decode.eval(feed_dict={in_bytes: [""]})
|
||||||
self.assertEqual(len(result), 1)
|
self.assertEqual(len(result), 1)
|
||||||
|
|
||||||
|
def testToUInt16(self):
|
||||||
|
with self.test_session():
|
||||||
|
in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
|
||||||
|
decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.uint16)
|
||||||
|
self.assertEqual([None, None], decode.get_shape().as_list())
|
||||||
|
|
||||||
|
# Use FF/EE/DD/CC so that decoded value is higher than 32768 for uint16
|
||||||
|
result = decode.eval(feed_dict={in_bytes: [b"\xFF\xEE\xDD\xCC"]})
|
||||||
|
self.assertAllEqual(
|
||||||
|
[[0xFF + 0xEE * 256, 0xDD + 0xCC * 256]], result)
|
||||||
|
|
||||||
|
with self.assertRaisesOpError(
|
||||||
|
"Input to DecodeRaw has length 3 that is not a multiple of 2, the "
|
||||||
|
"size of uint16"):
|
||||||
|
decode.eval(feed_dict={in_bytes: ["123", "456"]})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user