Added uint8 registration for addition operation (Fixes #10447)
This commit is contained in:
parent
3c4cb087e6
commit
4c0052dc4b
@ -22,10 +22,11 @@ namespace tensorflow {
|
||||
// sharded files, only make its register calls when not __ANDROID_TYPES_SLIM__.
|
||||
#if !defined(__ANDROID_TYPES_SLIM__)
|
||||
|
||||
REGISTER5(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64,
|
||||
complex128, string);
|
||||
REGISTER6(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64,
|
||||
uint8, complex128, string);
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER3(BinaryOp, GPU, "Add", functor::add, int64, complex64, complex128);
|
||||
REGISTER4(BinaryOp, GPU, "Add", functor::add, uint8, int64, complex64,
|
||||
complex128);
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#endif // !defined(__ANDROID_TYPES_SLIM__)
|
||||
|
@ -19,7 +19,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
DEFINE_BINARY6(add, Eigen::half, float, double, int64, complex64, complex128);
|
||||
DEFINE_BINARY7(add, Eigen::half, float, double, uint8, int64, complex64,
|
||||
complex128);
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -707,6 +707,11 @@ class BinaryOpTest(test.TestCase):
|
||||
except ImportError as e:
|
||||
tf_logging.warn("Cannot test special functions: %s" % str(e))
|
||||
|
||||
def testUint8Basic(self):
|
||||
x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint8)
|
||||
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint8)
|
||||
self._compareBoth(x, y, np.add, math_ops.add)
|
||||
|
||||
def testInt8Basic(self):
|
||||
x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int8)
|
||||
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int8)
|
||||
|
Loading…
Reference in New Issue
Block a user