[KERNEL_GEN] Add tf.TruncateDiv aliased kernel (tf.Div).
PiperOrigin-RevId: 352075718 Change-Id: Iec2209dbd468b68c36a3a68be5909af441546276
This commit is contained in:
parent
4db28856db
commit
1c729468d6
@ -35,14 +35,13 @@ REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
|
||||
uint16, int16, int64, complex64, complex128);
|
||||
REGISTER5(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double,
|
||||
complex64, complex128);
|
||||
REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16,
|
||||
int64);
|
||||
#else
|
||||
REGISTER4(BinaryOp, GPU, "Div", functor::div, uint8, uint16, complex64,
|
||||
complex128);
|
||||
REGISTER2(BinaryOp, GPU, "RealDiv", functor::div, complex64, complex128);
|
||||
REGISTER2(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16);
|
||||
#endif
|
||||
REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16,
|
||||
int64);
|
||||
REGISTER5(BinaryOp, GPU, "DivNoNan", functor::div_no_nan, Eigen::half, float,
|
||||
double, complex64, complex128);
|
||||
|
||||
|
@ -560,6 +560,27 @@ GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
/*test_name=*/Double, double, double, test::DefaultInput<double>(),
|
||||
test::DefaultInputNonZero<double>(), baseline_floor_div);
|
||||
|
||||
/// Test `tf.RealDiv`.
|
||||
|
||||
template <typename T>
|
||||
T baseline_real_div(T lhs, T rhs) {
|
||||
return lhs / rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
RealDiv,
|
||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
test::DefaultInput<Eigen::half>(), test::DefaultInputNonZero<Eigen::half>(),
|
||||
baseline_real_div);
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
RealDiv,
|
||||
/*test_name=*/Float, float, float, test::DefaultInput<float>(),
|
||||
test::DefaultInputNonZero<float>(), baseline_real_div);
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
RealDiv,
|
||||
/*test_name=*/Double, double, double, test::DefaultInput<double>(),
|
||||
test::DefaultInputNonZero<double>(), baseline_real_div);
|
||||
|
||||
/// Test `tf.Greater`.
|
||||
|
||||
template <typename T>
|
||||
@ -712,22 +733,6 @@ GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int16, int16, bool,
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int64, int64, bool,
|
||||
baseline_not_equal)
|
||||
|
||||
/// Test `tf.RealDiv`.
|
||||
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
RealDiv,
|
||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
test::DefaultInput<Eigen::half>(), test::DefaultInputNonZero<Eigen::half>(),
|
||||
baseline_div);
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
RealDiv,
|
||||
/*test_name=*/Float, float, float, test::DefaultInput<float>(),
|
||||
test::DefaultInputNonZero<float>(), baseline_div);
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
RealDiv,
|
||||
/*test_name=*/Double, double, double, test::DefaultInput<double>(),
|
||||
test::DefaultInputNonZero<double>(), baseline_div);
|
||||
|
||||
/// Test `tf.RightShift`.
|
||||
|
||||
template <typename T>
|
||||
@ -769,16 +774,5 @@ GENERATE_DEFAULT_TESTS(Sub,
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Int64, int64, int64, baseline_sub)
|
||||
|
||||
/// Test `tf.TruncateDiv`.
|
||||
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
TruncateDiv,
|
||||
/*test_name=*/Int16, int16, int16, test::DefaultInput<int16>(),
|
||||
test::DefaultInputNonZero<int16>(), baseline_div);
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
TruncateDiv,
|
||||
/*test_name=*/Int64, int64, int64, test::DefaultInput<int64>(),
|
||||
test::DefaultInputNonZero<int64>(), baseline_div);
|
||||
|
||||
} // namespace
|
||||
} // end namespace tensorflow
|
||||
|
@ -28,7 +28,4 @@ REGISTER_ALIASED_KERNEL(RealDiv, Div, f16, Eigen::half)
|
||||
REGISTER_ALIASED_KERNEL(RealDiv, Div, f32, float)
|
||||
REGISTER_ALIASED_KERNEL(RealDiv, Div, f64, double)
|
||||
|
||||
REGISTER_ALIASED_KERNEL(TruncatedDiv, Div, i16, int16)
|
||||
REGISTER_ALIASED_KERNEL(TruncatedDiv, Div, i64, int64)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user