Add registration for op AddV2, which is identical to Add, except that it does does not implement string concatenation. This allows us to mark AddV2 is_commutative and is_aggregate, which will allow optimizers more freedom.
PiperOrigin-RevId: 173931848
This commit is contained in:
parent
629e6d0c10
commit
1b6b7e208f
@ -18,9 +18,12 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
REGISTER5(BinaryOp, CPU, "Add", functor::add, float, Eigen::half, double, int32,
|
||||
int64);
|
||||
REGISTER5(BinaryOp, CPU, "AddV2", functor::add, float, Eigen::half, double,
|
||||
int32, int64);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER3(BinaryOp, GPU, "Add", functor::add, float, Eigen::half, double);
|
||||
REGISTER3(BinaryOp, GPU, "AddV2", functor::add, float, Eigen::half, double);
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
@ -32,11 +35,21 @@ REGISTER_KERNEL_BUILDER(Name("Add")
|
||||
.HostMemory("z")
|
||||
.TypeConstraint<int32>("T"),
|
||||
BinaryOp<CPUDevice, functor::add<int32>>);
|
||||
REGISTER_KERNEL_BUILDER(Name("AddV2")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("x")
|
||||
.HostMemory("y")
|
||||
.HostMemory("z")
|
||||
.TypeConstraint<int32>("T"),
|
||||
BinaryOp<CPUDevice, functor::add<int32>>);
|
||||
#endif
|
||||
|
||||
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
#define REGISTER_KERNEL(type) REGISTER(BinaryOp, SYCL, "Add", functor::add, type);
|
||||
#define REGISTER_KERNEL(type) \
|
||||
REGISTER(BinaryOp, SYCL, "Add", functor::add, type); \
|
||||
REEGISTER(BinaryOp, SYCL, "AddV2", functor::add, type);
|
||||
|
||||
TF_CALL_SYCL_NUMBER_TYPES(REGISTER_KERNEL);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Add")
|
||||
@ -46,5 +59,12 @@ REGISTER_KERNEL_BUILDER(Name("Add")
|
||||
.HostMemory("z")
|
||||
.TypeConstraint<int32>("T"),
|
||||
BinaryOp<CPUDevice, functor::add<int32>>);
|
||||
REGISTER_KERNEL_BUILDER(Name("AddV2")
|
||||
.Device(DEVICE_SYCL)
|
||||
.HostMemory("x")
|
||||
.HostMemory("y")
|
||||
.HostMemory("z")
|
||||
.TypeConstraint<int32>("T"),
|
||||
BinaryOp<CPUDevice, functor::add<int32>>);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
} // namespace tensorflow
|
||||
|
@ -24,9 +24,15 @@ namespace tensorflow {
|
||||
|
||||
REGISTER6(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64,
|
||||
uint8, complex128, string);
|
||||
// Notice: String is excluded to allow marking AddV2 is_commutative and
|
||||
// is_aggregate.
|
||||
REGISTER5(BinaryOp, CPU, "AddV2", functor::add, int8, int16, complex64, uint8,
|
||||
complex128);
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER4(BinaryOp, GPU, "Add", functor::add, uint8, int64, complex64,
|
||||
complex128);
|
||||
REGISTER4(BinaryOp, GPU, "AddV2", functor::add, uint8, int64, complex64,
|
||||
complex128);
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#endif // !defined(__ANDROID_TYPES_SLIM__)
|
||||
|
@ -514,7 +514,6 @@ rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.]
|
||||
Input("x: T").Input("y: T").Output("z: T").Attr( \
|
||||
"T: {half, float, double, int32, int64, complex64, complex128}")
|
||||
|
||||
// TODO(mrry): Restore `SetIsCommutative()` for non-string types.
|
||||
REGISTER_OP("Add")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
@ -530,6 +529,25 @@ Returns x + y element-wise.
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||
)doc");
|
||||
|
||||
// TODO(rmlarsen): Add a Python wrapper that swiches non-string instances to
|
||||
// use AddV2 (b/68646025).
|
||||
REGISTER_OP("AddV2")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
.Output("z: T")
|
||||
.Attr(
|
||||
"T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
|
||||
"complex128}")
|
||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
|
||||
.SetIsAggregate()
|
||||
.SetIsCommutative()
|
||||
.Doc(R"doc(
|
||||
Returns x + y element-wise.
|
||||
|
||||
*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("_MklAdd")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
|
@ -244,6 +244,7 @@ TensorSummaryV2
|
||||
Abs
|
||||
AccumulateNV2
|
||||
AddN
|
||||
AddV2
|
||||
All
|
||||
Any
|
||||
BatchMatMul
|
||||
|
Loading…
x
Reference in New Issue
Block a user