Add registration for op AddV2, which is identical to Add, except that it does does not implement string concatenation. This allows us to mark AddV2 is_commutative and is_aggregate, which will allow optimizers more freedom.
PiperOrigin-RevId: 173931848
This commit is contained in:
parent
629e6d0c10
commit
1b6b7e208f
@ -18,9 +18,12 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
REGISTER5(BinaryOp, CPU, "Add", functor::add, float, Eigen::half, double, int32,
|
REGISTER5(BinaryOp, CPU, "Add", functor::add, float, Eigen::half, double, int32,
|
||||||
int64);
|
int64);
|
||||||
|
REGISTER5(BinaryOp, CPU, "AddV2", functor::add, float, Eigen::half, double,
|
||||||
|
int32, int64);
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
REGISTER3(BinaryOp, GPU, "Add", functor::add, float, Eigen::half, double);
|
REGISTER3(BinaryOp, GPU, "Add", functor::add, float, Eigen::half, double);
|
||||||
|
REGISTER3(BinaryOp, GPU, "AddV2", functor::add, float, Eigen::half, double);
|
||||||
|
|
||||||
// A special GPU kernel for int32.
|
// A special GPU kernel for int32.
|
||||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||||
@ -32,11 +35,21 @@ REGISTER_KERNEL_BUILDER(Name("Add")
|
|||||||
.HostMemory("z")
|
.HostMemory("z")
|
||||||
.TypeConstraint<int32>("T"),
|
.TypeConstraint<int32>("T"),
|
||||||
BinaryOp<CPUDevice, functor::add<int32>>);
|
BinaryOp<CPUDevice, functor::add<int32>>);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("AddV2")
|
||||||
|
.Device(DEVICE_GPU)
|
||||||
|
.HostMemory("x")
|
||||||
|
.HostMemory("y")
|
||||||
|
.HostMemory("z")
|
||||||
|
.TypeConstraint<int32>("T"),
|
||||||
|
BinaryOp<CPUDevice, functor::add<int32>>);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
#if TENSORFLOW_USE_SYCL
|
#if TENSORFLOW_USE_SYCL
|
||||||
#define REGISTER_KERNEL(type) REGISTER(BinaryOp, SYCL, "Add", functor::add, type);
|
#define REGISTER_KERNEL(type) \
|
||||||
|
REGISTER(BinaryOp, SYCL, "Add", functor::add, type); \
|
||||||
|
REEGISTER(BinaryOp, SYCL, "AddV2", functor::add, type);
|
||||||
|
|
||||||
TF_CALL_SYCL_NUMBER_TYPES(REGISTER_KERNEL);
|
TF_CALL_SYCL_NUMBER_TYPES(REGISTER_KERNEL);
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("Add")
|
REGISTER_KERNEL_BUILDER(Name("Add")
|
||||||
@ -46,5 +59,12 @@ REGISTER_KERNEL_BUILDER(Name("Add")
|
|||||||
.HostMemory("z")
|
.HostMemory("z")
|
||||||
.TypeConstraint<int32>("T"),
|
.TypeConstraint<int32>("T"),
|
||||||
BinaryOp<CPUDevice, functor::add<int32>>);
|
BinaryOp<CPUDevice, functor::add<int32>>);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("AddV2")
|
||||||
|
.Device(DEVICE_SYCL)
|
||||||
|
.HostMemory("x")
|
||||||
|
.HostMemory("y")
|
||||||
|
.HostMemory("z")
|
||||||
|
.TypeConstraint<int32>("T"),
|
||||||
|
BinaryOp<CPUDevice, functor::add<int32>>);
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -24,9 +24,15 @@ namespace tensorflow {
|
|||||||
|
|
||||||
REGISTER6(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64,
|
REGISTER6(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64,
|
||||||
uint8, complex128, string);
|
uint8, complex128, string);
|
||||||
|
// Notice: String is excluded to allow marking AddV2 is_commutative and
|
||||||
|
// is_aggregate.
|
||||||
|
REGISTER5(BinaryOp, CPU, "AddV2", functor::add, int8, int16, complex64, uint8,
|
||||||
|
complex128);
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
REGISTER4(BinaryOp, GPU, "Add", functor::add, uint8, int64, complex64,
|
REGISTER4(BinaryOp, GPU, "Add", functor::add, uint8, int64, complex64,
|
||||||
complex128);
|
complex128);
|
||||||
|
REGISTER4(BinaryOp, GPU, "AddV2", functor::add, uint8, int64, complex64,
|
||||||
|
complex128);
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
#endif // !defined(__ANDROID_TYPES_SLIM__)
|
#endif // !defined(__ANDROID_TYPES_SLIM__)
|
||||||
|
@ -514,7 +514,6 @@ rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.]
|
|||||||
Input("x: T").Input("y: T").Output("z: T").Attr( \
|
Input("x: T").Input("y: T").Output("z: T").Attr( \
|
||||||
"T: {half, float, double, int32, int64, complex64, complex128}")
|
"T: {half, float, double, int32, int64, complex64, complex128}")
|
||||||
|
|
||||||
// TODO(mrry): Restore `SetIsCommutative()` for non-string types.
|
|
||||||
REGISTER_OP("Add")
|
REGISTER_OP("Add")
|
||||||
.Input("x: T")
|
.Input("x: T")
|
||||||
.Input("y: T")
|
.Input("y: T")
|
||||||
@ -530,6 +529,25 @@ Returns x + y element-wise.
|
|||||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
// TODO(rmlarsen): Add a Python wrapper that swiches non-string instances to
|
||||||
|
// use AddV2 (b/68646025).
|
||||||
|
REGISTER_OP("AddV2")
|
||||||
|
.Input("x: T")
|
||||||
|
.Input("y: T")
|
||||||
|
.Output("z: T")
|
||||||
|
.Attr(
|
||||||
|
"T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
|
||||||
|
"complex128}")
|
||||||
|
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
|
||||||
|
.SetIsAggregate()
|
||||||
|
.SetIsCommutative()
|
||||||
|
.Doc(R"doc(
|
||||||
|
Returns x + y element-wise.
|
||||||
|
|
||||||
|
*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
|
||||||
|
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||||
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("_MklAdd")
|
REGISTER_OP("_MklAdd")
|
||||||
.Input("x: T")
|
.Input("x: T")
|
||||||
.Input("y: T")
|
.Input("y: T")
|
||||||
|
@ -244,6 +244,7 @@ TensorSummaryV2
|
|||||||
Abs
|
Abs
|
||||||
AccumulateNV2
|
AccumulateNV2
|
||||||
AddN
|
AddN
|
||||||
|
AddV2
|
||||||
All
|
All
|
||||||
Any
|
Any
|
||||||
BatchMatMul
|
BatchMatMul
|
||||||
|
Loading…
x
Reference in New Issue
Block a user