Added uint8 registration for addition operation (Fixes #10447)

This commit is contained in:
Lakshay Garg 2017-06-09 12:02:13 +05:30 committed by Martin Wicke
parent 3c4cb087e6
commit 4c0052dc4b
3 changed files with 11 additions and 4 deletions

View File

@ -22,10 +22,11 @@ namespace tensorflow {
// sharded files, only make its register calls when not __ANDROID_TYPES_SLIM__.
#if !defined(__ANDROID_TYPES_SLIM__)
REGISTER5(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64,
complex128, string);
REGISTER6(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64,
uint8, complex128, string);
#if GOOGLE_CUDA
REGISTER3(BinaryOp, GPU, "Add", functor::add, int64, complex64, complex128);
REGISTER4(BinaryOp, GPU, "Add", functor::add, uint8, int64, complex64,
complex128);
#endif // GOOGLE_CUDA
#endif // !defined(__ANDROID_TYPES_SLIM__)

View File

@ -19,7 +19,8 @@ limitations under the License.
namespace tensorflow {
namespace functor {
DEFINE_BINARY6(add, Eigen::half, float, double, int64, complex64, complex128);
DEFINE_BINARY7(add, Eigen::half, float, double, uint8, int64, complex64,
complex128);
} // namespace functor
} // namespace tensorflow

View File

@ -707,6 +707,11 @@ class BinaryOpTest(test.TestCase):
except ImportError as e:
tf_logging.warn("Cannot test special functions: %s" % str(e))
def testUint8Basic(self):
x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint8)
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint8)
self._compareBoth(x, y, np.add, math_ops.add)
def testInt8Basic(self):
x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int8)
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int8)