Alias the Add and AddV2 operations in kernel generator generated kernels.
They have the same semantics except for strings. PiperOrigin-RevId: 356322598 Change-Id: Ibb0aafb9faabaf0bc8ca8c375b6a183065e37493
This commit is contained in:
parent
d7992c051e
commit
2a76e6eeb2
@ -27,10 +27,10 @@ REGISTER(BinaryOp, CPU, "AddV2", functor::add, bfloat16);
|
||||
#endif
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER3(BinaryOp, GPU, "Add", functor::add, float, Eigen::half, double);
|
||||
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
|
||||
REGISTER3(BinaryOp, GPU, "Add", functor::add, float, Eigen::half, double);
|
||||
REGISTER3(BinaryOp, GPU, "AddV2", functor::add, float, Eigen::half, double);
|
||||
#endif
|
||||
|
||||
|
@ -29,15 +29,18 @@ REGISTER6(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64, uint8,
|
||||
REGISTER8(BinaryOp, CPU, "AddV2", functor::add, int8, int16, complex64, uint8,
|
||||
uint16, uint32, uint64, complex128);
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
|
||||
REGISTER6(BinaryOp, GPU, "Add", functor::add, uint8, uint16, uint64, int64,
|
||||
complex64, complex128);
|
||||
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
|
||||
REGISTER7(BinaryOp, GPU, "AddV2", functor::add, uint8, uint16, uint32, uint64,
|
||||
int64, complex64, complex128);
|
||||
#else
|
||||
// There is an MLIR generated kernel for int64
|
||||
REGISTER5(BinaryOp, GPU, "Add", functor::add, uint8, uint16, uint64, complex64,
|
||||
complex128);
|
||||
|
||||
REGISTER6(BinaryOp, GPU, "AddV2", functor::add, uint8, uint16, uint32, uint64,
|
||||
complex64, complex128);
|
||||
#endif
|
||||
|
@ -19,8 +19,13 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
|
||||
!defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
|
||||
DEFINE_BINARY10(add, Eigen::half, float, double, uint8, uint16, uint32, uint64,
|
||||
int64, complex64, complex128);
|
||||
#else
|
||||
DEFINE_BINARY6(add, uint8, uint16, uint32, uint64, complex64, complex128);
|
||||
#endif
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -33,13 +33,21 @@ class BinaryOpsTest : public BinaryOpsTestBase {
|
||||
}
|
||||
};
|
||||
|
||||
/// Test `tf.AddV2`.
|
||||
/// Test `tf.Add`.
|
||||
|
||||
template <typename T>
|
||||
T baseline_add(T lhs, T rhs) {
|
||||
return lhs + rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(Add, /*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
baseline_add)
|
||||
GENERATE_DEFAULT_TESTS(Add, /*test_name=*/Float, float, float, baseline_add)
|
||||
GENERATE_DEFAULT_TESTS(Add, /*test_name=*/Double, double, double, baseline_add)
|
||||
GENERATE_DEFAULT_TESTS(Add, /*test_name=*/Int64, int64, int64, baseline_add)
|
||||
|
||||
/// Test `tf.AddV2`.
|
||||
|
||||
GENERATE_DEFAULT_TESTS(AddV2, /*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
baseline_add)
|
||||
GENERATE_DEFAULT_TESTS(AddV2, /*test_name=*/Float, float, float, baseline_add)
|
||||
|
@ -22,4 +22,10 @@ GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, f32, DT_FLOAT, float);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, f64, DT_DOUBLE, double);
|
||||
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, i64, DT_INT64, int64);
|
||||
|
||||
// Add is the same as AddV2 except for strings, which we do not support on gpu.
|
||||
REGISTER_ALIASED_GPU_KERNEL(Add, AddV2, f16, f16, Eigen::half);
|
||||
REGISTER_ALIASED_GPU_KERNEL(Add, AddV2, f32, f32, float);
|
||||
REGISTER_ALIASED_GPU_KERNEL(Add, AddV2, f64, f64, double);
|
||||
REGISTER_ALIASED_GPU_KERNEL(Add, AddV2, i64, i64, int64);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user