[MLIR][KernelGen] Unify unary kernel tests
This refactoring allows to easily add more tests for other unranked kernels without duplicating code. PiperOrigin-RevId: 342626049 Change-Id: I974b4fcf8f887fb1512103f53194c216ee100e0b
This commit is contained in:
parent
516f4a121c
commit
229aeadda8
@ -147,30 +147,9 @@ tf_kernel_library(
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "gpu_tanh_test",
|
||||
name = "gpu_unary_ops_test",
|
||||
size = "small",
|
||||
srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_tanh_test.cc"]),
|
||||
tags = tf_cuda_tests_tags() + [
|
||||
"no_cuda_asan", # TODO(b/171341759): re-enable.
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/common_runtime:device",
|
||||
"//tensorflow/core/common_runtime:device_factory",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "gpu_abs_test",
|
||||
size = "small",
|
||||
srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_abs_test.cc"]),
|
||||
srcs = if_mlir_generated_gpu_kernels_enabled(["gpu_unary_ops_test.cc"]),
|
||||
tags = tf_cuda_tests_tags() + [
|
||||
"no_cuda_asan", # TODO(b/171341759): re-enable.
|
||||
],
|
||||
|
@ -1,95 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class GpuAbsTest : public OpsTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
std::unique_ptr<tensorflow::Device> device_gpu(
|
||||
tensorflow::DeviceFactory::NewDevice("GPU", {},
|
||||
"/job:a/replica:0/task:0"));
|
||||
SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
|
||||
}
|
||||
template <typename T, typename RT = T>
|
||||
void RunAbsOp(std::initializer_list<T> input) {
|
||||
TensorShape shape({2, 3});
|
||||
TF_ASSERT_OK(NodeDefBuilder("abs_op", "Abs")
|
||||
.Input(FakeInput(DataTypeToEnum<T>::v()))
|
||||
.Attr("T", DataTypeToEnum<T>::v())
|
||||
.Finalize(node_def()));
|
||||
|
||||
TF_ASSERT_OK(InitOp());
|
||||
AddInputFromArray<T>(shape, input);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected_tensor(allocator(), DataTypeToEnum<T>::value, shape);
|
||||
std::vector<T> expected;
|
||||
expected.reserve(input.size());
|
||||
for (const T& inp : input) {
|
||||
expected.push_back(static_cast<T>(std::abs(static_cast<RT>(inp))));
|
||||
}
|
||||
test::FillValues<T>(&expected_tensor, expected);
|
||||
test::ExpectEqual(expected_tensor, *GetOutput(0));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(GpuAbsTest, AbsFloat) {
|
||||
RunAbsOp<float>({-std::numeric_limits<float>::infinity(), -0.1f, -0.0f, 0.0f,
|
||||
0.1f, std::numeric_limits<float>::infinity()});
|
||||
}
|
||||
|
||||
TEST_F(GpuAbsTest, AbsDouble) {
|
||||
RunAbsOp<double>({-std::numeric_limits<double>::infinity(), -0.1, -0.0, 0.0,
|
||||
0.1, std::numeric_limits<double>::infinity()});
|
||||
}
|
||||
|
||||
TEST_F(GpuAbsTest, AbsHalf) {
|
||||
RunAbsOp<Eigen::half, float>(
|
||||
{static_cast<Eigen::half>(-std::numeric_limits<double>::infinity()),
|
||||
static_cast<Eigen::half>(-0.1), static_cast<Eigen::half>(-0.0),
|
||||
static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(0.1),
|
||||
static_cast<Eigen::half>(std::numeric_limits<double>::infinity())});
|
||||
}
|
||||
|
||||
TEST_F(GpuAbsTest, AbsInt32) {
|
||||
RunAbsOp<int32>({std::numeric_limits<int32>::min(),
|
||||
std::numeric_limits<int32>::min() + 1, -1, 0, 1,
|
||||
std::numeric_limits<int32>::max()});
|
||||
}
|
||||
|
||||
TEST_F(GpuAbsTest, AbsInt64) {
|
||||
RunAbsOp<int64>({std::numeric_limits<int64>::min(),
|
||||
std::numeric_limits<int64>::min() + 1, -1, 0, 1,
|
||||
std::numeric_limits<int64>::max()});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // end namespace tensorflow
|
@ -1,85 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class GpuTanhTest : public OpsTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
std::unique_ptr<tensorflow::Device> device_gpu(
|
||||
tensorflow::DeviceFactory::NewDevice("GPU", {},
|
||||
"/job:a/replica:0/task:0"));
|
||||
SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
|
||||
}
|
||||
template <typename T, typename RT = T>
|
||||
void RunTanhOp(std::initializer_list<T> input) {
|
||||
TensorShape shape({2, 7});
|
||||
TF_ASSERT_OK(NodeDefBuilder("tanh_op", "Tanh")
|
||||
.Input(FakeInput(DataTypeToEnum<T>::v()))
|
||||
.Attr("T", DataTypeToEnum<T>::v())
|
||||
.Finalize(node_def()));
|
||||
|
||||
TF_ASSERT_OK(InitOp());
|
||||
AddInputFromArray<T>(shape, input);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected_tensor(allocator(), DataTypeToEnum<T>::value, shape);
|
||||
std::vector<T> expected;
|
||||
expected.reserve(input.size());
|
||||
for (const T& inp : input) {
|
||||
expected.push_back(static_cast<T>(std::tanh(static_cast<RT>(inp))));
|
||||
}
|
||||
test::FillValues<T>(&expected_tensor, expected);
|
||||
test::ExpectClose(expected_tensor, *GetOutput(0));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(GpuTanhTest, TanhFloat) {
|
||||
RunTanhOp<float>({-18.0f, -9.0f, -1e-6f, -0.0f, 0.0f, 1e-6, 0.1f, 0.2f, 0.3f,
|
||||
0.5f, 0.7f, 0.9f, 9.0f, 18.0f});
|
||||
}
|
||||
|
||||
TEST_F(GpuTanhTest, TanhDouble) {
|
||||
RunTanhOp<double>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1, 0.2, 0.3, 0.5,
|
||||
0.7, 0.9, 9.0, 18.0});
|
||||
}
|
||||
|
||||
TEST_F(GpuTanhTest, TanhHalf) {
|
||||
RunTanhOp<Eigen::half, float>(
|
||||
{static_cast<Eigen::half>(-18.0), static_cast<Eigen::half>(-9.0),
|
||||
static_cast<Eigen::half>(-1e-6), static_cast<Eigen::half>(-0.0),
|
||||
static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(1e-6),
|
||||
static_cast<Eigen::half>(0.1), static_cast<Eigen::half>(0.2),
|
||||
static_cast<Eigen::half>(0.3), static_cast<Eigen::half>(0.5),
|
||||
static_cast<Eigen::half>(0.7), static_cast<Eigen::half>(0.9),
|
||||
static_cast<Eigen::half>(9.0), static_cast<Eigen::half>(18.0)});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // end namespace tensorflow
|
172
tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
Normal file
172
tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
Normal file
@ -0,0 +1,172 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class GpuUnaryOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
std::unique_ptr<tensorflow::Device> device_gpu(
|
||||
tensorflow::DeviceFactory::NewDevice("GPU", {},
|
||||
"/job:a/replica:0/task:0"));
|
||||
SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
|
||||
}
|
||||
|
||||
template <typename T, typename RT = T>
|
||||
void Run(std::initializer_list<int64> input_shape,
|
||||
std::initializer_list<T> input, const std::string op_name,
|
||||
RT (*expected_callback)(RT), bool expect_equal = true) {
|
||||
assert(std::accumulate(input_shape.begin(), input_shape.end(), 1,
|
||||
std::multiplies<int64>()) == input.size() &&
|
||||
"Expected input length to equal to shape's number of elements.");
|
||||
|
||||
TensorShape shape(input_shape);
|
||||
TF_ASSERT_OK(NodeDefBuilder("some_name", op_name)
|
||||
.Input(FakeInput(DataTypeToEnum<T>::v()))
|
||||
.Attr("T", DataTypeToEnum<T>::v())
|
||||
.Finalize(node_def()));
|
||||
|
||||
TF_ASSERT_OK(InitOp());
|
||||
AddInputFromArray<T>(shape, input);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected_tensor(allocator(), DataTypeToEnum<T>::value, shape);
|
||||
std::vector<T> expected;
|
||||
expected.reserve(input.size());
|
||||
for (const T& inp : input) {
|
||||
expected.push_back(
|
||||
static_cast<T>(expected_callback(static_cast<RT>(inp))));
|
||||
}
|
||||
test::FillValues<T>(&expected_tensor, expected);
|
||||
if (expect_equal) {
|
||||
test::ExpectEqual(expected_tensor, *GetOutput(0));
|
||||
} else {
|
||||
test::ExpectClose(expected_tensor, *GetOutput(0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(GpuUnaryOpTest, TanhFloat) {
|
||||
Run<float>(/*input_shape=*/{2, 7},
|
||||
/*input=*/
|
||||
{-18.0f, -9.0f, -1e-6f, -0.0f, 0.0f, 1e-6, 0.1f, 0.2f, 0.3f, 0.5f,
|
||||
0.7f, 0.9f, 9.0f, 18.0f},
|
||||
/*op_name=*/"Tanh",
|
||||
/*expected_callback=*/std::tanh,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, TanhDouble) {
|
||||
Run<double>(/*input_shape=*/{2, 7},
|
||||
/*input=*/
|
||||
{-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1, 0.2, 0.3, 0.5, 0.7,
|
||||
0.9, 9.0, 18.0},
|
||||
/*op_name=*/"Tanh",
|
||||
/*expected_callback=*/std::tanh,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, TanhHalf) {
|
||||
Run<Eigen::half, float>(
|
||||
/*input_shape=*/{2, 7},
|
||||
/*input=*/
|
||||
{static_cast<Eigen::half>(-18.0), static_cast<Eigen::half>(-9.0),
|
||||
static_cast<Eigen::half>(-1e-6), static_cast<Eigen::half>(-0.0),
|
||||
static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(1e-6),
|
||||
static_cast<Eigen::half>(0.1), static_cast<Eigen::half>(0.2),
|
||||
static_cast<Eigen::half>(0.3), static_cast<Eigen::half>(0.5),
|
||||
static_cast<Eigen::half>(0.7), static_cast<Eigen::half>(0.9),
|
||||
static_cast<Eigen::half>(9.0), static_cast<Eigen::half>(18.0)},
|
||||
/*op_name=*/"Tanh",
|
||||
/*expected_callback=*/std::tanh,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, AbsFloat) {
|
||||
Run<float>(
|
||||
/*input_shape=*/{2, 3},
|
||||
/*input=*/
|
||||
{-std::numeric_limits<float>::infinity(), -0.1f, -0.0f, 0.0f, 0.1f,
|
||||
std::numeric_limits<float>::infinity()},
|
||||
/*op_name=*/"Abs",
|
||||
/*expected_callback=*/std::abs,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, AbsDouble) {
|
||||
Run<double>(
|
||||
/*input_shape=*/{2, 3},
|
||||
/*input=*/
|
||||
{-std::numeric_limits<double>::infinity(), -0.1, -0.0, 0.0, 0.1,
|
||||
std::numeric_limits<double>::infinity()},
|
||||
/*op_name=*/"Abs",
|
||||
/*expected_callback=*/std::abs,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, AbsHalf) {
|
||||
Run<Eigen::half, float>(
|
||||
/*input_shape=*/{2, 3},
|
||||
/*input=*/
|
||||
{static_cast<Eigen::half>(-std::numeric_limits<double>::infinity()),
|
||||
static_cast<Eigen::half>(-0.1), static_cast<Eigen::half>(-0.0),
|
||||
static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(0.1),
|
||||
static_cast<Eigen::half>(std::numeric_limits<double>::infinity())},
|
||||
/*op_name=*/"Abs",
|
||||
/*expected_callback=*/std::abs,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, AbsInt32) {
|
||||
Run<int32>(
|
||||
/*input_shape=*/{2, 3},
|
||||
/*input=*/
|
||||
{std::numeric_limits<int32>::min(), std::numeric_limits<int32>::min() + 1,
|
||||
-1, 0, 1, std::numeric_limits<int32>::max()},
|
||||
/*op_name=*/"Abs",
|
||||
/*expected_callback=*/std::abs,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, AbsInt64) {
|
||||
Run<int64>(
|
||||
/*input_shape=*/{2, 3},
|
||||
/*input=*/
|
||||
{std::numeric_limits<int64>::min(), std::numeric_limits<int64>::min() + 1,
|
||||
-1, 0, 1, std::numeric_limits<int64>::max()},
|
||||
/*op_name=*/"Abs",
|
||||
/*expected_callback=*/std::abs,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // end namespace tensorflow
|
Loading…
Reference in New Issue
Block a user