Support uint32 in tf.add.

PiperOrigin-RevId: 305382371
Change-Id: Iab3ad721f1cbffdabdff3d659ce75ba50512dd20
This commit is contained in:
A. Unique TensorFlower 2020-04-07 18:27:05 -07:00 committed by TensorFlower Gardener
parent 6f8771238a
commit 6e92134e1c
3 changed files with 9 additions and 4 deletions

View File

@ -26,12 +26,12 @@ REGISTER6(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64, uint8,
complex128, tstring);
// Notice: String is excluded to allow marking AddV2 is_commutative and
// is_aggregate.
REGISTER5(BinaryOp, CPU, "AddV2", functor::add, int8, int16, complex64, uint8,
complex128);
REGISTER6(BinaryOp, CPU, "AddV2", functor::add, int8, int16, uint32, complex64,
uint8, complex128);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER4(BinaryOp, GPU, "Add", functor::add, uint8, int64, complex64,
complex128);
REGISTER4(BinaryOp, GPU, "AddV2", functor::add, uint8, int64, complex64,
REGISTER5(BinaryOp, GPU, "AddV2", functor::add, uint8, uint32, int64, complex64,
complex128);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -19,7 +19,7 @@ limitations under the License.
namespace tensorflow {
namespace functor {
DEFINE_BINARY7(add, Eigen::half, float, double, uint8, int64, complex64,
DEFINE_BINARY8(add, Eigen::half, float, double, uint8, uint32, int64, complex64,
complex128);
} // namespace functor
} // namespace tensorflow

View File

@ -342,6 +342,11 @@ class BinaryOpTest(test.TestCase):
# _MOD for int32 on GPU by calling _compareGpu
self._compareGpu(x, y, np.mod, _MOD)
def testUint32Basic(self):
x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int32)
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int32)
self._compareBoth(x, y, np.add, math_ops.add)
def testInt64Basic(self):
x = np.arange(1 << 40, 13 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int64)