Add registration for op AddV2, which is identical to Add, except that it does does not implement string concatenation. This allows us to mark AddV2 is_commutative and is_aggregate, which will allow optimizers more freedom.

PiperOrigin-RevId: 173931848
This commit is contained in:
A. Unique TensorFlower 2017-10-30 12:27:53 -07:00 committed by TensorFlower Gardener
parent 629e6d0c10
commit 1b6b7e208f
4 changed files with 47 additions and 2 deletions

View File

@ -18,9 +18,12 @@ limitations under the License.
namespace tensorflow {
REGISTER5(BinaryOp, CPU, "Add", functor::add, float, Eigen::half, double, int32,
int64);
REGISTER5(BinaryOp, CPU, "AddV2", functor::add, float, Eigen::half, double,
int32, int64);
#if GOOGLE_CUDA
REGISTER3(BinaryOp, GPU, "Add", functor::add, float, Eigen::half, double);
REGISTER3(BinaryOp, GPU, "AddV2", functor::add, float, Eigen::half, double);
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
@ -32,11 +35,21 @@ REGISTER_KERNEL_BUILDER(Name("Add")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::add<int32>>);
REGISTER_KERNEL_BUILDER(Name("AddV2")
.Device(DEVICE_GPU)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::add<int32>>);
#endif
#if TENSORFLOW_USE_SYCL
#define REGISTER_KERNEL(type) REGISTER(BinaryOp, SYCL, "Add", functor::add, type);
#define REGISTER_KERNEL(type) \
REGISTER(BinaryOp, SYCL, "Add", functor::add, type); \
REEGISTER(BinaryOp, SYCL, "AddV2", functor::add, type);
TF_CALL_SYCL_NUMBER_TYPES(REGISTER_KERNEL);
REGISTER_KERNEL_BUILDER(Name("Add")
@ -46,5 +59,12 @@ REGISTER_KERNEL_BUILDER(Name("Add")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::add<int32>>);
REGISTER_KERNEL_BUILDER(Name("AddV2")
.Device(DEVICE_SYCL)
.HostMemory("x")
.HostMemory("y")
.HostMemory("z")
.TypeConstraint<int32>("T"),
BinaryOp<CPUDevice, functor::add<int32>>);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -24,9 +24,15 @@ namespace tensorflow {
REGISTER6(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64,
uint8, complex128, string);
// Notice: String is excluded to allow marking AddV2 is_commutative and
// is_aggregate.
REGISTER5(BinaryOp, CPU, "AddV2", functor::add, int8, int16, complex64, uint8,
complex128);
#if GOOGLE_CUDA
REGISTER4(BinaryOp, GPU, "Add", functor::add, uint8, int64, complex64,
complex128);
REGISTER4(BinaryOp, GPU, "AddV2", functor::add, uint8, int64, complex64,
complex128);
#endif // GOOGLE_CUDA
#endif // !defined(__ANDROID_TYPES_SLIM__)

View File

@ -514,7 +514,6 @@ rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.]
Input("x: T").Input("y: T").Output("z: T").Attr( \
"T: {half, float, double, int32, int64, complex64, complex128}")
// TODO(mrry): Restore `SetIsCommutative()` for non-string types.
REGISTER_OP("Add")
.Input("x: T")
.Input("y: T")
@ -530,6 +529,25 @@ Returns x + y element-wise.
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
)doc");
// TODO(rmlarsen): Add a Python wrapper that swiches non-string instances to
// use AddV2 (b/68646025).
REGISTER_OP("AddV2")
.Input("x: T")
.Input("y: T")
.Output("z: T")
.Attr(
"T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
"complex128}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.SetIsAggregate()
.SetIsCommutative()
.Doc(R"doc(
Returns x + y element-wise.
*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
)doc");
REGISTER_OP("_MklAdd")
.Input("x: T")
.Input("y: T")

View File

@ -244,6 +244,7 @@ TensorSummaryV2
Abs
AccumulateNV2
AddN
AddV2
All
Any
BatchMatMul