[MLIR][KernelGen] Unify unary kernel tests

This refactoring allows to easily add more tests for other unranked kernels
without duplicating code.

PiperOrigin-RevId: 342626049
Change-Id: I974b4fcf8f887fb1512103f53194c216ee100e0b
This commit is contained in:
A. Unique TensorFlower 2020-11-16 06:52:51 -08:00 committed by TensorFlower Gardener
parent 516f4a121c
commit 229aeadda8
4 changed files with 174 additions and 203 deletions

View File

@ -147,30 +147,9 @@ tf_kernel_library(
)
tf_cuda_cc_test(
name = "gpu_tanh_test",
name = "gpu_unary_ops_test",
size = "small",
srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_tanh_test.cc"]),
tags = tf_cuda_tests_tags() + [
"no_cuda_asan", # TODO(b/171341759): re-enable.
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/common_runtime:device",
"//tensorflow/core/common_runtime:device_factory",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:ops_testutil",
],
)
tf_cuda_cc_test(
name = "gpu_abs_test",
size = "small",
srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_abs_test.cc"]),
srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_unary_ops_test.cc"]),
tags = tf_cuda_tests_tags() + [
"no_cuda_asan", # TODO(b/171341759): re-enable.
],

View File

@ -1,95 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include <limits>
#include <memory>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class GpuAbsTest : public OpsTestBase {
protected:
void SetUp() override {
std::unique_ptr<tensorflow::Device> device_gpu(
tensorflow::DeviceFactory::NewDevice("GPU", {},
"/job:a/replica:0/task:0"));
SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
}
template <typename T, typename RT = T>
void RunAbsOp(std::initializer_list<T> input) {
TensorShape shape({2, 3});
TF_ASSERT_OK(NodeDefBuilder("abs_op", "Abs")
.Input(FakeInput(DataTypeToEnum<T>::v()))
.Attr("T", DataTypeToEnum<T>::v())
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<T>(shape, input);
TF_ASSERT_OK(RunOpKernel());
Tensor expected_tensor(allocator(), DataTypeToEnum<T>::value, shape);
std::vector<T> expected;
expected.reserve(input.size());
for (const T& inp : input) {
expected.push_back(static_cast<T>(std::abs(static_cast<RT>(inp))));
}
test::FillValues<T>(&expected_tensor, expected);
test::ExpectEqual(expected_tensor, *GetOutput(0));
}
};
TEST_F(GpuAbsTest, AbsFloat) {
RunAbsOp<float>({-std::numeric_limits<float>::infinity(), -0.1f, -0.0f, 0.0f,
0.1f, std::numeric_limits<float>::infinity()});
}
TEST_F(GpuAbsTest, AbsDouble) {
RunAbsOp<double>({-std::numeric_limits<double>::infinity(), -0.1, -0.0, 0.0,
0.1, std::numeric_limits<double>::infinity()});
}
TEST_F(GpuAbsTest, AbsHalf) {
RunAbsOp<Eigen::half, float>(
{static_cast<Eigen::half>(-std::numeric_limits<double>::infinity()),
static_cast<Eigen::half>(-0.1), static_cast<Eigen::half>(-0.0),
static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(0.1),
static_cast<Eigen::half>(std::numeric_limits<double>::infinity())});
}
TEST_F(GpuAbsTest, AbsInt32) {
RunAbsOp<int32>({std::numeric_limits<int32>::min(),
std::numeric_limits<int32>::min() + 1, -1, 0, 1,
std::numeric_limits<int32>::max()});
}
TEST_F(GpuAbsTest, AbsInt64) {
RunAbsOp<int64>({std::numeric_limits<int64>::min(),
std::numeric_limits<int64>::min() + 1, -1, 0, 1,
std::numeric_limits<int64>::max()});
}
} // namespace
} // end namespace tensorflow

View File

@ -1,85 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include <memory>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class GpuTanhTest : public OpsTestBase {
protected:
void SetUp() override {
std::unique_ptr<tensorflow::Device> device_gpu(
tensorflow::DeviceFactory::NewDevice("GPU", {},
"/job:a/replica:0/task:0"));
SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
}
template <typename T, typename RT = T>
void RunTanhOp(std::initializer_list<T> input) {
TensorShape shape({2, 7});
TF_ASSERT_OK(NodeDefBuilder("tanh_op", "Tanh")
.Input(FakeInput(DataTypeToEnum<T>::v()))
.Attr("T", DataTypeToEnum<T>::v())
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<T>(shape, input);
TF_ASSERT_OK(RunOpKernel());
Tensor expected_tensor(allocator(), DataTypeToEnum<T>::value, shape);
std::vector<T> expected;
expected.reserve(input.size());
for (const T& inp : input) {
expected.push_back(static_cast<T>(std::tanh(static_cast<RT>(inp))));
}
test::FillValues<T>(&expected_tensor, expected);
test::ExpectClose(expected_tensor, *GetOutput(0));
}
};
TEST_F(GpuTanhTest, TanhFloat) {
RunTanhOp<float>({-18.0f, -9.0f, -1e-6f, -0.0f, 0.0f, 1e-6, 0.1f, 0.2f, 0.3f,
0.5f, 0.7f, 0.9f, 9.0f, 18.0f});
}
TEST_F(GpuTanhTest, TanhDouble) {
RunTanhOp<double>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1, 0.2, 0.3, 0.5,
0.7, 0.9, 9.0, 18.0});
}
TEST_F(GpuTanhTest, TanhHalf) {
RunTanhOp<Eigen::half, float>(
{static_cast<Eigen::half>(-18.0), static_cast<Eigen::half>(-9.0),
static_cast<Eigen::half>(-1e-6), static_cast<Eigen::half>(-0.0),
static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(1e-6),
static_cast<Eigen::half>(0.1), static_cast<Eigen::half>(0.2),
static_cast<Eigen::half>(0.3), static_cast<Eigen::half>(0.5),
static_cast<Eigen::half>(0.7), static_cast<Eigen::half>(0.9),
static_cast<Eigen::half>(9.0), static_cast<Eigen::half>(18.0)});
}
} // namespace
} // end namespace tensorflow

View File

@ -0,0 +1,172 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include <functional>
#include <memory>
#include <numeric>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class GpuUnaryOpTest : public OpsTestBase {
protected:
void SetUp() override {
std::unique_ptr<tensorflow::Device> device_gpu(
tensorflow::DeviceFactory::NewDevice("GPU", {},
"/job:a/replica:0/task:0"));
SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
}
template <typename T, typename RT = T>
void Run(std::initializer_list<int64> input_shape,
std::initializer_list<T> input, const std::string op_name,
RT (*expected_callback)(RT), bool expect_equal = true) {
assert(std::accumulate(input_shape.begin(), input_shape.end(), 1,
std::multiplies<int64>()) == input.size() &&
"Expected input length to equal to shape's number of elements.");
TensorShape shape(input_shape);
TF_ASSERT_OK(NodeDefBuilder("some_name", op_name)
.Input(FakeInput(DataTypeToEnum<T>::v()))
.Attr("T", DataTypeToEnum<T>::v())
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<T>(shape, input);
TF_ASSERT_OK(RunOpKernel());
Tensor expected_tensor(allocator(), DataTypeToEnum<T>::value, shape);
std::vector<T> expected;
expected.reserve(input.size());
for (const T& inp : input) {
expected.push_back(
static_cast<T>(expected_callback(static_cast<RT>(inp))));
}
test::FillValues<T>(&expected_tensor, expected);
if (expect_equal) {
test::ExpectEqual(expected_tensor, *GetOutput(0));
} else {
test::ExpectClose(expected_tensor, *GetOutput(0));
}
}
};
TEST_F(GpuUnaryOpTest, TanhFloat) {
Run<float>(/*input_shape=*/{2, 7},
/*input=*/
{-18.0f, -9.0f, -1e-6f, -0.0f, 0.0f, 1e-6, 0.1f, 0.2f, 0.3f, 0.5f,
0.7f, 0.9f, 9.0f, 18.0f},
/*op_name=*/"Tanh",
/*expected_callback=*/std::tanh,
/*expect_equal=*/false);
}
TEST_F(GpuUnaryOpTest, TanhDouble) {
Run<double>(/*input_shape=*/{2, 7},
/*input=*/
{-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1, 0.2, 0.3, 0.5, 0.7,
0.9, 9.0, 18.0},
/*op_name=*/"Tanh",
/*expected_callback=*/std::tanh,
/*expect_equal=*/false);
}
TEST_F(GpuUnaryOpTest, TanhHalf) {
Run<Eigen::half, float>(
/*input_shape=*/{2, 7},
/*input=*/
{static_cast<Eigen::half>(-18.0), static_cast<Eigen::half>(-9.0),
static_cast<Eigen::half>(-1e-6), static_cast<Eigen::half>(-0.0),
static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(1e-6),
static_cast<Eigen::half>(0.1), static_cast<Eigen::half>(0.2),
static_cast<Eigen::half>(0.3), static_cast<Eigen::half>(0.5),
static_cast<Eigen::half>(0.7), static_cast<Eigen::half>(0.9),
static_cast<Eigen::half>(9.0), static_cast<Eigen::half>(18.0)},
/*op_name=*/"Tanh",
/*expected_callback=*/std::tanh,
/*expect_equal=*/false);
}
TEST_F(GpuUnaryOpTest, AbsFloat) {
Run<float>(
/*input_shape=*/{2, 3},
/*input=*/
{-std::numeric_limits<float>::infinity(), -0.1f, -0.0f, 0.0f, 0.1f,
std::numeric_limits<float>::infinity()},
/*op_name=*/"Abs",
/*expected_callback=*/std::abs,
/*expect_equal=*/false);
}
TEST_F(GpuUnaryOpTest, AbsDouble) {
Run<double>(
/*input_shape=*/{2, 3},
/*input=*/
{-std::numeric_limits<double>::infinity(), -0.1, -0.0, 0.0, 0.1,
std::numeric_limits<double>::infinity()},
/*op_name=*/"Abs",
/*expected_callback=*/std::abs,
/*expect_equal=*/false);
}
TEST_F(GpuUnaryOpTest, AbsHalf) {
Run<Eigen::half, float>(
/*input_shape=*/{2, 3},
/*input=*/
{static_cast<Eigen::half>(-std::numeric_limits<double>::infinity()),
static_cast<Eigen::half>(-0.1), static_cast<Eigen::half>(-0.0),
static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(0.1),
static_cast<Eigen::half>(std::numeric_limits<double>::infinity())},
/*op_name=*/"Abs",
/*expected_callback=*/std::abs,
/*expect_equal=*/false);
}
TEST_F(GpuUnaryOpTest, AbsInt32) {
Run<int32>(
/*input_shape=*/{2, 3},
/*input=*/
{std::numeric_limits<int32>::min(), std::numeric_limits<int32>::min() + 1,
-1, 0, 1, std::numeric_limits<int32>::max()},
/*op_name=*/"Abs",
/*expected_callback=*/std::abs,
/*expect_equal=*/true);
}
TEST_F(GpuUnaryOpTest, AbsInt64) {
Run<int64>(
/*input_shape=*/{2, 3},
/*input=*/
{std::numeric_limits<int64>::min(), std::numeric_limits<int64>::min() + 1,
-1, 0, 1, std::numeric_limits<int64>::max()},
/*op_name=*/"Abs",
/*expected_callback=*/std::abs,
/*expect_equal=*/true);
}
} // namespace
} // end namespace tensorflow