[KERNEL_GEN] Add tf.TruncateDiv aliased kernel (tf.Div).

PiperOrigin-RevId: 352075718
Change-Id: Iec2209dbd468b68c36a3a68be5909af441546276
This commit is contained in:
A. Unique TensorFlower 2021-01-15 13:43:18 -08:00 committed by TensorFlower Gardener
parent 4db28856db
commit 1c729468d6
3 changed files with 23 additions and 33 deletions

View File

@ -35,14 +35,13 @@ REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
uint16, int16, int64, complex64, complex128);
REGISTER5(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double,
complex64, complex128);
REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16,
int64);
#else
REGISTER4(BinaryOp, GPU, "Div", functor::div, uint8, uint16, complex64,
complex128);
REGISTER2(BinaryOp, GPU, "RealDiv", functor::div, complex64, complex128);
REGISTER2(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16);
#endif
REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16,
int64);
REGISTER5(BinaryOp, GPU, "DivNoNan", functor::div_no_nan, Eigen::half, float,
double, complex64, complex128);

View File

@ -560,6 +560,27 @@ GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
/*test_name=*/Double, double, double, test::DefaultInput<double>(),
test::DefaultInputNonZero<double>(), baseline_floor_div);
/// Test `tf.RealDiv`.
template <typename T>
T baseline_real_div(T lhs, T rhs) {
return lhs / rhs;
}
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
RealDiv,
/*test_name=*/Half, Eigen::half, Eigen::half,
test::DefaultInput<Eigen::half>(), test::DefaultInputNonZero<Eigen::half>(),
baseline_real_div);
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
RealDiv,
/*test_name=*/Float, float, float, test::DefaultInput<float>(),
test::DefaultInputNonZero<float>(), baseline_real_div);
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
RealDiv,
/*test_name=*/Double, double, double, test::DefaultInput<double>(),
test::DefaultInputNonZero<double>(), baseline_real_div);
/// Test `tf.Greater`.
template <typename T>
@ -712,22 +733,6 @@ GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int16, int16, bool,
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int64, int64, bool,
baseline_not_equal)
/// Test `tf.RealDiv`.
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
RealDiv,
/*test_name=*/Half, Eigen::half, Eigen::half,
test::DefaultInput<Eigen::half>(), test::DefaultInputNonZero<Eigen::half>(),
baseline_div);
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
RealDiv,
/*test_name=*/Float, float, float, test::DefaultInput<float>(),
test::DefaultInputNonZero<float>(), baseline_div);
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
RealDiv,
/*test_name=*/Double, double, double, test::DefaultInput<double>(),
test::DefaultInputNonZero<double>(), baseline_div);
/// Test `tf.RightShift`.
template <typename T>
@ -769,16 +774,5 @@ GENERATE_DEFAULT_TESTS(Sub,
GENERATE_DEFAULT_TESTS(Sub,
/*test_name=*/Int64, int64, int64, baseline_sub)
/// Test `tf.TruncateDiv`.
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
TruncateDiv,
/*test_name=*/Int16, int16, int16, test::DefaultInput<int16>(),
test::DefaultInputNonZero<int16>(), baseline_div);
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
TruncateDiv,
/*test_name=*/Int64, int64, int64, test::DefaultInput<int64>(),
test::DefaultInputNonZero<int64>(), baseline_div);
} // namespace
} // end namespace tensorflow

View File

@ -28,7 +28,4 @@ REGISTER_ALIASED_KERNEL(RealDiv, Div, f16, Eigen::half)
REGISTER_ALIASED_KERNEL(RealDiv, Div, f32, float)
REGISTER_ALIASED_KERNEL(RealDiv, Div, f64, double)
REGISTER_ALIASED_KERNEL(TruncatedDiv, Div, i16, int16)
REGISTER_ALIASED_KERNEL(TruncatedDiv, Div, i64, int64)
} // namespace tensorflow