Added uint8 registration for addition operation (Fixes #10447)
This commit is contained in:
parent
3c4cb087e6
commit
4c0052dc4b
@ -22,10 +22,11 @@ namespace tensorflow {
|
|||||||
// sharded files, only make its register calls when not __ANDROID_TYPES_SLIM__.
|
// sharded files, only make its register calls when not __ANDROID_TYPES_SLIM__.
|
||||||
#if !defined(__ANDROID_TYPES_SLIM__)
|
#if !defined(__ANDROID_TYPES_SLIM__)
|
||||||
|
|
||||||
REGISTER5(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64,
|
REGISTER6(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64,
|
||||||
complex128, string);
|
uint8, complex128, string);
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
REGISTER3(BinaryOp, GPU, "Add", functor::add, int64, complex64, complex128);
|
REGISTER4(BinaryOp, GPU, "Add", functor::add, uint8, int64, complex64,
|
||||||
|
complex128);
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
#endif // !defined(__ANDROID_TYPES_SLIM__)
|
#endif // !defined(__ANDROID_TYPES_SLIM__)
|
||||||
|
@ -19,7 +19,8 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace functor {
|
namespace functor {
|
||||||
DEFINE_BINARY6(add, Eigen::half, float, double, int64, complex64, complex128);
|
DEFINE_BINARY7(add, Eigen::half, float, double, uint8, int64, complex64,
|
||||||
|
complex128);
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -707,6 +707,11 @@ class BinaryOpTest(test.TestCase):
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
tf_logging.warn("Cannot test special functions: %s" % str(e))
|
tf_logging.warn("Cannot test special functions: %s" % str(e))
|
||||||
|
|
||||||
|
def testUint8Basic(self):
|
||||||
|
x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint8)
|
||||||
|
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint8)
|
||||||
|
self._compareBoth(x, y, np.add, math_ops.add)
|
||||||
|
|
||||||
def testInt8Basic(self):
|
def testInt8Basic(self):
|
||||||
x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int8)
|
x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int8)
|
||||||
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int8)
|
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int8)
|
||||||
|
Loading…
Reference in New Issue
Block a user