From 4c0052dc4b7c49a876166113b49960a57f7db939 Mon Sep 17 00:00:00 2001 From: Lakshay Garg Date: Fri, 9 Jun 2017 12:02:13 +0530 Subject: [PATCH] Added uint8 registration for addition operation (Fixes #10447) --- tensorflow/core/kernels/cwise_op_add_2.cc | 7 ++++--- tensorflow/core/kernels/cwise_op_gpu_add.cu.cc | 3 ++- tensorflow/python/kernel_tests/cwise_ops_test.py | 5 +++++ 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/kernels/cwise_op_add_2.cc b/tensorflow/core/kernels/cwise_op_add_2.cc index 5d3385b0ed6..5dea00e95c7 100644 --- a/tensorflow/core/kernels/cwise_op_add_2.cc +++ b/tensorflow/core/kernels/cwise_op_add_2.cc @@ -22,10 +22,11 @@ namespace tensorflow { // sharded files, only make its register calls when not __ANDROID_TYPES_SLIM__. #if !defined(__ANDROID_TYPES_SLIM__) -REGISTER5(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64, - complex128, string); +REGISTER6(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64, + uint8, complex128, string); #if GOOGLE_CUDA -REGISTER3(BinaryOp, GPU, "Add", functor::add, int64, complex64, complex128); +REGISTER4(BinaryOp, GPU, "Add", functor::add, uint8, int64, complex64, + complex128); #endif // GOOGLE_CUDA #endif // !defined(__ANDROID_TYPES_SLIM__) diff --git a/tensorflow/core/kernels/cwise_op_gpu_add.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_add.cu.cc index 5aaf2b5b4b8..61079ebab39 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_add.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_add.cu.cc @@ -19,7 +19,8 @@ limitations under the License. namespace tensorflow { namespace functor { -DEFINE_BINARY6(add, Eigen::half, float, double, int64, complex64, complex128); +DEFINE_BINARY7(add, Eigen::half, float, double, uint8, int64, complex64, + complex128); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 54810cdc342..b47139e6b8b 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -707,6 +707,11 @@ class BinaryOpTest(test.TestCase): except ImportError as e: tf_logging.warn("Cannot test special functions: %s" % str(e)) + def testUint8Basic(self): + x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint8) + y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint8) + self._compareBoth(x, y, np.add, math_ops.add) + def testInt8Basic(self): x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int8) y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int8)