[MLIR][KernelGen] Remove op name as argument to default input values
PiperOrigin-RevId: 350728617 Change-Id: Ice753bda3655c04c322ab6ee9a50404c2a994242
This commit is contained in:
parent
1f4aaeedb0
commit
cf603aa4f9
@ -356,8 +356,8 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
|
||||
#define GENERATE_DEFAULT_TESTS(op_name, test_name, T, OutT, baseline_callback) \
|
||||
GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, T, OutT, OutT, \
|
||||
test::DefaultInput<T>(#op_name), \
|
||||
test::DefaultInput<T>(#op_name), baseline_callback, \
|
||||
test::DefaultInput<T>(), test::DefaultInput<T>(), \
|
||||
baseline_callback, \
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
#define GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES( \
|
||||
@ -441,23 +441,23 @@ T baseline_div(T lhs, T rhs) {
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Div,
|
||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
test::DefaultInput<Eigen::half>("Div"),
|
||||
test::DefaultInputNonZero<Eigen::half>(), baseline_div);
|
||||
test::DefaultInput<Eigen::half>(), test::DefaultInputNonZero<Eigen::half>(),
|
||||
baseline_div);
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Div,
|
||||
/*test_name=*/Float, float, float, test::DefaultInput<float>("Div"),
|
||||
/*test_name=*/Float, float, float, test::DefaultInput<float>(),
|
||||
test::DefaultInputNonZero<float>(), baseline_div);
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Div,
|
||||
/*test_name=*/Double, double, double, test::DefaultInput<double>("Div"),
|
||||
/*test_name=*/Double, double, double, test::DefaultInput<double>(),
|
||||
test::DefaultInputNonZero<double>(), baseline_div);
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Div,
|
||||
/*test_name=*/Int16, int16, int16, test::DefaultInput<int16>("Div"),
|
||||
/*test_name=*/Int16, int16, int16, test::DefaultInput<int16>(),
|
||||
test::DefaultInputNonZero<int16>(), baseline_div);
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Div,
|
||||
/*test_name=*/Int64, int64, int64, test::DefaultInput<int64>("Div"),
|
||||
/*test_name=*/Int64, int64, int64, test::DefaultInput<int64>(),
|
||||
test::DefaultInputNonZero<int64>(), baseline_div);
|
||||
|
||||
/// Test `tf.Equal`.
|
||||
@ -492,15 +492,15 @@ Eigen::half baseline_floor_div(Eigen::half lhs, Eigen::half rhs) {
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
FloorDiv,
|
||||
/*test_name=*/Half, Eigen::half, Eigen::half,
|
||||
test::DefaultInput<Eigen::half>("Div"),
|
||||
test::DefaultInputNonZero<Eigen::half>(), baseline_floor_div);
|
||||
test::DefaultInput<Eigen::half>(), test::DefaultInputNonZero<Eigen::half>(),
|
||||
baseline_floor_div);
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
FloorDiv,
|
||||
/*test_name=*/Float, float, float, test::DefaultInput<float>("Div"),
|
||||
/*test_name=*/Float, float, float, test::DefaultInput<float>(),
|
||||
test::DefaultInputNonZero<float>(), baseline_floor_div);
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
FloorDiv,
|
||||
/*test_name=*/Double, double, double, test::DefaultInput<double>("Div"),
|
||||
/*test_name=*/Double, double, double, test::DefaultInput<double>(),
|
||||
test::DefaultInputNonZero<double>(), baseline_floor_div);
|
||||
|
||||
/// Test `tf.Greater`.
|
||||
@ -551,20 +551,16 @@ T baseline_left_shift(T lhs, T rhs) {
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
LeftShift, /*test_name=*/Int8, int8, int8,
|
||||
test::DefaultInput<int8>("LeftShift"),
|
||||
LeftShift, /*test_name=*/Int8, int8, int8, test::DefaultInput<int8>(),
|
||||
test::DefaultInputLessThanBitwidth<int8>(), baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
LeftShift, /*test_name=*/Int16, int16, int16,
|
||||
test::DefaultInput<int16>("LeftShift"),
|
||||
LeftShift, /*test_name=*/Int16, int16, int16, test::DefaultInput<int16>(),
|
||||
test::DefaultInputLessThanBitwidth<int16>(), baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
LeftShift, /*test_name=*/Int32, int32, int32,
|
||||
test::DefaultInput<int32>("LeftShift"),
|
||||
LeftShift, /*test_name=*/Int32, int32, int32, test::DefaultInput<int32>(),
|
||||
test::DefaultInputLessThanBitwidth<int32>(), baseline_left_shift)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
LeftShift, /*test_name=*/Int64, int64, int64,
|
||||
test::DefaultInput<int64>("LeftShift"),
|
||||
LeftShift, /*test_name=*/Int64, int64, int64, test::DefaultInput<int64>(),
|
||||
test::DefaultInputLessThanBitwidth<int64>(), baseline_left_shift)
|
||||
|
||||
/// Test `tf.Less`.
|
||||
@ -608,10 +604,8 @@ bool baseline_logical_and(bool lhs, bool rhs) { return lhs && rhs; }
|
||||
|
||||
GENERATE_DEFAULT_TESTS_2(LogicalAnd, /*test_name=*/Bool, /*T=*/bool,
|
||||
/*BaselineT=*/bool, /*OutT=*/bool,
|
||||
/*BaselineOutT=*/bool,
|
||||
test::DefaultInput<bool>("LogicalAnd"),
|
||||
test::DefaultInput<bool>("LogicalAnd"),
|
||||
baseline_logical_and,
|
||||
/*BaselineOutT=*/bool, test::DefaultInput<bool>(),
|
||||
test::DefaultInput<bool>(), baseline_logical_and,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoT())
|
||||
|
||||
/// Test `tf.LogicalOr`.
|
||||
@ -620,10 +614,8 @@ bool baseline_logical_or(bool lhs, bool rhs) { return lhs || rhs; }
|
||||
|
||||
GENERATE_DEFAULT_TESTS_2(LogicalOr, /*test_name=*/Bool, /*T=*/bool,
|
||||
/*BaselineT=*/bool, /*OutT=*/bool,
|
||||
/*BaselineOutT=*/bool,
|
||||
test::DefaultInput<bool>("LogicalOr"),
|
||||
test::DefaultInput<bool>("LogicalOr"),
|
||||
baseline_logical_or,
|
||||
/*BaselineOutT=*/bool, test::DefaultInput<bool>(),
|
||||
test::DefaultInput<bool>(), baseline_logical_or,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoT())
|
||||
|
||||
/// Test `tf.Mul`.
|
||||
@ -672,19 +664,19 @@ T baseline_right_shift(T lhs, T rhs) {
|
||||
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
RightShift,
|
||||
/*test_name=*/Int8, int8, int8, test::DefaultInput<int8>("RightShift"),
|
||||
/*test_name=*/Int8, int8, int8, test::DefaultInput<int8>(),
|
||||
test::DefaultInputLessThanBitwidth<int8>(), baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
RightShift,
|
||||
/*test_name=*/Int16, int16, int16, test::DefaultInput<int16>("RightShift"),
|
||||
/*test_name=*/Int16, int16, int16, test::DefaultInput<int16>(),
|
||||
test::DefaultInputLessThanBitwidth<int16>(), baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
RightShift,
|
||||
/*test_name=*/Int32, int32, int32, test::DefaultInput<int32>("RightShift"),
|
||||
/*test_name=*/Int32, int32, int32, test::DefaultInput<int32>(),
|
||||
test::DefaultInputLessThanBitwidth<int32>(), baseline_right_shift)
|
||||
GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES(
|
||||
RightShift,
|
||||
/*test_name=*/Int64, int64, int64, test::DefaultInput<int64>("RightShift"),
|
||||
/*test_name=*/Int64, int64, int64, test::DefaultInput<int64>(),
|
||||
test::DefaultInputLessThanBitwidth<int64>(), baseline_right_shift)
|
||||
|
||||
/// Test `tf.Sub`.
|
||||
|
@ -149,14 +149,14 @@ absl::InlinedVector<T, 10> DefaultInputLessThanBitwidth() {
|
||||
template <typename T,
|
||||
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
||||
absl::InlinedVector<T, 10> DefaultInput() {
|
||||
return InputAsVector<T, int>({-18, -9, -1, 0, 0, 1, 1, 2, 3, 5, 7, 9, 9, 18});
|
||||
}
|
||||
|
||||
template <typename T, std::enable_if_t<
|
||||
llvm::is_one_of<T, Eigen::half, float, double>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
||||
absl::InlinedVector<T, 10> DefaultInput() {
|
||||
return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1,
|
||||
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
|
||||
}
|
||||
@ -165,9 +165,9 @@ template <typename T,
|
||||
std::enable_if_t<llvm::is_one_of<T, std::complex<float>,
|
||||
std::complex<double>>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
||||
absl::InlinedVector<T, 10> DefaultInput() {
|
||||
using ElementType = typename T::value_type;
|
||||
auto input = test::DefaultInput<ElementType>(op_name);
|
||||
auto input = test::DefaultInput<ElementType>();
|
||||
absl::InlinedVector<T, 10> complex_input;
|
||||
for (ElementType value : input) {
|
||||
complex_input.emplace_back(value, -value);
|
||||
@ -177,7 +177,7 @@ absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
||||
|
||||
template <typename T,
|
||||
std::enable_if_t<llvm::is_one_of<T, bool>::value, bool> = true>
|
||||
absl::InlinedVector<T, 10> DefaultInput(absl::string_view /*op_name*/) {
|
||||
absl::InlinedVector<T, 10> DefaultInput() {
|
||||
return InputAsVector<T, bool>({true, false, true, true, false});
|
||||
}
|
||||
|
||||
|
@ -132,7 +132,7 @@ class GpuUnaryOpTest : public OpsTestBase {
|
||||
baseline_callback, config) \
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
|
||||
op_name, InT, BaselineT, OutT, BaselineOutT, \
|
||||
test::DefaultInput<NativeT>(#op_name), baseline_callback, config)
|
||||
test::DefaultInput<NativeT>(), baseline_callback, config)
|
||||
|
||||
#define GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES( \
|
||||
op_name, InT, OutT, input_values, baseline_callback, config) \
|
||||
@ -252,7 +252,7 @@ GENERATE_DEFAULT_TEST(Imag, DT_COMPLEX128, DT_DOUBLE, baseline_imag,
|
||||
TEST_F(GpuUnaryOpTest, DISABLED_IsInfFloat) {
|
||||
Test<float, float, bool, bool>(
|
||||
/*op_name=*/"IsInf", test::DefaultInputShape(),
|
||||
test::DefaultInput<float>("IsInf"),
|
||||
test::DefaultInput<float>(),
|
||||
/*baseline_callback=*/std::isinf,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual());
|
||||
}
|
||||
@ -263,7 +263,7 @@ TEST_F(GpuUnaryOpTest, DISABLED_IsInfDouble) {
|
||||
// comparing expected values.
|
||||
Test<double, float, bool, bool>(
|
||||
/*op_name=*/"IsInf", test::DefaultInputShape(),
|
||||
test::DefaultInput<double>("IsInf"),
|
||||
test::DefaultInput<double>(),
|
||||
/*baseline_callback=*/std::isinf,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual());
|
||||
}
|
||||
@ -271,7 +271,7 @@ TEST_F(GpuUnaryOpTest, DISABLED_IsInfDouble) {
|
||||
TEST_F(GpuUnaryOpTest, DISABLED_IsInfHalf) {
|
||||
Test<Eigen::half, float, bool, bool>(
|
||||
/*op_name=*/"IsInf", test::DefaultInputShape(),
|
||||
test::DefaultInput<Eigen::half>("IsInf"),
|
||||
test::DefaultInput<Eigen::half>(),
|
||||
/*baseline_callback=*/std::isinf,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual());
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user