[KERNEL_GEN] Add a missing test for Mul kernel and sort the tests.
PiperOrigin-RevId: 350172525 Change-Id: If54750054cb46a0a781fde9785493d24e6ce355f
This commit is contained in:
parent
0b83955574
commit
dff50dd3ad
@ -391,23 +391,6 @@ GENERATE_DEFAULT_TESTS(AddV2,
|
||||
GENERATE_DEFAULT_TESTS(AddV2,
|
||||
/*test_name=*/Int64, int64, int64, baseline_add)
|
||||
|
||||
/// Test `tf.Sub`.
|
||||
|
||||
template <typename T>
|
||||
T baseline_sub(T lhs, T rhs) {
|
||||
return lhs - rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
baseline_sub)
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Float, float, float, baseline_sub)
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Double, double, double, baseline_sub)
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Int64, int64, int64, baseline_sub)
|
||||
|
||||
/// Test `tf.BitwiseAnd`.
|
||||
|
||||
template <typename T>
|
||||
@ -456,37 +439,23 @@ GENERATE_DEFAULT_TESTS(BitwiseXor,
|
||||
GENERATE_DEFAULT_TESTS(BitwiseXor,
|
||||
/*test_name=*/Int64, int64, int64, baseline_bitwise_xor)
|
||||
|
||||
/// Test `tf.LeftShift`.
|
||||
|
||||
/// Test `tf.Div`.
|
||||
template <typename T>
|
||||
T baseline_left_shift(T lhs, T rhs) {
|
||||
return lhs << rhs;
|
||||
T baseline_div(T lhs, T rhs) {
|
||||
return lhs / rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int8, int8, int8,
|
||||
baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int16, int16, int16,
|
||||
baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int32, int32, int32,
|
||||
baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int64, int64, int64,
|
||||
baseline_left_shift)
|
||||
|
||||
/// Test `tf.RightShift`.
|
||||
|
||||
template <typename T>
|
||||
T baseline_right_shift(T lhs, T rhs) {
|
||||
return lhs >> rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int8, int8, int8, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int16, int16, int16, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int32, int32, int32, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int64, int64, int64, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(Div,
|
||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
baseline_div);
|
||||
GENERATE_DEFAULT_TESTS(Div,
|
||||
/*test_name=*/Float, float, float, baseline_div);
|
||||
GENERATE_DEFAULT_TESTS(Div,
|
||||
/*test_name=*/Double, double, double, baseline_div);
|
||||
GENERATE_DEFAULT_TESTS(Div,
|
||||
/*test_name=*/Int16, int16, int16, baseline_div);
|
||||
GENERATE_DEFAULT_TESTS(Div,
|
||||
/*test_name=*/Int64, int64, int64, baseline_div);
|
||||
|
||||
/// Test `tf.Equal`.
|
||||
|
||||
@ -505,27 +474,25 @@ GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int8, int8, bool, baseline_equal)
|
||||
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int16, int16, bool, baseline_equal)
|
||||
GENERATE_DEFAULT_TESTS(Equal, /*test_name=*/Int64, int64, bool, baseline_equal)
|
||||
|
||||
/// Test `tf.NotEqual`.
|
||||
/// Test `tf.FloorDiv`.
|
||||
|
||||
template <typename T>
|
||||
bool baseline_not_equal(T lhs, T rhs) {
|
||||
return lhs != rhs;
|
||||
T baseline_floor_div(T lhs, T rhs) {
|
||||
return std::floor(lhs / rhs);
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Half, Eigen::half, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Float, float, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Double, double, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Bool, bool, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int8, int8, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int16, int16, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int64, int64, bool,
|
||||
baseline_not_equal)
|
||||
template <>
|
||||
Eigen::half baseline_floor_div(Eigen::half lhs, Eigen::half rhs) {
|
||||
return static_cast<Eigen::half>(std::floor(static_cast<float>(lhs / rhs)));
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
baseline_floor_div)
|
||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||
/*test_name=*/Float, float, float, baseline_floor_div)
|
||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||
/*test_name=*/Double, double, double, baseline_floor_div)
|
||||
|
||||
/// Test `tf.Greater`.
|
||||
|
||||
@ -567,6 +534,22 @@ GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int16, int16, bool,
|
||||
GENERATE_DEFAULT_TESTS(GreaterEqual, /*test_name=*/Int64, int64, bool,
|
||||
baseline_greater_equal)
|
||||
|
||||
/// Test `tf.LeftShift`.
|
||||
|
||||
template <typename T>
|
||||
T baseline_left_shift(T lhs, T rhs) {
|
||||
return lhs << rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int8, int8, int8,
|
||||
baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int16, int16, int16,
|
||||
baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int32, int32, int32,
|
||||
baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int64, int64, int64,
|
||||
baseline_left_shift)
|
||||
|
||||
/// Test `tf.Less`.
|
||||
|
||||
template <typename T>
|
||||
@ -620,43 +603,75 @@ GENERATE_DEFAULT_TESTS_2(LogicalOr, /*test_name=*/Bool, /*T=*/bool,
|
||||
/*BaselineOutT=*/bool, baseline_logical_or,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoT())
|
||||
|
||||
/// Test `tf.FloorDiv`.
|
||||
/// Test `tf.Mul`.
|
||||
|
||||
template <typename T>
|
||||
T baseline_floor_div(T lhs, T rhs) {
|
||||
return std::floor(lhs / rhs);
|
||||
T baseline_mul(T lhs, T rhs) {
|
||||
return lhs * rhs;
|
||||
}
|
||||
|
||||
template <>
|
||||
Eigen::half baseline_floor_div(Eigen::half lhs, Eigen::half rhs) {
|
||||
return static_cast<Eigen::half>(std::floor(static_cast<float>(lhs / rhs)));
|
||||
}
|
||||
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
baseline_mul)
|
||||
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Float, float, float, baseline_mul)
|
||||
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Double, double, double, baseline_mul)
|
||||
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int8, int8, int8, baseline_mul)
|
||||
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int16, int16, int16, baseline_mul)
|
||||
GENERATE_DEFAULT_TESTS(Mul, /*test_name=*/Int64, int64, int64, baseline_mul)
|
||||
|
||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
baseline_floor_div);
|
||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||
/*test_name=*/Float, float, float, baseline_floor_div);
|
||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||
/*test_name=*/Double, double, double,
|
||||
baseline_floor_div);
|
||||
/// Test `tf.NotEqual`.
|
||||
|
||||
/// Test `tf.Div`.
|
||||
template <typename T>
|
||||
T baseline_div(T lhs, T rhs) {
|
||||
return lhs / rhs;
|
||||
bool baseline_not_equal(T lhs, T rhs) {
|
||||
return lhs != rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(Div,
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Half, Eigen::half, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Float, float, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Double, double, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Bool, bool, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int8, int8, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int16, int16, bool,
|
||||
baseline_not_equal)
|
||||
GENERATE_DEFAULT_TESTS(NotEqual, /*test_name=*/Int64, int64, bool,
|
||||
baseline_not_equal)
|
||||
|
||||
/// Test `tf.RightShift`.
|
||||
|
||||
template <typename T>
|
||||
T baseline_right_shift(T lhs, T rhs) {
|
||||
return lhs >> rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int8, int8, int8, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int16, int16, int16, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int32, int32, int32, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int64, int64, int64, baseline_right_shift)
|
||||
|
||||
/// Test `tf.Sub`.
|
||||
|
||||
template <typename T>
|
||||
T baseline_sub(T lhs, T rhs) {
|
||||
return lhs - rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
baseline_div);
|
||||
GENERATE_DEFAULT_TESTS(Div,
|
||||
/*test_name=*/Float, float, float, baseline_div);
|
||||
GENERATE_DEFAULT_TESTS(Div,
|
||||
/*test_name=*/Double, double, double, baseline_div);
|
||||
GENERATE_DEFAULT_TESTS(Div,
|
||||
/*test_name=*/Int16, int16, int16, baseline_div);
|
||||
GENERATE_DEFAULT_TESTS(Div,
|
||||
/*test_name=*/Int64, int64, int64, baseline_div);
|
||||
baseline_sub)
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Float, float, float, baseline_sub)
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Double, double, double, baseline_sub)
|
||||
GENERATE_DEFAULT_TESTS(Sub,
|
||||
/*test_name=*/Int64, int64, int64, baseline_sub)
|
||||
|
||||
} // namespace
|
||||
} // end namespace tensorflow
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user