[MLIR][KernelGen] Explicitly specify special inputs in all binary tests.
PiperOrigin-RevId: 350725301 Change-Id: I516dde4a6496af7a7419c3f877e63ce4a5c8f086
This commit is contained in:
parent
d185bebfe7
commit
1f4aaeedb0
@ -489,13 +489,19 @@ Eigen::half baseline_floor_div(Eigen::half lhs, Eigen::half rhs) {
|
||||
return static_cast<Eigen::half>(std::floor(static_cast<float>(lhs / rhs)));
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
baseline_floor_div)
|
||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||
/*test_name=*/Float, float, float, baseline_floor_div)
|
||||
GENERATE_DEFAULT_TESTS(FloorDiv,
|
||||
/*test_name=*/Double, double, double, baseline_floor_div)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
FloorDiv,
|
||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
test::DefaultInput<Eigen::half>("Div"),
|
||||
test::DefaultInputNonZero<Eigen::half>(), baseline_floor_div);
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
FloorDiv,
|
||||
/*test_name=*/Float, float, float, test::DefaultInput<float>("Div"),
|
||||
test::DefaultInputNonZero<float>(), baseline_floor_div);
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
FloorDiv,
|
||||
/*test_name=*/Double, double, double, test::DefaultInput<double>("Div"),
|
||||
test::DefaultInputNonZero<double>(), baseline_floor_div);
|
||||
|
||||
/// Test `tf.Greater`.
|
||||
|
||||
@ -544,14 +550,22 @@ T baseline_left_shift(T lhs, T rhs) {
|
||||
return lhs << rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int8, int8, int8,
|
||||
baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int16, int16, int16,
|
||||
baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int32, int32, int32,
|
||||
baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS(LeftShift, /*test_name=*/Int64, int64, int64,
|
||||
baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
LeftShift, /*test_name=*/Int8, int8, int8,
|
||||
test::DefaultInput<int8>("LeftShift"),
|
||||
test::DefaultInputLessThanBitwidth<int8>(), baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
LeftShift, /*test_name=*/Int16, int16, int16,
|
||||
test::DefaultInput<int16>("LeftShift"),
|
||||
test::DefaultInputLessThanBitwidth<int16>(), baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
LeftShift, /*test_name=*/Int32, int32, int32,
|
||||
test::DefaultInput<int32>("LeftShift"),
|
||||
test::DefaultInputLessThanBitwidth<int32>(), baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
LeftShift, /*test_name=*/Int64, int64, int64,
|
||||
test::DefaultInput<int64>("LeftShift"),
|
||||
test::DefaultInputLessThanBitwidth<int64>(), baseline_left_shift)
|
||||
|
||||
/// Test `tf.Less`.
|
||||
|
||||
@ -656,14 +670,22 @@ T baseline_right_shift(T lhs, T rhs) {
|
||||
return lhs >> rhs;
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int8, int8, int8, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int16, int16, int16, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int32, int32, int32, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS(RightShift,
|
||||
/*test_name=*/Int64, int64, int64, baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
RightShift,
|
||||
/*test_name=*/Int8, int8, int8, test::DefaultInput<int8>("RightShift"),
|
||||
test::DefaultInputLessThanBitwidth<int8>(), baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
RightShift,
|
||||
/*test_name=*/Int16, int16, int16, test::DefaultInput<int16>("RightShift"),
|
||||
test::DefaultInputLessThanBitwidth<int16>(), baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
RightShift,
|
||||
/*test_name=*/Int32, int32, int32, test::DefaultInput<int32>("RightShift"),
|
||||
test::DefaultInputLessThanBitwidth<int32>(), baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
RightShift,
|
||||
/*test_name=*/Int64, int64, int64, test::DefaultInput<int64>("RightShift"),
|
||||
test::DefaultInputLessThanBitwidth<int64>(), baseline_right_shift)
|
||||
|
||||
/// Test `tf.Sub`.
|
||||
|
||||
|
@ -136,17 +136,20 @@ absl::InlinedVector<T, 10> DefaultInputNonZero() {
|
||||
|
||||
/// Helper functions to get default input data.
|
||||
|
||||
template <typename T,
|
||||
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> DefaultInputLessThanBitwidth() {
|
||||
auto max_shift = sizeof(T) * 8 - 1;
|
||||
absl::InlinedVector<T, 10> v(max_shift);
|
||||
for (auto i = 0; i < max_shift; ++i) v.push_back(i);
|
||||
return v;
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
||||
// Only generate values less than the bitwidth of the data type.
|
||||
if (op_name == "LeftShift" || op_name == "RightShift") {
|
||||
auto max_shift = sizeof(T) * 8 - 1;
|
||||
absl::InlinedVector<T, 10> v(max_shift);
|
||||
for (auto i = 0; i < max_shift; ++i) v.push_back(i);
|
||||
return v;
|
||||
}
|
||||
return InputAsVector<T, int>({-18, -9, -1, 0, 0, 1, 1, 2, 3, 5, 7, 9, 9, 18});
|
||||
}
|
||||
|
||||
@ -154,10 +157,6 @@ template <typename T, std::enable_if_t<
|
||||
llvm::is_one_of<T, Eigen::half, float, double>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
||||
if (op_name == "FloorDiv") {
|
||||
return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.1, 0.1, 1e-6, 0.1,
|
||||
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
|
||||
}
|
||||
return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1,
|
||||
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user