[MLIR][KernelGen] Refactor unary kernel tests

Take advantage of the new kernel test util.

PiperOrigin-RevId: 348081858
Change-Id: Ib5972122bf5b9bb51c04c67228bf811370ae015c
This commit is contained in:
A. Unique TensorFlower 2020-12-17 13:24:17 -08:00 committed by TensorFlower Gardener
parent 5e6c86fdf2
commit 50c75992c7
3 changed files with 350 additions and 312 deletions

View File

@ -197,6 +197,7 @@ tf_cuda_cc_test(
"no_cuda_asan", # TODO(b/171341759): re-enable. "no_cuda_asan", # TODO(b/171341759): re-enable.
], ],
deps = [ deps = [
":gpu_ops_test_util",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:tensorflow", "//tensorflow/core:tensorflow",

View File

@ -48,10 +48,12 @@ absl::InlinedVector<T, 10> RepeatInputToMatchShape(
return result; return result;
} }
/// Helper functions to get default input values. /// Helper functions to get default input shapes.
TensorShape DefaultInputShape(); TensorShape DefaultInputShape();
/// Helper functions to get default input data.
template <typename T, template <typename T,
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value, std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
bool> = true> bool> = true>
@ -72,17 +74,10 @@ T DefaultScalarInput() {
return static_cast<T>(true); return static_cast<T>(true);
} }
template <typename T>
absl::InlinedVector<T, 10> InfZeroInput() {
return InputAsVector<T, double>({-std::numeric_limits<double>::infinity(),
-0.1, -0.0, 0.0, 0.1,
std::numeric_limits<float>::infinity()});
}
template <typename T, template <typename T,
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value, std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
bool> = true> bool> = true>
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) { absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name = "") {
// Only generate values less than the bitwidth of the data type. // Only generate values less than the bitwidth of the data type.
if (op_name == "LeftShift" || op_name == "RightShift") { if (op_name == "LeftShift" || op_name == "RightShift") {
auto max_shift = sizeof(T) * 8 - 1; auto max_shift = sizeof(T) * 8 - 1;
@ -96,16 +91,65 @@ absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
template <typename T, std::enable_if_t< template <typename T, std::enable_if_t<
llvm::is_one_of<T, Eigen::half, float, double>::value, llvm::is_one_of<T, Eigen::half, float, double>::value,
bool> = true> bool> = true>
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) { absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name = "") {
return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1, return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1,
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0}); 0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
} }
template <typename T, template <typename T,
std::enable_if_t<llvm::is_one_of<T, bool>::value, bool> = true> std::enable_if_t<llvm::is_one_of<T, bool>::value, bool> = true>
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) { absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name = "") {
return InputAsVector<T, bool>({true, false, true, true, false}); return InputAsVector<T, bool>({true, false, true, true, false});
} }
/// Helper functions to get more specific input data.
template <typename T, std::enable_if_t<
llvm::is_one_of<T, Eigen::half, float, double>::value,
bool> = true>
absl::InlinedVector<std::complex<T>, 10> DefaultComplexInput() {
auto input = test::DefaultInput<T>();
absl::InlinedVector<std::complex<T>, 10> complex_input;
for (T value : input) {
complex_input.emplace_back(value, -value);
}
return complex_input;
}
template <typename T, std::enable_if_t<
llvm::is_one_of<T, Eigen::half, float, double>::value,
bool> = true>
absl::InlinedVector<T, 10> NearZeroAndExtremeInput() {
return InputAsVector<T, double>({-std::numeric_limits<double>::infinity(),
-0.1, -0.0, 0.0, 0.1,
std::numeric_limits<float>::infinity()});
}
template <typename T,
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
bool> = true>
absl::InlinedVector<T, 10> NearZeroAndExtremeInput() {
return InputAsVector<T, T>({std::numeric_limits<T>::min(),
std::numeric_limits<T>::min() + 1, -1, 0, 1,
std::numeric_limits<T>::max()});
}
template <typename T, std::enable_if_t<
llvm::is_one_of<T, Eigen::half, float, double>::value,
bool> = true>
absl::InlinedVector<T, 10> DefaultInputGreaterThanZero() {
return test::InputAsVector<T, double>({18.0, 9.0, 1e-6, 1.0, 0.1, 1e-6, 0.1,
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
}
template <typename T, std::enable_if_t<
llvm::is_one_of<T, Eigen::half, float, double>::value,
bool> = true>
absl::InlinedVector<T, 10> DefaultInputGreaterOrEqualToZero() {
return test::InputAsVector<T, double>({18.0, 9.0, 1e-6, 0.0, 0.1, 1e-6, 0.1,
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
}
} // namespace test } // namespace test
} // namespace tensorflow } // namespace tensorflow

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_test_util.h"
#include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
@ -44,20 +45,10 @@ class GpuUnaryOpTest : public OpsTestBase {
SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu)); SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
} }
// 'T' is the input type, 'RT' is the input type for the callback function, template <typename T, typename OutT>
// 'OutT' is the output type, 'ROutT' is the output type for the callback void SetOpKernel(const std::string& op_name, const TensorShape& shape,
// function. In most cases it is enough to just provide the input type, const absl::InlinedVector<T, 10>& input, bool add_t,
// because all the types are the same. bool add_tout) {
template <typename T, typename RT = T, typename OutT = T, typename ROutT = RT>
void Run(std::vector<int64> input_shape, absl::InlinedVector<T, 10> input,
const std::string op_name, ROutT (*expected_callback)(RT),
bool expect_equal = true, bool add_tout = false,
bool expect_buffer_reuse = true, bool add_t = true) {
assert(std::accumulate(input_shape.begin(), input_shape.end(), 1,
std::multiplies<int64>()) == input.size() &&
"Expected input length to equal to shape's number of elements.");
TensorShape shape(input_shape);
NodeDefBuilder builder("some_name", op_name); NodeDefBuilder builder("some_name", op_name);
builder.Input(FakeInput(DataTypeToEnum<T>::v())); builder.Input(FakeInput(DataTypeToEnum<T>::v()));
if (add_t) { if (add_t) {
@ -70,6 +61,15 @@ class GpuUnaryOpTest : public OpsTestBase {
TF_ASSERT_OK(InitOp()); TF_ASSERT_OK(InitOp());
AddInputFromArray<T>(shape, input); AddInputFromArray<T>(shape, input);
}
template <typename T, typename OutT>
void RunAndExpectResult(const std::string& op_name, const TensorShape& shape,
const absl::InlinedVector<T, 10>& input,
const absl::InlinedVector<OutT, 10>& expected_output,
bool add_t, bool add_tout, bool expect_buffer_reuse,
bool expect_equal) {
SetOpKernel<T, OutT>(op_name, shape, input, add_t, add_tout);
TF_ASSERT_OK(RunOpKernel()); TF_ASSERT_OK(RunOpKernel());
// Assert buffer reuse if expected. // Assert buffer reuse if expected.
@ -81,13 +81,7 @@ class GpuUnaryOpTest : public OpsTestBase {
// Assert expected results. // Assert expected results.
Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value, shape); Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value, shape);
absl::InlinedVector<OutT, 14> expected; test::FillValues<OutT>(&expected_tensor, expected_output);
expected.reserve(input.size());
for (const T& inp : input) {
expected.push_back(
static_cast<OutT>(expected_callback(static_cast<RT>(inp))));
}
test::FillValues<OutT>(&expected_tensor, expected);
if (expect_equal) { if (expect_equal) {
test::ExpectEqual(expected_tensor, *GetOutput(0)); test::ExpectEqual(expected_tensor, *GetOutput(0));
} else { } else {
@ -95,241 +89,225 @@ class GpuUnaryOpTest : public OpsTestBase {
} }
} }
// Some helper functions to get default input values. template <typename T, typename BaselineT, typename OutT,
typename BaselineOutT>
void Test(const std::string op_name, const TensorShape& shape,
absl::InlinedVector<T, 10> input,
BaselineOutT (*baseline_callback)(BaselineT),
bool expect_equal = true, bool add_tout = false,
bool expect_buffer_reuse = true, bool add_t = true) {
// Prepare inputs and compute expected results.
auto repeated_input =
test::RepeatInputToMatchShape(input, shape.num_elements());
absl::InlinedVector<OutT, 10> expected_output =
ComputeExpectedOutput<T, BaselineT, OutT, BaselineOutT>(
repeated_input, baseline_callback);
std::vector<int64> DefaultInputShape() { return std::vector<int64>{2, 7}; } RunAndExpectResult<T, OutT>(op_name, shape, repeated_input, expected_output,
add_t, add_tout, expect_buffer_reuse,
template <typename T> expect_equal);
absl::InlinedVector<T, 10> DefaultInput() {
return InputAsVector<T>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1, 0.2, 0.3,
0.5, 0.7, 0.9, 9.0, 18.0});
}
template <typename T>
absl::InlinedVector<std::complex<T>, 10> DefaultComplexInput() {
auto input = DefaultInput<T>();
absl::InlinedVector<std::complex<T>, 10> complex_input;
for (T value : input) {
complex_input.emplace_back(value, -value);
}
return complex_input;
}
template <typename T>
absl::InlinedVector<T, 10> DefaultInputGreaterThanZero() {
return InputAsVector<T>({18.0, 9.0, 1e-6, 1.0, 0.1, 1e-6, 0.1, 0.2, 0.3,
0.5, 0.7, 0.9, 9.0, 18.0});
}
template <typename T>
absl::InlinedVector<T, 10> DefaultInputGreaterOrEqualToZero() {
return InputAsVector<T>({18.0, 9.0, 1e-6, 0.0, 0.1, 1e-6, 0.1, 0.2, 0.3,
0.5, 0.7, 0.9, 9.0, 18.0});
} }
private: private:
template <typename T> template <typename T, typename BaselineT, typename OutT,
absl::InlinedVector<T, 10> InputAsVector( typename BaselineOutT>
std::initializer_list<double> input) { absl::InlinedVector<OutT, 10> ComputeExpectedOutput(
absl::InlinedVector<T, 10> result; absl::InlinedVector<T, 10> input,
result.reserve(input.size()); BaselineOutT (*baseline_callback)(BaselineT)) {
for (const auto& value : input) { absl::InlinedVector<OutT, 10> expected_output;
result.push_back(static_cast<T>(value)); for (int i = 0; i < input.size(); i++) {
auto arg = static_cast<BaselineT>(input[i]);
auto result = static_cast<OutT>(baseline_callback(arg));
expected_output.push_back(result);
} }
return result; return expected_output;
} }
}; };
/// Test `tf.Abs`. /// Test `tf.Abs`.
TEST_F(GpuUnaryOpTest, AbsFloat) { TEST_F(GpuUnaryOpTest, AbsFloat) {
Run<float>( Test<float, float, float, float>(
/*input_shape=*/{2, 3}, /*op_name=*/"Abs", test::DefaultInputShape(),
/*input=*/ test::NearZeroAndExtremeInput<float>(),
{-std::numeric_limits<float>::infinity(), -0.1f, -0.0f, 0.0f, 0.1f, /*baseline_callback=*/std::abs,
std::numeric_limits<float>::infinity()},
/*op_name=*/"Abs",
/*expected_callback=*/std::abs,
/*expect_equal=*/true); /*expect_equal=*/true);
} }
TEST_F(GpuUnaryOpTest, AbsDouble) { TEST_F(GpuUnaryOpTest, AbsDouble) {
Run<double>( Test<double, double, double, double>(
/*input_shape=*/{2, 3}, /*op_name=*/"Abs", test::DefaultInputShape(),
/*input=*/ test::NearZeroAndExtremeInput<double>(),
{-std::numeric_limits<double>::infinity(), -0.1, -0.0, 0.0, 0.1, /*baseline_callback=*/std::abs,
std::numeric_limits<double>::infinity()},
/*op_name=*/"Abs",
/*expected_callback=*/std::abs,
/*expect_equal=*/true); /*expect_equal=*/true);
} }
TEST_F(GpuUnaryOpTest, AbsHalf) { TEST_F(GpuUnaryOpTest, AbsHalf) {
Run<Eigen::half, float>( Test<Eigen::half, float, Eigen::half, float>(
/*input_shape=*/{2, 3}, /*op_name=*/"Abs", test::DefaultInputShape(),
/*input=*/ test::NearZeroAndExtremeInput<Eigen::half>(),
{static_cast<Eigen::half>(-std::numeric_limits<double>::infinity()), /*baseline_callback=*/std::abs,
static_cast<Eigen::half>(-0.1), static_cast<Eigen::half>(-0.0),
static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(0.1),
static_cast<Eigen::half>(std::numeric_limits<double>::infinity())},
/*op_name=*/"Abs",
/*expected_callback=*/std::abs,
/*expect_equal=*/true); /*expect_equal=*/true);
} }
TEST_F(GpuUnaryOpTest, AbsInt32) { TEST_F(GpuUnaryOpTest, AbsInt32) {
Run<int32>( Test<int32, int32, int32, int32>(
/*input_shape=*/{2, 3}, /*op_name=*/"Abs", test::DefaultInputShape(),
/*input=*/ test::NearZeroAndExtremeInput<int32>(),
{std::numeric_limits<int32>::min(), std::numeric_limits<int32>::min() + 1, /*baseline_callback=*/std::abs,
-1, 0, 1, std::numeric_limits<int32>::max()},
/*op_name=*/"Abs",
/*expected_callback=*/std::abs,
/*expect_equal=*/true); /*expect_equal=*/true);
} }
TEST_F(GpuUnaryOpTest, AbsInt64) { TEST_F(GpuUnaryOpTest, AbsInt64) {
Run<int64>( Test<int64, int64, int64, int64>(
/*input_shape=*/{2, 3}, /*op_name=*/"Abs", test::DefaultInputShape(),
/*input=*/ test::NearZeroAndExtremeInput<int64>(),
{std::numeric_limits<int64>::min(), std::numeric_limits<int64>::min() + 1, /*baseline_callback=*/std::abs,
-1, 0, 1, std::numeric_limits<int64>::max()},
/*op_name=*/"Abs",
/*expected_callback=*/std::abs,
/*expect_equal=*/true); /*expect_equal=*/true);
} }
/// Test `tf.Ceil`. /// Test `tf.Ceil`.
TEST_F(GpuUnaryOpTest, CeilFloat) { TEST_F(GpuUnaryOpTest, CeilFloat) {
Run<float>(DefaultInputShape(), DefaultInput<float>(), Test<float, float, float, float>(
/*op_name=*/"Ceil", /*op_name=*/"Ceil", test::DefaultInputShape(),
/*expected_callback=*/std::ceil, test::DefaultInput<float>("Ceil"),
/*expect_equal=*/true); /*baseline_callback=*/std::ceil,
/*expect_equal=*/true);
} }
TEST_F(GpuUnaryOpTest, CeilDouble) { TEST_F(GpuUnaryOpTest, CeilDouble) {
Run<double>(DefaultInputShape(), DefaultInput<double>(), Test<double, double, double, double>(
/*op_name=*/"Ceil", /*op_name=*/"Ceil", test::DefaultInputShape(),
/*expected_callback=*/std::ceil, test::DefaultInput<double>(),
/*expect_equal=*/true); /*baseline_callback=*/std::ceil,
/*expect_equal=*/true);
} }
TEST_F(GpuUnaryOpTest, CeilHalf) { TEST_F(GpuUnaryOpTest, CeilHalf) {
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(), Test<Eigen::half, float, Eigen::half, float>(
/*op_name=*/"Ceil", /*op_name=*/"Ceil", test::DefaultInputShape(),
/*expected_callback=*/std::ceil, test::DefaultInput<Eigen::half>(),
/*expect_equal=*/true); /*baseline_callback=*/std::ceil,
/*expect_equal=*/true);
} }
/// Test `tf.Conj`. /// Test `tf.Conj`.
TEST_F(GpuUnaryOpTest, ConjFloat) { TEST_F(GpuUnaryOpTest, ConjFloat) {
Run<std::complex<float>, const std::complex<float>&, std::complex<float>, Test<std::complex<float>, const std::complex<float>&, std::complex<float>,
std::complex<float>>(DefaultInputShape(), DefaultComplexInput<float>(), std::complex<float>>(/*op_name=*/"Conj", test::DefaultInputShape(),
/*op_name=*/"Conj", test::DefaultComplexInput<float>(),
/*expected_callback=*/std::conj, /*baseline_callback=*/std::conj,
/*expect_equal=*/false,
/*add_tout=*/false,
/*expect_buffer_reuse=*/false);
}
TEST_F(GpuUnaryOpTest, ConjDouble) {
Run<std::complex<double>, const std::complex<double>&, std::complex<double>,
std::complex<double>>(DefaultInputShape(), DefaultComplexInput<double>(),
/*op_name=*/"Conj",
/*expected_callback=*/std::conj,
/*expect_equal=*/false, /*expect_equal=*/false,
/*add_tout=*/false, /*add_tout=*/false,
/*expect_buffer_reuse=*/false); /*expect_buffer_reuse=*/false);
} }
TEST_F(GpuUnaryOpTest, ConjDouble) {
Test<std::complex<double>, const std::complex<double>&, std::complex<double>,
std::complex<double>>(
/*op_name=*/"Conj", test::DefaultInputShape(),
test::DefaultComplexInput<double>(),
/*baseline_callback=*/std::conj,
/*expect_equal=*/false,
/*add_tout=*/false,
/*expect_buffer_reuse=*/false);
}
/// Test `tf.Cos`. /// Test `tf.Cos`.
TEST_F(GpuUnaryOpTest, CosFloat) { TEST_F(GpuUnaryOpTest, CosFloat) {
Run<float>(DefaultInputShape(), DefaultInput<float>(), Test<float, float, float, float>(
/*op_name=*/"Cos", /*op_name=*/"Cos", test::DefaultInputShape(), test::DefaultInput<float>(),
/*expected_callback=*/std::cos, /*baseline_callback=*/std::cos,
/*expect_equal=*/false); /*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, CosDouble) { TEST_F(GpuUnaryOpTest, CosDouble) {
Run<double>(DefaultInputShape(), DefaultInput<double>(), Test<double, double, double, double>(/*op_name=*/"Cos",
/*op_name=*/"Cos", test::DefaultInputShape(),
/*expected_callback=*/std::cos, test::DefaultInput<double>(),
/*expect_equal=*/false); /*baseline_callback=*/std::cos,
/*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, CosHalf) { TEST_F(GpuUnaryOpTest, CosHalf) {
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(), Test<Eigen::half, float, Eigen::half, float>(
/*op_name=*/"Cos", /*op_name=*/"Cos", test::DefaultInputShape(),
/*expected_callback=*/std::cos, test::DefaultInput<Eigen::half>(),
/*expect_equal=*/false); /*baseline_callback=*/std::cos,
/*expect_equal=*/false);
} }
/// Test `tf.Exp`. /// Test `tf.Exp`.
TEST_F(GpuUnaryOpTest, ExpFloat) { TEST_F(GpuUnaryOpTest, ExpFloat) {
Run<float>(DefaultInputShape(), DefaultInput<float>(), Test<float, float, float, float>(/*op_name=*/"Exp", test::DefaultInputShape(),
/*op_name=*/"Exp", test::DefaultInput<float>(),
/*expected_callback=*/std::exp, /*baseline_callback=*/std::exp,
/*expect_equal=*/false); /*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, ExpDouble) { TEST_F(GpuUnaryOpTest, ExpDouble) {
Run<double>(DefaultInputShape(), DefaultInput<double>(), Test<double, double, double, double>(/*op_name=*/"Exp",
/*op_name=*/"Exp", test::DefaultInputShape(),
/*expected_callback=*/std::exp, test::DefaultInput<double>(),
/*expect_equal=*/false); /*baseline_callback=*/std::exp,
/*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, ExpHalf) { TEST_F(GpuUnaryOpTest, ExpHalf) {
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(), Test<Eigen::half, float, Eigen::half, float>(
/*op_name=*/"Exp", /*op_name=*/"Exp", test::DefaultInputShape(),
/*expected_callback=*/std::exp, test::DefaultInput<Eigen::half>(),
/*expect_equal=*/false); /*baseline_callback=*/std::exp,
/*expect_equal=*/false);
} }
/// Test `tf.Floor`. /// Test `tf.Floor`.
TEST_F(GpuUnaryOpTest, FloorFloat) { TEST_F(GpuUnaryOpTest, FloorFloat) {
Run<float>(DefaultInputShape(), DefaultInput<float>(), Test<float, float, float, float>(/*op_name=*/"Floor",
/*op_name=*/"Floor", test::DefaultInputShape(),
/*expected_callback=*/std::floor, test::DefaultInput<float>(),
/*expect_equal=*/true); /*baseline_callback=*/std::floor,
/*expect_equal=*/true);
} }
TEST_F(GpuUnaryOpTest, FloorDouble) { TEST_F(GpuUnaryOpTest, FloorDouble) {
Run<double>(DefaultInputShape(), DefaultInput<double>(), Test<double, double, double, double>(/*op_name=*/"Floor",
/*op_name=*/"Floor", test::DefaultInputShape(),
/*expected_callback=*/std::floor, test::DefaultInput<double>(),
/*expect_equal=*/true); /*baseline_callback=*/std::floor,
/*expect_equal=*/true);
} }
TEST_F(GpuUnaryOpTest, FloorHalf) { TEST_F(GpuUnaryOpTest, FloorHalf) {
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(), Test<Eigen::half, float, Eigen::half, float>(
/*op_name=*/"Floor", /*op_name=*/"Floor", test::DefaultInputShape(),
/*expected_callback=*/std::floor, test::DefaultInput<Eigen::half>(),
/*expect_equal=*/true); /*baseline_callback=*/std::floor,
/*expect_equal=*/true);
} }
/// Test `tf.Imag`. /// Test `tf.Imag`.
TEST_F(GpuUnaryOpTest, ImagFloat) { TEST_F(GpuUnaryOpTest, ImagFloat) {
Run<std::complex<float>, const std::complex<float>&, float, float>( Test<std::complex<float>, const std::complex<float>&, float, float>(
DefaultInputShape(), DefaultComplexInput<float>(), /*op_name=*/"Imag", test::DefaultInputShape(),
/*op_name=*/"Imag", test::DefaultComplexInput<float>(),
/*expected_callback=*/std::imag, /*baseline_callback=*/std::imag,
/*expect_equal=*/false, /*expect_equal=*/false,
/*add_tout=*/true, /*add_tout=*/true,
/*expect_buffer_reuse=*/false); /*expect_buffer_reuse=*/false);
} }
TEST_F(GpuUnaryOpTest, ImagDouble) { TEST_F(GpuUnaryOpTest, ImagDouble) {
Run<std::complex<double>, const std::complex<double>&, double, double>( Test<std::complex<double>, const std::complex<double>&, double, double>(
DefaultInputShape(), DefaultComplexInput<double>(), /*op_name=*/"Imag", test::DefaultInputShape(),
/*op_name=*/"Imag", test::DefaultComplexInput<double>(),
/*expected_callback=*/std::imag, /*baseline_callback=*/std::imag,
/*expect_equal=*/false, /*expect_equal=*/false,
/*add_tout=*/true, /*add_tout=*/true,
/*expect_buffer_reuse=*/false); /*expect_buffer_reuse=*/false);
@ -338,64 +316,65 @@ TEST_F(GpuUnaryOpTest, ImagDouble) {
/// Test `tf.IsInf`. /// Test `tf.IsInf`.
// TODO(b/162575339): The tests currently still fails with CUDA_ILLEGAL_ADDRESS // TODO(b/162575339): The tests currently still fails with CUDA_ILLEGAL_ADDRESS
// when run with unranked kernels. // when Test with unranked kernels.
TEST_F(GpuUnaryOpTest, DISABLED_IsInfFloat) { TEST_F(GpuUnaryOpTest, DISABLED_IsInfFloat) {
Run<float, float, bool, bool>(DefaultInputShape(), DefaultInput<float>(), Test<float, float, bool, bool>(/*op_name=*/"IsInf", test::DefaultInputShape(),
/*op_name=*/"IsInf", test::DefaultInput<float>(),
/*expected_callback=*/std::isinf, /*baseline_callback=*/std::isinf,
/*expect_equal=*/true); /*expect_equal=*/true);
} }
TEST_F(GpuUnaryOpTest, DISABLED_IsInfDouble) { TEST_F(GpuUnaryOpTest, DISABLED_IsInfDouble) {
// Workaround for gcc bug, it would fail with "unresolved overloaded function // Workaround for gcc bug, it would fail with "unresolved overloaded function
// type" if passing std::isinf with type double. So we use type float for // type" if passing std::isinf with type double. So we use type float for
// comparing expected values. // comparing expected values.
Run<double, float, bool, bool>(DefaultInputShape(), DefaultInput<double>(), Test<double, float, bool, bool>(/*op_name=*/"IsInf",
/*op_name=*/"IsInf", test::DefaultInputShape(),
/*expected_callback=*/std::isinf, test::DefaultInput<double>(),
/*expect_equal=*/true); /*baseline_callback=*/std::isinf,
/*expect_equal=*/true);
} }
TEST_F(GpuUnaryOpTest, DISABLED_IsInfHalf) { TEST_F(GpuUnaryOpTest, DISABLED_IsInfHalf) {
Run<Eigen::half, float, bool, bool>(DefaultInputShape(), Test<Eigen::half, float, bool, bool>(/*op_name=*/"IsInf",
DefaultInput<Eigen::half>(), test::DefaultInputShape(),
/*op_name=*/"IsInf", test::DefaultInput<Eigen::half>(),
/*expected_callback=*/std::isinf, /*baseline_callback=*/std::isinf,
/*expect_equal=*/true); /*expect_equal=*/true);
} }
/// Test `tf.Log`. /// Test `tf.Log`.
TEST_F(GpuUnaryOpTest, LogFloat) { TEST_F(GpuUnaryOpTest, LogFloat) {
Run<float>(DefaultInputShape(), DefaultInputGreaterThanZero<float>(), Test<float, float, float, float>(/*op_name=*/"Log", test::DefaultInputShape(),
/*op_name=*/"Log", test::DefaultInputGreaterThanZero<float>(),
/*expected_callback=*/std::log, /*baseline_callback=*/std::log,
/*expect_equal=*/false); /*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, LogDouble) { TEST_F(GpuUnaryOpTest, LogDouble) {
Run<double>(DefaultInputShape(), DefaultInputGreaterThanZero<double>(), Test<double, double, double, double>(
/*op_name=*/"Log", /*op_name=*/"Log", test::DefaultInputShape(),
/*expected_callback=*/std::log, test::DefaultInputGreaterThanZero<double>(),
/*expect_equal=*/false); /*baseline_callback=*/std::log,
/*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, LogHalf) { TEST_F(GpuUnaryOpTest, LogHalf) {
Run<Eigen::half, float>(DefaultInputShape(), Test<Eigen::half, float, Eigen::half, float>(
/*input=*/ /*op_name=*/"Log", test::DefaultInputShape(),
DefaultInputGreaterThanZero<Eigen::half>(), test::DefaultInputGreaterThanZero<Eigen::half>(),
/*op_name=*/"Log", /*baseline_callback=*/std::log,
/*expected_callback=*/std::log, /*expect_equal=*/false);
/*expect_equal=*/false);
} }
/// Test `tf.LogicalNot` /// Test `tf.LogicalNot`
TEST_F(GpuUnaryOpTest, LogicalNot) { TEST_F(GpuUnaryOpTest, LogicalNot) {
Run<bool, bool, bool, bool>( Test<bool, bool, bool, bool>(
DefaultInputShape(), DefaultInput<bool>(), /*op_name=*/"LogicalNot", test::DefaultInputShape(),
/*op_name=*/"LogicalNot", test::DefaultInput<bool>(),
/*expected_callback=*/[](bool v) { return !v; }, /*baseline_callback=*/[](bool v) { return !v; },
/*expect_equal=*/true, /*expect_equal=*/true,
/*add_tout=*/false, /*add_tout=*/false,
/*expect_buffer_reuse=*/true, /*expect_buffer_reuse=*/true,
@ -406,69 +385,71 @@ TEST_F(GpuUnaryOpTest, LogicalNot) {
/// Reference implementation. /// Reference implementation.
template <typename T> template <typename T>
T expected_neg(T x) { T baseline_neg(T x) {
return -x; return -x;
} }
TEST_F(GpuUnaryOpTest, NegFloat) { TEST_F(GpuUnaryOpTest, NegFloat) {
Run<float>(DefaultInputShape(), DefaultInput<float>(), Test<float, float, float, float>(
/*op_name=*/"Neg", /*op_name=*/"Neg", test::DefaultInputShape(), test::DefaultInput<float>(),
/*expected_callback=*/expected_neg, /*baseline_callback=*/baseline_neg,
/*expect_equal=*/false); /*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, NegDouble) { TEST_F(GpuUnaryOpTest, NegDouble) {
Run<double>(DefaultInputShape(), DefaultInput<double>(), Test<double, double, double, double>(
/*op_name=*/"Neg", /*op_name=*/"Neg", test::DefaultInputShape(),
/*expected_callback=*/expected_neg, test::DefaultInput<double>(),
/*expect_equal=*/false); /*baseline_callback=*/baseline_neg,
/*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, NegHalf) { TEST_F(GpuUnaryOpTest, NegHalf) {
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(), Test<Eigen::half, float, Eigen::half, float>(
/*op_name=*/"Neg", /*op_name=*/"Neg", test::DefaultInputShape(),
/*expected_callback=*/expected_neg, test::DefaultInput<Eigen::half>(),
/*expect_equal=*/false); /*baseline_callback=*/baseline_neg,
/*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, NegInt8) { TEST_F(GpuUnaryOpTest, NegInt8) {
Run<int8>(DefaultInputShape(), DefaultInput<int8>(), Test<int8, int8, int8, int8>(
/*op_name=*/"Neg", /*op_name=*/"Neg", test::DefaultInputShape(), test::DefaultInput<int8>(),
/*expected_callback=*/expected_neg, /*baseline_callback=*/baseline_neg,
/*expect_equal=*/true); /*expect_equal=*/true);
} }
TEST_F(GpuUnaryOpTest, NegInt16) { TEST_F(GpuUnaryOpTest, NegInt16) {
Run<int16>(DefaultInputShape(), DefaultInput<int16>(), Test<int16, int16, int16, int16>(/*op_name=*/"Neg", test::DefaultInputShape(),
/*op_name=*/"Neg", test::DefaultInput<int16>(),
/*expected_callback=*/expected_neg, /*baseline_callback=*/baseline_neg,
/*expect_equal=*/true); /*expect_equal=*/true);
} }
TEST_F(GpuUnaryOpTest, NegInt64) { TEST_F(GpuUnaryOpTest, NegInt64) {
Run<int64>(DefaultInputShape(), DefaultInput<int64>(), Test<int64, int64, int64, int64>(/*op_name=*/"Neg", test::DefaultInputShape(),
/*op_name=*/"Neg", test::DefaultInput<int64>(),
/*expected_callback=*/expected_neg, /*baseline_callback=*/baseline_neg,
/*expect_equal=*/true); /*expect_equal=*/true);
} }
/// Test `tf.Real`. /// Test `tf.Real`.
TEST_F(GpuUnaryOpTest, RealFloat) { TEST_F(GpuUnaryOpTest, RealFloat) {
Run<std::complex<float>, const std::complex<float>&, float, float>( Test<std::complex<float>, const std::complex<float>&, float, float>(
DefaultInputShape(), DefaultComplexInput<float>(), /*op_name=*/"Real", test::DefaultInputShape(),
/*op_name=*/"Real", test::DefaultComplexInput<float>(),
/*expected_callback=*/std::real, /*baseline_callback=*/std::real,
/*expect_equal=*/false, /*expect_equal=*/false,
/*add_tout=*/true, /*add_tout=*/true,
/*expect_buffer_reuse=*/false); /*expect_buffer_reuse=*/false);
} }
TEST_F(GpuUnaryOpTest, RealDouble) { TEST_F(GpuUnaryOpTest, RealDouble) {
Run<std::complex<double>, const std::complex<double>&, double, double>( Test<std::complex<double>, const std::complex<double>&, double, double>(
DefaultInputShape(), DefaultComplexInput<double>(), /*op_name=*/"Real", test::DefaultInputShape(),
/*op_name=*/"Real", test::DefaultComplexInput<double>(),
/*expected_callback=*/std::real, /*baseline_callback=*/std::real,
/*expect_equal=*/false, /*expect_equal=*/false,
/*add_tout=*/true, /*add_tout=*/true,
/*expect_buffer_reuse=*/false); /*expect_buffer_reuse=*/false);
@ -478,141 +459,153 @@ TEST_F(GpuUnaryOpTest, RealDouble) {
/// Reference implementation. /// Reference implementation.
template <typename T> template <typename T>
T expected_rsqrt(T x) { T baseline_rsqrt(T x) {
return 1.0 / std::sqrt(x); return 1.0 / std::sqrt(x);
} }
TEST_F(GpuUnaryOpTest, RsqrtFloat) { TEST_F(GpuUnaryOpTest, RsqrtFloat) {
Run<float>(DefaultInputShape(), DefaultInputGreaterThanZero<float>(), Test<float, float, float, float>(/*op_name=*/"Rsqrt",
/*op_name=*/"Rsqrt", test::DefaultInputShape(),
/*expected_callback=*/expected_rsqrt, test::DefaultInputGreaterThanZero<float>(),
/*expect_equal=*/false); /*baseline_callback=*/baseline_rsqrt,
/*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, RsqrtDouble) { TEST_F(GpuUnaryOpTest, RsqrtDouble) {
Run<double>(DefaultInputShape(), DefaultInputGreaterThanZero<double>(), Test<double, double, double, double>(
/*op_name=*/"Rsqrt", /*op_name=*/"Rsqrt", test::DefaultInputShape(),
/*expected_callback=*/expected_rsqrt, test::DefaultInputGreaterThanZero<double>(),
/*expect_equal=*/false); /*baseline_callback=*/baseline_rsqrt,
/*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, RsqrtHalf) { TEST_F(GpuUnaryOpTest, RsqrtHalf) {
Run<Eigen::half, float>(DefaultInputShape(), Test<Eigen::half, float, Eigen::half, float>(
/*input=*/ /*op_name=*/"Rsqrt", test::DefaultInputShape(),
DefaultInputGreaterThanZero<Eigen::half>(), test::DefaultInputGreaterThanZero<Eigen::half>(),
/*op_name=*/"Rsqrt", /*baseline_callback=*/baseline_rsqrt,
/*expected_callback=*/expected_rsqrt, /*expect_equal=*/false);
/*expect_equal=*/false);
} }
/// Test `tf.Sign`. /// Test `tf.Sign`.
// Reference implementation // Reference implementation
template <typename T> template <typename T>
T expected_sign(T x) { T baseline_sign(T x) {
if (x == 0) return 0; if (x == 0) return 0;
if (x < 0) return -1; if (x < 0) return -1;
return 1; return 1;
} }
TEST_F(GpuUnaryOpTest, SignFloat) { TEST_F(GpuUnaryOpTest, SignFloat) {
Run<float>(DefaultInputShape(), DefaultInput<float>(), Test<float, float, float, float>(/*op_name=*/"Sign",
/*op_name=*/"Sign", test::DefaultInputShape(),
/*expected_callback=*/expected_sign, test::DefaultInput<float>(),
/*expect_equal=*/true); /*baseline_callback=*/baseline_sign,
/*expect_equal=*/true);
} }
TEST_F(GpuUnaryOpTest, SignDouble) { TEST_F(GpuUnaryOpTest, SignDouble) {
Run<double>(DefaultInputShape(), DefaultInput<double>(), Test<double, double, double, double>(/*op_name=*/"Sign",
/*op_name=*/"Sign", test::DefaultInputShape(),
/*expected_callback=*/expected_sign, test::DefaultInput<double>(),
/*expect_equal=*/true); /*baseline_callback=*/baseline_sign,
/*expect_equal=*/true);
} }
TEST_F(GpuUnaryOpTest, SignHalf) { TEST_F(GpuUnaryOpTest, SignHalf) {
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(), Test<Eigen::half, float, Eigen::half, float>(
/*op_name=*/"Sign", /*op_name=*/"Sign", test::DefaultInputShape(),
/*expected_callback=*/expected_sign, test::DefaultInput<Eigen::half>(),
// TODO(b/162577610): We should actually use true /*expected_callback=*/baseline_sign,
// here. This requires returning 0.0 for input -0.0. // TODO(b/162577610): We should actually use true
/*expect_equal=*/false); // here. This requires returning 0.0 for input -0.0.
/*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, SignInt64) { TEST_F(GpuUnaryOpTest, SignInt64) {
Run<int64>(DefaultInputShape(), DefaultInput<int64>(), Test<int64, int64, int64, int64>(
/*op_name=*/"Sign", /*op_name=*/"Sign", test::DefaultInputShape(),
/*expected_callback=*/expected_sign, test::DefaultInput<int64>(),
/*expect_equal=*/true); /*expected_callback=*/baseline_sign,
/*expect_equal=*/true);
} }
/// Test `tf.Sin`. /// Test `tf.Sin`.
TEST_F(GpuUnaryOpTest, SinFloat) { TEST_F(GpuUnaryOpTest, SinFloat) {
Run<float>(DefaultInputShape(), DefaultInput<float>(), Test<float, float, float, float>(/*op_name=*/"Sin", test::DefaultInputShape(),
/*op_name=*/"Sin", test::DefaultInput<float>(),
/*expected_callback=*/std::sin, /*baseline_callback=*/std::sin,
/*expect_equal=*/false); /*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, SinDouble) { TEST_F(GpuUnaryOpTest, SinDouble) {
Run<double>(DefaultInputShape(), DefaultInput<double>(), Test<double, double, double, double>(/*op_name=*/"Sin",
/*op_name=*/"Sin", test::DefaultInputShape(),
/*expected_callback=*/std::sin, test::DefaultInput<double>(),
/*expect_equal=*/false); /*baseline_callback=*/std::sin,
/*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, SinHalf) { TEST_F(GpuUnaryOpTest, SinHalf) {
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(), Test<Eigen::half, float, Eigen::half, float>(
/*op_name=*/"Sin", /*op_name=*/"Sin", test::DefaultInputShape(),
/*expected_callback=*/std::sin, test::DefaultInput<Eigen::half>(),
/*expect_equal=*/false); /*baseline_callback=*/std::sin,
/*expect_equal=*/false);
} }
/// Test `tf.Sqrt`. /// Test `tf.Sqrt`.
TEST_F(GpuUnaryOpTest, SqrtFloat) { TEST_F(GpuUnaryOpTest, SqrtFloat) {
Run<float>(DefaultInputShape(), DefaultInputGreaterOrEqualToZero<float>(), Test<float, float, float, float>(
/*op_name=*/"Sqrt", /*op_name=*/"Sqrt", test::DefaultInputShape(),
/*expected_callback=*/std::sqrt, test::DefaultInputGreaterOrEqualToZero<float>(),
/*expect_equal=*/false); /*baseline_callback=*/std::sqrt,
/*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, SqrtDouble) { TEST_F(GpuUnaryOpTest, SqrtDouble) {
Run<double>(DefaultInputShape(), DefaultInputGreaterOrEqualToZero<double>(), Test<double, double, double, double>(
/*op_name=*/"Sqrt", /*op_name=*/"Sqrt", test::DefaultInputShape(),
/*expected_callback=*/std::sqrt, test::DefaultInputGreaterOrEqualToZero<double>(),
/*expect_equal=*/false); /*baseline_callback=*/std::sqrt,
/*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, SqrtHalf) { TEST_F(GpuUnaryOpTest, SqrtHalf) {
Run<Eigen::half, float>(DefaultInputShape(), Test<Eigen::half, float, Eigen::half, float>(
DefaultInputGreaterOrEqualToZero<Eigen::half>(), /*op_name=*/"Sqrt", test::DefaultInputShape(),
/*op_name=*/"Sqrt", test::DefaultInputGreaterOrEqualToZero<Eigen::half>(),
/*expected_callback=*/std::sqrt, /*baseline_callback=*/std::sqrt,
/*expect_equal=*/false); /*expect_equal=*/false);
} }
/// Test `tf.Tanh`. /// Test `tf.Tanh`.
TEST_F(GpuUnaryOpTest, TanhFloat) { TEST_F(GpuUnaryOpTest, TanhFloat) {
Run<float>(DefaultInputShape(), DefaultInput<float>(), Test<float, float, float, float>(/*op_name=*/"Tanh",
/*op_name=*/"Tanh", test::DefaultInputShape(),
/*expected_callback=*/std::tanh, test::DefaultInput<float>(),
/*expect_equal=*/false); /*baseline_callback=*/std::tanh,
/*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, TanhDouble) { TEST_F(GpuUnaryOpTest, TanhDouble) {
Run<double>(DefaultInputShape(), DefaultInput<double>(), Test<double, double, double, double>(/*op_name=*/"Tanh",
/*op_name=*/"Tanh", test::DefaultInputShape(),
/*expected_callback=*/std::tanh, test::DefaultInput<double>(),
/*expect_equal=*/false); /*baseline_callback=*/std::tanh,
/*expect_equal=*/false);
} }
TEST_F(GpuUnaryOpTest, TanhHalf) { TEST_F(GpuUnaryOpTest, TanhHalf) {
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(), Test<Eigen::half, float, Eigen::half, float>(
/*op_name=*/"Tanh", /*op_name=*/"Tanh", test::DefaultInputShape(),
/*expected_callback=*/std::tanh, test::DefaultInput<Eigen::half>(),
/*expect_equal=*/false); /*baseline_callback=*/std::tanh,
/*expect_equal=*/false);
} }
} // namespace } // namespace