[MLIR][KernelGen] Test buffer reuse for unary ops

PiperOrigin-RevId: 346078047
Change-Id: Idec052f21c961ee7679ae91a7fd9c03a6954c5a1
This commit is contained in:
A. Unique TensorFlower 2020-12-07 06:19:56 -08:00 committed by TensorFlower Gardener
parent 9217aed95c
commit 3980781f7f

View File

@ -51,7 +51,8 @@ class GpuUnaryOpTest : public OpsTestBase {
template <typename T, typename RT = T, typename OutT = T, typename ROutT = RT>
void Run(std::vector<int64> input_shape, std::vector<T> input,
const std::string op_name, ROutT (*expected_callback)(RT),
bool expect_equal = true, bool add_tout = false) {
bool expect_equal = true, bool add_tout = false,
bool expect_buffer_reuse = true) {
assert(std::accumulate(input_shape.begin(), input_shape.end(), 1,
std::multiplies<int64>()) == input.size() &&
"Expected input length to equal to shape's number of elements.");
@ -69,6 +70,14 @@ class GpuUnaryOpTest : public OpsTestBase {
AddInputFromArray<T>(shape, input);
TF_ASSERT_OK(RunOpKernel());
// Assert buffer reuse if expected.
if (expect_buffer_reuse) {
void* arg_ptr_on_device = context_->input(0).data();
void* result_ptr_on_device = context_->mutable_output(0)->data();
ASSERT_EQ(arg_ptr_on_device, result_ptr_on_device);
}
// Assert expected results.
Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value, shape);
absl::InlinedVector<OutT, 14> expected;
expected.reserve(input.size());
@ -217,7 +226,9 @@ TEST_F(GpuUnaryOpTest, ConjFloat) {
std::complex<float>>(DefaultInputShape(), DefaultComplexInput<float>(),
/*op_name=*/"Conj",
/*expected_callback=*/std::conj,
/*expect_equal=*/false);
/*expect_equal=*/false,
/*add_tout=*/false,
/*expect_buffer_reuse=*/false);
}
TEST_F(GpuUnaryOpTest, ConjDouble) {
@ -225,7 +236,9 @@ TEST_F(GpuUnaryOpTest, ConjDouble) {
std::complex<double>>(DefaultInputShape(), DefaultComplexInput<double>(),
/*op_name=*/"Conj",
/*expected_callback=*/std::conj,
/*expect_equal=*/false);
/*expect_equal=*/false,
/*add_tout=*/false,
/*expect_buffer_reuse=*/false);
}
/// Test `tf.Cos`.
@ -305,7 +318,8 @@ TEST_F(GpuUnaryOpTest, ImagFloat) {
/*op_name=*/"Imag",
/*expected_callback=*/std::imag,
/*expect_equal=*/false,
/*add_tout=*/true);
/*add_tout=*/true,
/*expect_buffer_reuse=*/false);
}
TEST_F(GpuUnaryOpTest, ImagDouble) {
@ -314,7 +328,8 @@ TEST_F(GpuUnaryOpTest, ImagDouble) {
/*op_name=*/"Imag",
/*expected_callback=*/std::imag,
/*expect_equal=*/false,
/*add_tout=*/true);
/*add_tout=*/true,
/*expect_buffer_reuse=*/false);
}
/// Test `tf.IsInf`.
@ -408,7 +423,8 @@ TEST_F(GpuUnaryOpTest, RealFloat) {
/*op_name=*/"Real",
/*expected_callback=*/std::real,
/*expect_equal=*/false,
/*add_tout=*/true);
/*add_tout=*/true,
/*expect_buffer_reuse=*/false);
}
TEST_F(GpuUnaryOpTest, RealDouble) {
@ -417,7 +433,8 @@ TEST_F(GpuUnaryOpTest, RealDouble) {
/*op_name=*/"Real",
/*expected_callback=*/std::real,
/*expect_equal=*/false,
/*add_tout=*/true);
/*add_tout=*/true,
/*expect_buffer_reuse=*/false);
}
/// Test `tf.Rsqrt`.