[XLA:TPU] Add a bunch of F64 Convert ops. F64 <-> C64/U64/U32.
PiperOrigin-RevId: 306344563 Change-Id: Ia24263b9d51be5e057a1d3d9314a6c5fa647038b
This commit is contained in:
parent
b0726c4787
commit
20d0fd0d81
@ -808,24 +808,51 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def testCast(self):
|
||||
shapes = [[], [4], [2, 3], [2, 0, 4]]
|
||||
types = (
|
||||
set([dtypes.bool, dtypes.int32, dtypes.float32])
|
||||
| self.complex_tf_types)
|
||||
for shape in shapes:
|
||||
for src_type in types:
|
||||
for dst_type in types:
|
||||
src = np.arange(np.prod(shape)).astype(src_type.as_numpy_dtype)
|
||||
if src_type in self.complex_tf_types:
|
||||
src += (np.arange(np.prod(shape)) * 2j).astype(
|
||||
src_type.as_numpy_dtype)
|
||||
src = src.reshape(shape)
|
||||
types = {
|
||||
dtypes.bool, dtypes.float32, dtypes.float64, dtypes.complex64,
|
||||
dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64
|
||||
}
|
||||
for src_type in types:
|
||||
for dst_type in types:
|
||||
src_np_dtype = src_type.as_numpy_dtype
|
||||
dst_np_dtype = dst_type.as_numpy_dtype
|
||||
|
||||
dst = src.astype(dst_type.as_numpy_dtype)
|
||||
for shape in shapes:
|
||||
src = np.arange(np.prod(shape)).astype(src_np_dtype)
|
||||
|
||||
if src_type in self.complex_tf_types:
|
||||
src += (np.arange(np.prod(shape)) * 2j).astype(src_np_dtype)
|
||||
src = src.reshape(shape)
|
||||
dst = src.astype(dst_np_dtype)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
lambda x, dst_type=dst_type: math_ops.cast(x, dst_type),
|
||||
src,
|
||||
expected=dst)
|
||||
|
||||
# Check special values.
|
||||
if src_type.is_integer:
|
||||
imin = np.iinfo(src_np_dtype).min
|
||||
imax = np.iinfo(src_np_dtype).max
|
||||
src = np.array([imin, imax, 0, 1, -1], dtype=src_np_dtype)
|
||||
elif src_type in self.float_tf_types:
|
||||
if dst_type.is_integer:
|
||||
imin = np.iinfo(dst_np_dtype).min
|
||||
imax = np.iinfo(dst_np_dtype).max // 2
|
||||
src = np.array([imin, imax, 0, 1], dtype=src_np_dtype)
|
||||
elif dst_type in self.float_tf_types:
|
||||
fmin = np.finfo(dst_np_dtype).min
|
||||
fmax = np.finfo(dst_np_dtype).max
|
||||
tiny = np.finfo(dst_np_dtype).tiny
|
||||
eps = np.finfo(dst_np_dtype).eps
|
||||
src = np.array(
|
||||
[fmin, fmax, np.nan, eps, -eps, tiny, -tiny, np.inf, -np.inf],
|
||||
dtype=src_np_dtype)
|
||||
dst = src.astype(dst_np_dtype)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
lambda x, dst_type=dst_type: math_ops.cast(x, dst_type),
|
||||
src,
|
||||
expected=dst)
|
||||
|
||||
def testBitcast(self):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
lambda x: array_ops.bitcast(x, dtypes.int32),
|
||||
|
Loading…
Reference in New Issue
Block a user