Convert one_hot inputs to passed in dtype.

This allows one to write code like this without one_hot raising a TypeError complaining that the float32 inputs don't match the dtype of float16:

tf.one_hot(indices=[0, 1, 2], depth=3, on_value=1., off_value=0., dtype=tf.float16)

Also fix issue with several tests, where previously they were mostly run in float32 even if the test was intended to run in a different dtype.

PiperOrigin-RevId: 296548854
Change-Id: I44881c5a5a007e255671d86808b73015d56dfa94
This commit is contained in:
Reed Wanderman-Milne 2020-02-21 18:13:18 -08:00 committed by TensorFlower Gardener
parent ce9564d430
commit 1015e48633
2 changed files with 50 additions and 14 deletions
tensorflow/python

View File

@ -33,16 +33,19 @@ class OneHotTest(test.TestCase):
use_gpu=False,
expected_err_re=None,
raises=None,
dtype=None,
**inputs):
with self.cached_session(use_gpu=use_gpu):
if raises is not None:
with self.assertRaises(raises):
array_ops.one_hot(**inputs)
array_ops.one_hot(dtype=dtype, **inputs)
else:
ans = array_ops.one_hot(**inputs)
ans = array_ops.one_hot(dtype=dtype, **inputs)
if expected_err_re is None:
tf_ans = self.evaluate(ans)
self.assertAllEqual(tf_ans, truth)
if dtype:
self.assertEqual(tf_ans.dtype, dtype)
self.assertEqual(tf_ans.shape, ans.get_shape())
else:
with self.assertRaisesOpError(expected_err_re):
@ -91,13 +94,16 @@ class OneHotTest(test.TestCase):
dtype=dtype)
# axis == -1
self._testBothOneHot(indices=indices, depth=depth, truth=truth)
self._testBothOneHot(indices=indices, depth=depth, dtype=dtype, truth=truth)
# axis == 0
self._testBothOneHot(
indices=indices, depth=depth, axis=0,
indices=indices, depth=depth, axis=0, dtype=dtype,
truth=truth.T) # Output is transpose version in this case
def testDefaultNoDtype(self):
self._testDefaultBasic(None)
def testFloatBasic(self):
self._testBasic(np.float32)
self._testDefaultBasic(np.float32)
@ -303,7 +309,6 @@ class OneHotTest(test.TestCase):
depth=depth,
on_value=on_value,
off_value=off_value,
dtype=dtypes.string,
truth=truth)
on_value = constant_op.constant(b"1.0")
@ -313,7 +318,6 @@ class OneHotTest(test.TestCase):
depth=depth,
on_value=on_value,
off_value=off_value,
dtype=dtypes.string,
truth=truth)
on_value = b"1.0"
@ -323,7 +327,6 @@ class OneHotTest(test.TestCase):
depth=depth,
on_value=on_value,
off_value=off_value,
dtype=dtypes.string,
truth=truth)
def testIndicesTypes(self):
@ -400,8 +403,8 @@ class OneHotTest(test.TestCase):
def testDtypeMismatchTypeError(self):
indices = [0, 1, 2]
depth = 3
on_value = np.asarray(1.0, np.float32)
off_value = np.asarray(0.0, np.float32)
on_value = constant_op.constant(1.0, dtypes.float32)
off_value = constant_op.constant(0.0, dtypes.float32)
dtype = np.int32
self._testBothOneHot(
@ -420,6 +423,37 @@ class OneHotTest(test.TestCase):
truth=None,
raises=TypeError)
def testConvertToTensorOfCorrectDtype(self):
indices = [0, 1, 2]
depth = 3
dtype = np.float16
truth = np.asarray([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]])
self._testBothOneHot(
truth=truth,
indices=indices,
depth=depth,
on_value=1.0,
off_value=constant_op.constant(0.0, dtype),
dtype=dtype)
self._testBothOneHot(
truth=truth,
indices=indices,
depth=depth,
on_value=constant_op.constant(1.0, dtype),
off_value=0.,
dtype=dtype)
self._testBothOneHot(
truth=truth,
indices=indices,
depth=depth,
on_value=1.0,
off_value=0.,
dtype=dtype)
def testOneHotUint8WithLargeArray(self):
with self.cached_session(use_gpu=False) as sess:
matrix = np.random.rand(256) * 10

View File

@ -3949,11 +3949,13 @@ def one_hot(indices,
on_exists = on_value is not None
off_exists = off_value is not None
on_dtype = (
ops.convert_to_tensor(on_value).dtype.base_dtype if on_exists else None)
off_dtype = (
ops.convert_to_tensor(off_value).dtype.base_dtype
if off_exists else None)
if on_exists:
on_value = ops.convert_to_tensor(on_value, dtype_hint=dtype)
if off_exists:
off_value = ops.convert_to_tensor(off_value, dtype_hint=dtype)
on_dtype = on_value.dtype.base_dtype if on_exists else None
off_dtype = off_value.dtype.base_dtype if off_exists else None
if on_exists or off_exists:
if dtype is not None: