From cfc9a6852efb32d6965c2fe69ac3751eb0c382f5 Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Wed, 3 Jun 2020 15:55:39 -0700 Subject: [PATCH] Add GPU support for complex types in OneHot PiperOrigin-RevId: 314621230 Change-Id: I23bb95ff53cc2e8b16d73cbc1c22f42af6461170 --- tensorflow/core/kernels/one_hot_op.cc | 4 ++++ tensorflow/core/kernels/one_hot_op_gpu.cu.cc | 2 ++ 2 files changed, 6 insertions(+) diff --git a/tensorflow/core/kernels/one_hot_op.cc b/tensorflow/core/kernels/one_hot_op.cc index 0548e389b7a..3badbc294b7 100644 --- a/tensorflow/core/kernels/one_hot_op.cc +++ b/tensorflow/core/kernels/one_hot_op.cc @@ -164,6 +164,8 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); TF_CALL_bool(DECLARE_GPU_SPEC); TF_CALL_int32(DECLARE_GPU_SPEC); TF_CALL_int64(DECLARE_GPU_SPEC); +TF_CALL_complex64(DECLARE_GPU_SPEC); +TF_CALL_complex128(DECLARE_GPU_SPEC); #undef DECLARE_GPU_SPEC_INDEX #undef DECLARE_GPU_SPEC @@ -188,6 +190,8 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_ONE_HOT_GPU); TF_CALL_bool(REGISTER_ONE_HOT_GPU); TF_CALL_int32(REGISTER_ONE_HOT_GPU); TF_CALL_int64(REGISTER_ONE_HOT_GPU); +TF_CALL_complex64(REGISTER_ONE_HOT_GPU); +TF_CALL_complex128(REGISTER_ONE_HOT_GPU); #undef REGISTER_ONE_HOT_GPU_INDEX #undef REGISTER_ONE_HOT_GPU diff --git a/tensorflow/core/kernels/one_hot_op_gpu.cu.cc b/tensorflow/core/kernels/one_hot_op_gpu.cu.cc index 83ba272433f..8df7284caed 100644 --- a/tensorflow/core/kernels/one_hot_op_gpu.cu.cc +++ b/tensorflow/core/kernels/one_hot_op_gpu.cu.cc @@ -41,6 +41,8 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC); TF_CALL_bool(DEFINE_GPU_SPEC); TF_CALL_int32(DEFINE_GPU_SPEC); TF_CALL_int64(DEFINE_GPU_SPEC); +TF_CALL_complex64(DEFINE_GPU_SPEC); +TF_CALL_complex128(DEFINE_GPU_SPEC); #undef DEFINE_GPU_SPEC_INDEX #undef DEFINE_GPU_SPEC