[MLIR][KernelGen] Test buffer reuse for unary ops
PiperOrigin-RevId: 346078047 Change-Id: Idec052f21c961ee7679ae91a7fd9c03a6954c5a1
This commit is contained in:
parent
9217aed95c
commit
3980781f7f
@ -51,7 +51,8 @@ class GpuUnaryOpTest : public OpsTestBase {
|
||||
template <typename T, typename RT = T, typename OutT = T, typename ROutT = RT>
|
||||
void Run(std::vector<int64> input_shape, std::vector<T> input,
|
||||
const std::string op_name, ROutT (*expected_callback)(RT),
|
||||
bool expect_equal = true, bool add_tout = false) {
|
||||
bool expect_equal = true, bool add_tout = false,
|
||||
bool expect_buffer_reuse = true) {
|
||||
assert(std::accumulate(input_shape.begin(), input_shape.end(), 1,
|
||||
std::multiplies<int64>()) == input.size() &&
|
||||
"Expected input length to equal to shape's number of elements.");
|
||||
@ -69,6 +70,14 @@ class GpuUnaryOpTest : public OpsTestBase {
|
||||
AddInputFromArray<T>(shape, input);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Assert buffer reuse if expected.
|
||||
if (expect_buffer_reuse) {
|
||||
void* arg_ptr_on_device = context_->input(0).data();
|
||||
void* result_ptr_on_device = context_->mutable_output(0)->data();
|
||||
ASSERT_EQ(arg_ptr_on_device, result_ptr_on_device);
|
||||
}
|
||||
|
||||
// Assert expected results.
|
||||
Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value, shape);
|
||||
absl::InlinedVector<OutT, 14> expected;
|
||||
expected.reserve(input.size());
|
||||
@ -217,7 +226,9 @@ TEST_F(GpuUnaryOpTest, ConjFloat) {
|
||||
std::complex<float>>(DefaultInputShape(), DefaultComplexInput<float>(),
|
||||
/*op_name=*/"Conj",
|
||||
/*expected_callback=*/std::conj,
|
||||
/*expect_equal=*/false);
|
||||
/*expect_equal=*/false,
|
||||
/*add_tout=*/false,
|
||||
/*expect_buffer_reuse=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, ConjDouble) {
|
||||
@ -225,7 +236,9 @@ TEST_F(GpuUnaryOpTest, ConjDouble) {
|
||||
std::complex<double>>(DefaultInputShape(), DefaultComplexInput<double>(),
|
||||
/*op_name=*/"Conj",
|
||||
/*expected_callback=*/std::conj,
|
||||
/*expect_equal=*/false);
|
||||
/*expect_equal=*/false,
|
||||
/*add_tout=*/false,
|
||||
/*expect_buffer_reuse=*/false);
|
||||
}
|
||||
|
||||
/// Test `tf.Cos`.
|
||||
@ -305,7 +318,8 @@ TEST_F(GpuUnaryOpTest, ImagFloat) {
|
||||
/*op_name=*/"Imag",
|
||||
/*expected_callback=*/std::imag,
|
||||
/*expect_equal=*/false,
|
||||
/*add_tout=*/true);
|
||||
/*add_tout=*/true,
|
||||
/*expect_buffer_reuse=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, ImagDouble) {
|
||||
@ -314,7 +328,8 @@ TEST_F(GpuUnaryOpTest, ImagDouble) {
|
||||
/*op_name=*/"Imag",
|
||||
/*expected_callback=*/std::imag,
|
||||
/*expect_equal=*/false,
|
||||
/*add_tout=*/true);
|
||||
/*add_tout=*/true,
|
||||
/*expect_buffer_reuse=*/false);
|
||||
}
|
||||
|
||||
/// Test `tf.IsInf`.
|
||||
@ -408,7 +423,8 @@ TEST_F(GpuUnaryOpTest, RealFloat) {
|
||||
/*op_name=*/"Real",
|
||||
/*expected_callback=*/std::real,
|
||||
/*expect_equal=*/false,
|
||||
/*add_tout=*/true);
|
||||
/*add_tout=*/true,
|
||||
/*expect_buffer_reuse=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, RealDouble) {
|
||||
@ -417,7 +433,8 @@ TEST_F(GpuUnaryOpTest, RealDouble) {
|
||||
/*op_name=*/"Real",
|
||||
/*expected_callback=*/std::real,
|
||||
/*expect_equal=*/false,
|
||||
/*add_tout=*/true);
|
||||
/*add_tout=*/true,
|
||||
/*expect_buffer_reuse=*/false);
|
||||
}
|
||||
|
||||
/// Test `tf.Rsqrt`.
|
||||
|
Loading…
Reference in New Issue
Block a user