[XLA:TPU] Add a bunch of F64 Convert ops. F64 <-> C64/U64/U32.
PiperOrigin-RevId: 306344563 Change-Id: Ia24263b9d51be5e057a1d3d9314a6c5fa647038b
This commit is contained in:
parent
b0726c4787
commit
20d0fd0d81
@ -808,19 +808,46 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
def testCast(self):
|
def testCast(self):
|
||||||
shapes = [[], [4], [2, 3], [2, 0, 4]]
|
shapes = [[], [4], [2, 3], [2, 0, 4]]
|
||||||
types = (
|
types = {
|
||||||
set([dtypes.bool, dtypes.int32, dtypes.float32])
|
dtypes.bool, dtypes.float32, dtypes.float64, dtypes.complex64,
|
||||||
| self.complex_tf_types)
|
dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64
|
||||||
for shape in shapes:
|
}
|
||||||
for src_type in types:
|
for src_type in types:
|
||||||
for dst_type in types:
|
for dst_type in types:
|
||||||
src = np.arange(np.prod(shape)).astype(src_type.as_numpy_dtype)
|
src_np_dtype = src_type.as_numpy_dtype
|
||||||
if src_type in self.complex_tf_types:
|
dst_np_dtype = dst_type.as_numpy_dtype
|
||||||
src += (np.arange(np.prod(shape)) * 2j).astype(
|
|
||||||
src_type.as_numpy_dtype)
|
|
||||||
src = src.reshape(shape)
|
|
||||||
|
|
||||||
dst = src.astype(dst_type.as_numpy_dtype)
|
for shape in shapes:
|
||||||
|
src = np.arange(np.prod(shape)).astype(src_np_dtype)
|
||||||
|
|
||||||
|
if src_type in self.complex_tf_types:
|
||||||
|
src += (np.arange(np.prod(shape)) * 2j).astype(src_np_dtype)
|
||||||
|
src = src.reshape(shape)
|
||||||
|
dst = src.astype(dst_np_dtype)
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
lambda x, dst_type=dst_type: math_ops.cast(x, dst_type),
|
||||||
|
src,
|
||||||
|
expected=dst)
|
||||||
|
|
||||||
|
# Check special values.
|
||||||
|
if src_type.is_integer:
|
||||||
|
imin = np.iinfo(src_np_dtype).min
|
||||||
|
imax = np.iinfo(src_np_dtype).max
|
||||||
|
src = np.array([imin, imax, 0, 1, -1], dtype=src_np_dtype)
|
||||||
|
elif src_type in self.float_tf_types:
|
||||||
|
if dst_type.is_integer:
|
||||||
|
imin = np.iinfo(dst_np_dtype).min
|
||||||
|
imax = np.iinfo(dst_np_dtype).max // 2
|
||||||
|
src = np.array([imin, imax, 0, 1], dtype=src_np_dtype)
|
||||||
|
elif dst_type in self.float_tf_types:
|
||||||
|
fmin = np.finfo(dst_np_dtype).min
|
||||||
|
fmax = np.finfo(dst_np_dtype).max
|
||||||
|
tiny = np.finfo(dst_np_dtype).tiny
|
||||||
|
eps = np.finfo(dst_np_dtype).eps
|
||||||
|
src = np.array(
|
||||||
|
[fmin, fmax, np.nan, eps, -eps, tiny, -tiny, np.inf, -np.inf],
|
||||||
|
dtype=src_np_dtype)
|
||||||
|
dst = src.astype(dst_np_dtype)
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
lambda x, dst_type=dst_type: math_ops.cast(x, dst_type),
|
lambda x, dst_type=dst_type: math_ops.cast(x, dst_type),
|
||||||
src,
|
src,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user