diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py index 506f4bd8777..96606b9c0e5 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py @@ -228,7 +228,10 @@ class TFExampleDecoderTest(test.TestCase): image_shape = (2, 3, 3) unused_image, serialized_example = self.GenerateImage( image_format='jpeg', image_shape=image_shape) - with self.assertRaises(TypeError): + # decode_raw support uint16 now so ValueError will be thrown instead. + with self.assertRaisesRegexp( + ValueError, + 'true_fn and false_fn must have the same type: uint16, uint8'): unused_decoded_image = self.RunDecodeExample( serialized_example, tfexample_decoder.Image(dtype=dtypes.uint16), diff --git a/tensorflow/core/kernels/decode_raw_op.cc b/tensorflow/core/kernels/decode_raw_op.cc index 9492a4e26d4..1c0085cfeab 100644 --- a/tensorflow/core/kernels/decode_raw_op.cc +++ b/tensorflow/core/kernels/decode_raw_op.cc @@ -105,6 +105,7 @@ REGISTER(Eigen::half); REGISTER(float); REGISTER(double); REGISTER(int32); +REGISTER(uint16); REGISTER(uint8); REGISTER(int16); REGISTER(int8); diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc index 1f7ebe91cf0..f23ff083afe 100644 --- a/tensorflow/core/ops/parsing_ops.cc +++ b/tensorflow/core/ops/parsing_ops.cc @@ -26,7 +26,7 @@ using shape_inference::ShapeHandle; REGISTER_OP("DecodeRaw") .Input("bytes: string") .Output("output: out_type") - .Attr("out_type: {half,float,double,int32,uint8,int16,int8,int64}") + .Attr("out_type: {half,float,double,int32,uint16,uint8,int16,int8,int64}") .Attr("little_endian: bool = true") .SetShapeFn([](InferenceContext* c) { // Note: last dimension is data dependent. diff --git a/tensorflow/python/kernel_tests/decode_raw_op_test.py b/tensorflow/python/kernel_tests/decode_raw_op_test.py index e986b7ff2b6..009f3ea4b31 100644 --- a/tensorflow/python/kernel_tests/decode_raw_op_test.py +++ b/tensorflow/python/kernel_tests/decode_raw_op_test.py @@ -93,6 +93,22 @@ class DecodeRawOpTest(test.TestCase): result = decode.eval(feed_dict={in_bytes: [""]}) self.assertEqual(len(result), 1) + def testToUInt16(self): + with self.test_session(): + in_bytes = array_ops.placeholder(dtypes.string, shape=[None]) + decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.uint16) + self.assertEqual([None, None], decode.get_shape().as_list()) + + # Use FF/EE/DD/CC so that decoded value is higher than 32768 for uint16 + result = decode.eval(feed_dict={in_bytes: [b"\xFF\xEE\xDD\xCC"]}) + self.assertAllEqual( + [[0xFF + 0xEE * 256, 0xDD + 0xCC * 256]], result) + + with self.assertRaisesOpError( + "Input to DecodeRaw has length 3 that is not a multiple of 2, the " + "size of uint16"): + decode.eval(feed_dict={in_bytes: ["123", "456"]}) + if __name__ == "__main__": test.main()