Resubmit of PR 38848 but only support complex64 and complex128

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2020-04-23 20:24:06 +00:00
parent 57f6650fbb
commit 71a54bdbc1
2 changed files with 9 additions and 1 deletions

View File

@ -231,6 +231,8 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
IsVariableInitializedOp);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_complex64(REGISTER_GPU_KERNELS);
TF_CALL_complex128(REGISTER_GPU_KERNELS);
TF_CALL_int64(REGISTER_GPU_KERNELS);
TF_CALL_uint32(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS

View File

@ -33,10 +33,13 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
_NP_TO_TF = {
np.float16: dtypes.float16,
np.float32: dtypes.float32,
np.float64: dtypes.float64,
np.int32: dtypes.int32,
np.int64: dtypes.int64,
np.complex64: dtypes.complex64,
np.complex128: dtypes.complex128,
}
@ -50,7 +53,10 @@ class VariableOpTest(test.TestCase):
return self.evaluate(p)
def _testTypes(self, vals):
for dtype in [np.float32, np.float64, np.int32, np.int64]:
for dtype in [
np.float16, np.float32, np.float64,
np.complex64, np.complex128,
np.int32, np.int64]:
self.setUp()
x = vals.astype(dtype)
tftype = _NP_TO_TF[dtype]