Fix typo in convert_image_to_dtype
PiperOrigin-RevId: 169252296
This commit is contained in:
parent
095f6aa7d1
commit
5b94356280
@ -1061,7 +1061,7 @@ def convert_image_dtype(image, dtype, saturate=False, name=None):
|
|||||||
# Scaling up, cast first, then scale. The scale will not map in.max to
|
# Scaling up, cast first, then scale. The scale will not map in.max to
|
||||||
# out.max, but converting back and forth should result in no change.
|
# out.max, but converting back and forth should result in no change.
|
||||||
if saturate:
|
if saturate:
|
||||||
cast = math_ops.saturate_cast(scaled, dtype)
|
cast = math_ops.saturate_cast(image, dtype)
|
||||||
else:
|
else:
|
||||||
cast = math_ops.cast(image, dtype)
|
cast = math_ops.cast(image, dtype)
|
||||||
scale = (scale_out + 1) // (scale_in + 1)
|
scale = (scale_out + 1) // (scale_in + 1)
|
||||||
|
@ -2655,6 +2655,12 @@ class ConvertImageTest(test_util.TensorFlowTestCase):
|
|||||||
y = image_ops.convert_image_dtype(image, output_dtype)
|
y = image_ops.convert_image_dtype(image, output_dtype)
|
||||||
self.assertTrue(y.dtype == output_dtype)
|
self.assertTrue(y.dtype == output_dtype)
|
||||||
self.assertAllClose(y.eval(), y_np, atol=1e-5)
|
self.assertAllClose(y.eval(), y_np, atol=1e-5)
|
||||||
|
if output_dtype in [dtypes.float32, dtypes.float64,
|
||||||
|
dtypes.int32, dtypes.int64]:
|
||||||
|
y_saturate = image_ops.convert_image_dtype(
|
||||||
|
image, output_dtype, saturate=True)
|
||||||
|
self.assertTrue(y_saturate.dtype == output_dtype)
|
||||||
|
self.assertAllClose(y_saturate.eval(), y_np, atol=1e-5)
|
||||||
|
|
||||||
def testNoConvert(self):
|
def testNoConvert(self):
|
||||||
# Make sure converting to the same data type creates only an identity op
|
# Make sure converting to the same data type creates only an identity op
|
||||||
@ -2670,6 +2676,8 @@ class ConvertImageTest(test_util.TensorFlowTestCase):
|
|||||||
with self.test_session(use_gpu=True):
|
with self.test_session(use_gpu=True):
|
||||||
self._convert([0, 255], dtypes.uint8, dtypes.int16, [0, 255 * 128])
|
self._convert([0, 255], dtypes.uint8, dtypes.int16, [0, 255 * 128])
|
||||||
self._convert([0, 32767], dtypes.int16, dtypes.uint8, [0, 255])
|
self._convert([0, 32767], dtypes.int16, dtypes.uint8, [0, 255])
|
||||||
|
self._convert([0, 2 ** 32], dtypes.int64, dtypes.int32, [0, 1])
|
||||||
|
self._convert([0, 1], dtypes.int32, dtypes.int64, [0, 2 ** 32])
|
||||||
|
|
||||||
def testConvertBetweenFloat(self):
|
def testConvertBetweenFloat(self):
|
||||||
# Make sure converting to between float types does nothing interesting
|
# Make sure converting to between float types does nothing interesting
|
||||||
|
Loading…
Reference in New Issue
Block a user