Support uint32 in tf.add.
PiperOrigin-RevId: 305382371 Change-Id: Iab3ad721f1cbffdabdff3d659ce75ba50512dd20
This commit is contained in:
parent
6f8771238a
commit
6e92134e1c
tensorflow
@ -26,12 +26,12 @@ REGISTER6(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64, uint8,
|
||||
complex128, tstring);
|
||||
// Notice: String is excluded to allow marking AddV2 is_commutative and
|
||||
// is_aggregate.
|
||||
REGISTER5(BinaryOp, CPU, "AddV2", functor::add, int8, int16, complex64, uint8,
|
||||
complex128);
|
||||
REGISTER6(BinaryOp, CPU, "AddV2", functor::add, int8, int16, uint32, complex64,
|
||||
uint8, complex128);
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER4(BinaryOp, GPU, "Add", functor::add, uint8, int64, complex64,
|
||||
complex128);
|
||||
REGISTER4(BinaryOp, GPU, "AddV2", functor::add, uint8, int64, complex64,
|
||||
REGISTER5(BinaryOp, GPU, "AddV2", functor::add, uint8, uint32, int64, complex64,
|
||||
complex128);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
DEFINE_BINARY7(add, Eigen::half, float, double, uint8, int64, complex64,
|
||||
DEFINE_BINARY8(add, Eigen::half, float, double, uint8, uint32, int64, complex64,
|
||||
complex128);
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
@ -342,6 +342,11 @@ class BinaryOpTest(test.TestCase):
|
||||
# _MOD for int32 on GPU by calling _compareGpu
|
||||
self._compareGpu(x, y, np.mod, _MOD)
|
||||
|
||||
def testUint32Basic(self):
|
||||
x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int32)
|
||||
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int32)
|
||||
self._compareBoth(x, y, np.add, math_ops.add)
|
||||
|
||||
def testInt64Basic(self):
|
||||
x = np.arange(1 << 40, 13 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
|
||||
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int64)
|
||||
|
Loading…
Reference in New Issue
Block a user