[XLA:TPU] Add a bunch of F64 Convert ops. F64 <-> C64/U64/U32.

PiperOrigin-RevId: 306344563
Change-Id: Ia24263b9d51be5e057a1d3d9314a6c5fa647038b
This commit is contained in:
Anudhyan Boral 2020-04-13 17:30:38 -07:00 committed by TensorFlower Gardener
parent b0726c4787
commit 20d0fd0d81

View File

@ -808,24 +808,51 @@ class UnaryOpsTest(xla_test.XLATestCase):
def testCast(self):
shapes = [[], [4], [2, 3], [2, 0, 4]]
types = (
set([dtypes.bool, dtypes.int32, dtypes.float32])
| self.complex_tf_types)
for shape in shapes:
for src_type in types:
for dst_type in types:
src = np.arange(np.prod(shape)).astype(src_type.as_numpy_dtype)
if src_type in self.complex_tf_types:
src += (np.arange(np.prod(shape)) * 2j).astype(
src_type.as_numpy_dtype)
src = src.reshape(shape)
types = {
dtypes.bool, dtypes.float32, dtypes.float64, dtypes.complex64,
dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64
}
for src_type in types:
for dst_type in types:
src_np_dtype = src_type.as_numpy_dtype
dst_np_dtype = dst_type.as_numpy_dtype
dst = src.astype(dst_type.as_numpy_dtype)
for shape in shapes:
src = np.arange(np.prod(shape)).astype(src_np_dtype)
if src_type in self.complex_tf_types:
src += (np.arange(np.prod(shape)) * 2j).astype(src_np_dtype)
src = src.reshape(shape)
dst = src.astype(dst_np_dtype)
self._assertOpOutputMatchesExpected(
lambda x, dst_type=dst_type: math_ops.cast(x, dst_type),
src,
expected=dst)
# Check special values.
if src_type.is_integer:
imin = np.iinfo(src_np_dtype).min
imax = np.iinfo(src_np_dtype).max
src = np.array([imin, imax, 0, 1, -1], dtype=src_np_dtype)
elif src_type in self.float_tf_types:
if dst_type.is_integer:
imin = np.iinfo(dst_np_dtype).min
imax = np.iinfo(dst_np_dtype).max // 2
src = np.array([imin, imax, 0, 1], dtype=src_np_dtype)
elif dst_type in self.float_tf_types:
fmin = np.finfo(dst_np_dtype).min
fmax = np.finfo(dst_np_dtype).max
tiny = np.finfo(dst_np_dtype).tiny
eps = np.finfo(dst_np_dtype).eps
src = np.array(
[fmin, fmax, np.nan, eps, -eps, tiny, -tiny, np.inf, -np.inf],
dtype=src_np_dtype)
dst = src.astype(dst_np_dtype)
self._assertOpOutputMatchesExpected(
lambda x, dst_type=dst_type: math_ops.cast(x, dst_type),
src,
expected=dst)
def testBitcast(self):
self._assertOpOutputMatchesExpected(
lambda x: array_ops.bitcast(x, dtypes.int32),