Alias the Add and AddV2 operations in kernel generator generated kernels.

They have the same semantics except for strings.

PiperOrigin-RevId: 356322598
Change-Id: Ibb0aafb9faabaf0bc8ca8c375b6a183065e37493
This commit is contained in:
Stephan Herhut 2021-02-08 12:11:32 -08:00 committed by TensorFlower Gardener
parent d7992c051e
commit 2a76e6eeb2
5 changed files with 26 additions and 4 deletions

View File

@ -27,10 +27,10 @@ REGISTER(BinaryOp, CPU, "AddV2", functor::add, bfloat16);
#endif
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER3(BinaryOp, GPU, "Add", functor::add, float, Eigen::half, double);
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
REGISTER3(BinaryOp, GPU, "Add", functor::add, float, Eigen::half, double);
REGISTER3(BinaryOp, GPU, "AddV2", functor::add, float, Eigen::half, double);
#endif

View File

@ -29,15 +29,18 @@ REGISTER6(BinaryOp, CPU, "Add", functor::add, int8, int16, complex64, uint8,
REGISTER8(BinaryOp, CPU, "AddV2", functor::add, int8, int16, complex64, uint8,
uint16, uint32, uint64, complex128);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
REGISTER6(BinaryOp, GPU, "Add", functor::add, uint8, uint16, uint64, int64,
complex64, complex128);
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
REGISTER7(BinaryOp, GPU, "AddV2", functor::add, uint8, uint16, uint32, uint64,
int64, complex64, complex128);
#else
// There is an MLIR generated kernel for int64
REGISTER5(BinaryOp, GPU, "Add", functor::add, uint8, uint16, uint64, complex64,
complex128);
REGISTER6(BinaryOp, GPU, "AddV2", functor::add, uint8, uint16, uint32, uint64,
complex64, complex128);
#endif

View File

@ -19,8 +19,13 @@ limitations under the License.
namespace tensorflow {
namespace functor {
#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
!defined(MLIR_GENERATED_EXPERIMENTAL_KERNELS_ENABLED)
DEFINE_BINARY10(add, Eigen::half, float, double, uint8, uint16, uint32, uint64,
int64, complex64, complex128);
#else
DEFINE_BINARY6(add, uint8, uint16, uint32, uint64, complex64, complex128);
#endif
} // namespace functor
} // namespace tensorflow

View File

@ -33,13 +33,21 @@ class BinaryOpsTest : public BinaryOpsTestBase {
}
};
/// Test `tf.AddV2`.
/// Test `tf.Add`.
template <typename T>
T baseline_add(T lhs, T rhs) {
return lhs + rhs;
}
GENERATE_DEFAULT_TESTS(Add, /*test_name=*/Half, Eigen::half, Eigen::half,
baseline_add)
GENERATE_DEFAULT_TESTS(Add, /*test_name=*/Float, float, float, baseline_add)
GENERATE_DEFAULT_TESTS(Add, /*test_name=*/Double, double, double, baseline_add)
GENERATE_DEFAULT_TESTS(Add, /*test_name=*/Int64, int64, int64, baseline_add)
/// Test `tf.AddV2`.
GENERATE_DEFAULT_TESTS(AddV2, /*test_name=*/Half, Eigen::half, Eigen::half,
baseline_add)
GENERATE_DEFAULT_TESTS(AddV2, /*test_name=*/Float, float, float, baseline_add)

View File

@ -22,4 +22,10 @@ GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, f32, DT_FLOAT, float);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, f64, DT_DOUBLE, double);
GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(AddV2, i64, DT_INT64, int64);
// Add is the same as AddV2 except for strings, which we do not support on gpu.
REGISTER_ALIASED_GPU_KERNEL(Add, AddV2, f16, f16, Eigen::half);
REGISTER_ALIASED_GPU_KERNEL(Add, AddV2, f32, f32, float);
REGISTER_ALIASED_GPU_KERNEL(Add, AddV2, f64, f64, double);
REGISTER_ALIASED_GPU_KERNEL(Add, AddV2, i64, i64, int64);
} // namespace tensorflow