Convert one_hot inputs to passed in dtype.
This allows one to write code like this without one_hot raising a TypeError complaining that the float32 inputs don't match the dtype of float16: tf.one_hot(indices=[0, 1, 2], depth=3, on_value=1., off_value=0., dtype=tf.float16) Also fix issue with several tests, where previously they were mostly run in float32 even if the test was intended to run in a different dtype. PiperOrigin-RevId: 296548854 Change-Id: I44881c5a5a007e255671d86808b73015d56dfa94
This commit is contained in:
parent
ce9564d430
commit
1015e48633
tensorflow/python
@ -33,16 +33,19 @@ class OneHotTest(test.TestCase):
|
||||
use_gpu=False,
|
||||
expected_err_re=None,
|
||||
raises=None,
|
||||
dtype=None,
|
||||
**inputs):
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
if raises is not None:
|
||||
with self.assertRaises(raises):
|
||||
array_ops.one_hot(**inputs)
|
||||
array_ops.one_hot(dtype=dtype, **inputs)
|
||||
else:
|
||||
ans = array_ops.one_hot(**inputs)
|
||||
ans = array_ops.one_hot(dtype=dtype, **inputs)
|
||||
if expected_err_re is None:
|
||||
tf_ans = self.evaluate(ans)
|
||||
self.assertAllEqual(tf_ans, truth)
|
||||
if dtype:
|
||||
self.assertEqual(tf_ans.dtype, dtype)
|
||||
self.assertEqual(tf_ans.shape, ans.get_shape())
|
||||
else:
|
||||
with self.assertRaisesOpError(expected_err_re):
|
||||
@ -91,13 +94,16 @@ class OneHotTest(test.TestCase):
|
||||
dtype=dtype)
|
||||
|
||||
# axis == -1
|
||||
self._testBothOneHot(indices=indices, depth=depth, truth=truth)
|
||||
self._testBothOneHot(indices=indices, depth=depth, dtype=dtype, truth=truth)
|
||||
|
||||
# axis == 0
|
||||
self._testBothOneHot(
|
||||
indices=indices, depth=depth, axis=0,
|
||||
indices=indices, depth=depth, axis=0, dtype=dtype,
|
||||
truth=truth.T) # Output is transpose version in this case
|
||||
|
||||
def testDefaultNoDtype(self):
|
||||
self._testDefaultBasic(None)
|
||||
|
||||
def testFloatBasic(self):
|
||||
self._testBasic(np.float32)
|
||||
self._testDefaultBasic(np.float32)
|
||||
@ -303,7 +309,6 @@ class OneHotTest(test.TestCase):
|
||||
depth=depth,
|
||||
on_value=on_value,
|
||||
off_value=off_value,
|
||||
dtype=dtypes.string,
|
||||
truth=truth)
|
||||
|
||||
on_value = constant_op.constant(b"1.0")
|
||||
@ -313,7 +318,6 @@ class OneHotTest(test.TestCase):
|
||||
depth=depth,
|
||||
on_value=on_value,
|
||||
off_value=off_value,
|
||||
dtype=dtypes.string,
|
||||
truth=truth)
|
||||
|
||||
on_value = b"1.0"
|
||||
@ -323,7 +327,6 @@ class OneHotTest(test.TestCase):
|
||||
depth=depth,
|
||||
on_value=on_value,
|
||||
off_value=off_value,
|
||||
dtype=dtypes.string,
|
||||
truth=truth)
|
||||
|
||||
def testIndicesTypes(self):
|
||||
@ -400,8 +403,8 @@ class OneHotTest(test.TestCase):
|
||||
def testDtypeMismatchTypeError(self):
|
||||
indices = [0, 1, 2]
|
||||
depth = 3
|
||||
on_value = np.asarray(1.0, np.float32)
|
||||
off_value = np.asarray(0.0, np.float32)
|
||||
on_value = constant_op.constant(1.0, dtypes.float32)
|
||||
off_value = constant_op.constant(0.0, dtypes.float32)
|
||||
dtype = np.int32
|
||||
|
||||
self._testBothOneHot(
|
||||
@ -420,6 +423,37 @@ class OneHotTest(test.TestCase):
|
||||
truth=None,
|
||||
raises=TypeError)
|
||||
|
||||
def testConvertToTensorOfCorrectDtype(self):
|
||||
indices = [0, 1, 2]
|
||||
depth = 3
|
||||
dtype = np.float16
|
||||
truth = np.asarray([[1, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1]])
|
||||
self._testBothOneHot(
|
||||
truth=truth,
|
||||
indices=indices,
|
||||
depth=depth,
|
||||
on_value=1.0,
|
||||
off_value=constant_op.constant(0.0, dtype),
|
||||
dtype=dtype)
|
||||
|
||||
self._testBothOneHot(
|
||||
truth=truth,
|
||||
indices=indices,
|
||||
depth=depth,
|
||||
on_value=constant_op.constant(1.0, dtype),
|
||||
off_value=0.,
|
||||
dtype=dtype)
|
||||
|
||||
self._testBothOneHot(
|
||||
truth=truth,
|
||||
indices=indices,
|
||||
depth=depth,
|
||||
on_value=1.0,
|
||||
off_value=0.,
|
||||
dtype=dtype)
|
||||
|
||||
def testOneHotUint8WithLargeArray(self):
|
||||
with self.cached_session(use_gpu=False) as sess:
|
||||
matrix = np.random.rand(256) * 10
|
||||
|
@ -3949,11 +3949,13 @@ def one_hot(indices,
|
||||
on_exists = on_value is not None
|
||||
off_exists = off_value is not None
|
||||
|
||||
on_dtype = (
|
||||
ops.convert_to_tensor(on_value).dtype.base_dtype if on_exists else None)
|
||||
off_dtype = (
|
||||
ops.convert_to_tensor(off_value).dtype.base_dtype
|
||||
if off_exists else None)
|
||||
if on_exists:
|
||||
on_value = ops.convert_to_tensor(on_value, dtype_hint=dtype)
|
||||
if off_exists:
|
||||
off_value = ops.convert_to_tensor(off_value, dtype_hint=dtype)
|
||||
|
||||
on_dtype = on_value.dtype.base_dtype if on_exists else None
|
||||
off_dtype = off_value.dtype.base_dtype if off_exists else None
|
||||
|
||||
if on_exists or off_exists:
|
||||
if dtype is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user