[KERNEL_GEN][CPU] Generalize unary tests to support CPU platform.
PiperOrigin-RevId: 355356652 Change-Id: I5414bdaadc0634211b1f8199db902c264deb3cce
This commit is contained in:
parent
b137fe2ad5
commit
0ee7258c0d
@ -254,6 +254,42 @@ tf_kernel_library(
|
||||
),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "base_ops_test",
|
||||
testonly = 1,
|
||||
srcs = ["base_ops_test.cc"],
|
||||
hdrs = ["base_ops_test.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "base_unary_ops_test",
|
||||
testonly = 1,
|
||||
hdrs = ["base_unary_ops_test.h"],
|
||||
deps = [
|
||||
":base_ops_test",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/framework:types_proto_cc",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "gpu_unary_ops_test",
|
||||
size = "small",
|
||||
@ -262,18 +298,10 @@ tf_cuda_cc_test(
|
||||
"no_cuda_asan", # TODO(b/171341759): re-enable.
|
||||
],
|
||||
deps = [
|
||||
":gpu_ops_test_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
":base_ops_test",
|
||||
":base_unary_ops_test",
|
||||
"//tensorflow/core/common_runtime:device",
|
||||
"//tensorflow/core/common_runtime:device_factory",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
],
|
||||
)
|
||||
|
||||
@ -285,7 +313,7 @@ tf_cuda_cc_test(
|
||||
"no_cuda_asan", # b/173033461
|
||||
],
|
||||
deps = [
|
||||
":gpu_ops_test_util",
|
||||
":base_ops_test",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:tensorflow",
|
||||
@ -304,25 +332,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_ops_test_util",
|
||||
testonly = 1,
|
||||
srcs = [
|
||||
"gpu_ops_test_util.cc",
|
||||
"gpu_ops_test_util.h",
|
||||
],
|
||||
hdrs = [
|
||||
"gpu_ops_test_util.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(b/160731748): Re-enable when it works again.
|
||||
# gen_kernel_library(
|
||||
# name = "bias_add",
|
||||
|
||||
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_test_util.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_ops_test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace test {
|
||||
@ -13,10 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_TEST_UTIL_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_TEST_UTIL_H_
|
||||
|
||||
#include <iostream>
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_OPS_TEST_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_OPS_TEST_H_
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
@ -56,29 +54,29 @@ TensorShape DefaultInputShape();
|
||||
|
||||
/// Helper functions to configure tests.
|
||||
|
||||
struct GpuOpsTestConfig {
|
||||
struct OpsTestConfig {
|
||||
bool add_t = true;
|
||||
bool add_tout = false;
|
||||
// Only used for gpu_unary_ops_test.
|
||||
bool expect_buffer_reuse = true;
|
||||
bool expect_strictly_equal = false;
|
||||
GpuOpsTestConfig ExpectStrictlyEqual() {
|
||||
GpuOpsTestConfig config = *this;
|
||||
OpsTestConfig ExpectStrictlyEqual() {
|
||||
OpsTestConfig config = *this;
|
||||
config.expect_strictly_equal = true;
|
||||
return config;
|
||||
}
|
||||
GpuOpsTestConfig NoBufferReuse() {
|
||||
GpuOpsTestConfig config = *this;
|
||||
OpsTestConfig NoBufferReuse() {
|
||||
OpsTestConfig config = *this;
|
||||
config.expect_buffer_reuse = false;
|
||||
return config;
|
||||
}
|
||||
GpuOpsTestConfig AddTout() {
|
||||
GpuOpsTestConfig config = *this;
|
||||
OpsTestConfig AddTout() {
|
||||
OpsTestConfig config = *this;
|
||||
config.add_tout = true;
|
||||
return config;
|
||||
}
|
||||
GpuOpsTestConfig NoT() {
|
||||
GpuOpsTestConfig config = *this;
|
||||
OpsTestConfig NoT() {
|
||||
OpsTestConfig config = *this;
|
||||
config.add_t = false;
|
||||
return config;
|
||||
}
|
||||
@ -231,4 +229,4 @@ absl::InlinedVector<T, 10> DefaultInput() {
|
||||
} // namespace test
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_GPU_OPS_TEST_UTIL_H_
|
||||
#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_OPS_TEST_H_
|
||||
167
tensorflow/core/kernels/mlir_generated/base_unary_ops_test.h
Normal file
167
tensorflow/core/kernels/mlir_generated/base_unary_ops_test.h
Normal file
@ -0,0 +1,167 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_UNARY_OPS_TEST_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_UNARY_OPS_TEST_H_
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_ops_test.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Base class for `UnaryOpsTest` fixture that has to be defined with a custom TF
|
||||
// device if you want to use the test macros in this file.
|
||||
class UnaryOpsTestBase : public OpsTestBase {
|
||||
protected:
|
||||
// This method should set the TF device, e.g. DEVICE_CPU, DEVICE_GPU.
|
||||
void SetUp() override = 0;
|
||||
|
||||
template <typename T, typename OutT>
|
||||
void SetOpKernel(const std::string& op_name, const TensorShape& shape,
|
||||
const absl::InlinedVector<T, 10>& input, bool add_t,
|
||||
bool add_tout) {
|
||||
NodeDefBuilder builder("some_name", op_name);
|
||||
builder.Input(FakeInput(DataTypeToEnum<T>::v()));
|
||||
if (add_t) {
|
||||
builder.Attr("T", DataTypeToEnum<T>::v());
|
||||
}
|
||||
if (add_tout) {
|
||||
builder.Attr("Tout", DataTypeToEnum<OutT>::v());
|
||||
}
|
||||
TF_ASSERT_OK(builder.Finalize(node_def()));
|
||||
|
||||
TF_ASSERT_OK(InitOp());
|
||||
AddInputFromArray<T>(shape, input);
|
||||
}
|
||||
|
||||
template <typename T, typename OutT>
|
||||
void RunAndExpectResult(const std::string& op_name, const TensorShape& shape,
|
||||
const absl::InlinedVector<T, 10>& input,
|
||||
const absl::InlinedVector<OutT, 10>& expected_output,
|
||||
const test::OpsTestConfig& config) {
|
||||
SetOpKernel<T, OutT>(op_name, shape, input, config.add_t, config.add_tout);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Assert buffer reuse if expected.
|
||||
if (config.expect_buffer_reuse) {
|
||||
void* arg_ptr_on_device = context_->input(0).data();
|
||||
void* result_ptr_on_device = context_->mutable_output(0)->data();
|
||||
ASSERT_EQ(arg_ptr_on_device, result_ptr_on_device);
|
||||
}
|
||||
|
||||
// Assert expected results.
|
||||
Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value, shape);
|
||||
test::FillValues<OutT>(&expected_tensor, expected_output);
|
||||
if (config.expect_strictly_equal) {
|
||||
test::ExpectEqual(expected_tensor, *GetOutput(0));
|
||||
} else {
|
||||
test::ExpectClose(expected_tensor, *GetOutput(0), kAbsoluteTolerance,
|
||||
kRelativeTolerance);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename BaselineT, typename OutT,
|
||||
typename BaselineOutT>
|
||||
void Test(const std::string& op_name, const TensorShape& shape,
|
||||
const absl::InlinedVector<T, 10>& input,
|
||||
BaselineOutT (*baseline_callback)(BaselineT),
|
||||
const test::OpsTestConfig& config) {
|
||||
// Prepare inputs and compute expected results.
|
||||
CHECK(input.size() <= shape.num_elements());
|
||||
auto repeated_input =
|
||||
test::RepeatInputToMatchShape(input, shape.num_elements());
|
||||
absl::InlinedVector<OutT, 10> expected_output =
|
||||
ComputeExpectedOutput<T, BaselineT, OutT, BaselineOutT>(
|
||||
repeated_input, baseline_callback);
|
||||
|
||||
RunAndExpectResult<T, OutT>(op_name, shape, repeated_input, expected_output,
|
||||
config);
|
||||
}
|
||||
|
||||
template <typename T, typename OutT>
|
||||
void TestEmptyShape(const std::string& op_name,
|
||||
const test::OpsTestConfig& config) {
|
||||
TensorShape shape{0, 1, 2};
|
||||
absl::InlinedVector<T, 10> empty_input = {};
|
||||
absl::InlinedVector<OutT, 10> expected_output = {};
|
||||
RunAndExpectResult<T, OutT>(op_name, shape, empty_input, expected_output,
|
||||
config);
|
||||
}
|
||||
|
||||
private:
|
||||
constexpr static double kAbsoluteTolerance = 0.001;
|
||||
constexpr static double kRelativeTolerance = 0.001;
|
||||
|
||||
template <typename T, typename BaselineT, typename OutT,
|
||||
typename BaselineOutT>
|
||||
absl::InlinedVector<OutT, 10> ComputeExpectedOutput(
|
||||
absl::InlinedVector<T, 10> input,
|
||||
BaselineOutT (*baseline_callback)(BaselineT)) {
|
||||
absl::InlinedVector<OutT, 10> expected_output;
|
||||
for (int i = 0; i < input.size(); i++) {
|
||||
auto arg = static_cast<BaselineT>(input[i]);
|
||||
auto result = static_cast<OutT>(baseline_callback(arg));
|
||||
expected_output.push_back(result);
|
||||
}
|
||||
return expected_output;
|
||||
}
|
||||
};
|
||||
|
||||
// Macros to easily generate common test cases. The macros use `UnaryOpsTest`
|
||||
// fixture in order to share implementation across GPU and CPU platform tests.
|
||||
// For specific inputs, please define your own test fixtures.
|
||||
#define GENERATE_DEFAULT_TEST(op_name, InT, OutT, baseline_callback, config) \
|
||||
GENERATE_DEFAULT_TEST_2(op_name, InT, InT, OutT, OutT, baseline_callback, \
|
||||
config)
|
||||
|
||||
#define GENERATE_DEFAULT_TEST_2(op_name, InT, BaselineT, OutT, BaselineOutT, \
|
||||
baseline_callback, config) \
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
|
||||
op_name, InT, BaselineT, OutT, BaselineOutT, \
|
||||
test::DefaultInput<NativeT>(), baseline_callback, config)
|
||||
|
||||
#define GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES( \
|
||||
op_name, InT, OutT, input_values, baseline_callback, config) \
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
|
||||
op_name, InT, InT, OutT, OutT, input_values, baseline_callback, config)
|
||||
|
||||
#define GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
|
||||
op_name, InT, BaselineT, OutT, BaselineOutT, input_values, \
|
||||
baseline_callback, config) \
|
||||
TEST_F(UnaryOpsTest, op_name##InT) { \
|
||||
using NativeT = EnumToDataType<InT>::Type; \
|
||||
using NativeBaselineT = EnumToDataType<BaselineT>::Type; \
|
||||
using NativeOutT = EnumToDataType<OutT>::Type; \
|
||||
using NativeBaselineOutT = EnumToDataType<BaselineOutT>::Type; \
|
||||
Test<NativeT, NativeBaselineT, NativeOutT, NativeBaselineOutT>( \
|
||||
#op_name, test::DefaultInputShape(), input_values, baseline_callback, \
|
||||
config); \
|
||||
} \
|
||||
TEST_F(UnaryOpsTest, op_name##InT##EmptyShape) { \
|
||||
using NativeT = EnumToDataType<InT>::Type; \
|
||||
using NativeOutT = EnumToDataType<OutT>::Type; \
|
||||
TestEmptyShape<NativeT, NativeOutT>(#op_name, config); \
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_UNARY_OPS_TEST_H_
|
||||
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_test_util.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_ops_test.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -76,7 +76,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
const TensorShape& expected_shape,
|
||||
const absl::InlinedVector<OutT, 10>& expected_output,
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
const test::OpsTestConfig& config) {
|
||||
SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
|
||||
config.add_t, config.add_tout);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
@ -98,7 +98,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
const absl::InlinedVector<T, 10>& lhs_input,
|
||||
const TensorShape& rhs_shape,
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
const test::OpsTestConfig& config) {
|
||||
SetOpKernel<T, OutT>(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input,
|
||||
config.add_t, config.add_tout);
|
||||
auto status = RunOpKernel();
|
||||
@ -112,7 +112,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
void TestIncompatibleShapes(const std::string& op_name,
|
||||
const absl::InlinedVector<T, 10>& lhs_input,
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
const test::OpsTestConfig& config) {
|
||||
// Prepare incompatibly shaped inputs.
|
||||
TensorShape lhs_shape{3};
|
||||
TensorShape rhs_shape{2};
|
||||
@ -131,7 +131,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
const absl::InlinedVector<T, 10>& lhs_input,
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
const test::OpsTestConfig& config) {
|
||||
// Prepare inputs.
|
||||
int input_size = shape.num_elements();
|
||||
CHECK(lhs_input.size() <= input_size && rhs_input.size() <= input_size &&
|
||||
@ -164,7 +164,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
const TensorShape& other_shape,
|
||||
const absl::InlinedVector<T, 10>& other_input,
|
||||
BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
const test::OpsTestConfig& config) {
|
||||
// Prepare inputs.
|
||||
TensorShape scalar_shape{};
|
||||
CHECK(other_input.size() <= other_shape.num_elements() &&
|
||||
@ -197,7 +197,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
BaselineOutT (*baseline_callback)(BaselineT,
|
||||
BaselineT),
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
const test::OpsTestConfig& config) {
|
||||
// Prepare inputs.
|
||||
TensorShape lhs_shape{1};
|
||||
TensorShape rhs_shape{6};
|
||||
@ -226,7 +226,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
BaselineOutT (*baseline_callback)(BaselineT,
|
||||
BaselineT),
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
const test::OpsTestConfig& config) {
|
||||
// Prepare inputs.
|
||||
TensorShape lhs_shape{3};
|
||||
TensorShape rhs_shape{2, 3};
|
||||
@ -254,7 +254,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
const absl::InlinedVector<T, 10>& lhs_input,
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
BaselineOutT (*baseline_callback)(BaselineT, BaselineT),
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
const test::OpsTestConfig& config) {
|
||||
// Prepare inputs.
|
||||
TensorShape lhs_shape{2, 1};
|
||||
TensorShape rhs_shape{3};
|
||||
@ -282,7 +282,7 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
void TestEmptyShapeBroadcasting(const std::string& op_name,
|
||||
const absl::InlinedVector<T, 10>& lhs_input,
|
||||
const absl::InlinedVector<T, 10>& rhs_input,
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
const test::OpsTestConfig& config) {
|
||||
// Prepare inputs.
|
||||
TensorShape lhs_shape{2, 0, 1};
|
||||
TensorShape rhs_shape{2, 0, 5};
|
||||
@ -362,13 +362,13 @@ class GpuBinaryOpTest : public OpsTestBase {
|
||||
GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, T, OutT, OutT, \
|
||||
test::DefaultInput<T>(), test::DefaultInput<T>(), \
|
||||
baseline_callback, \
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
#define GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES( \
|
||||
op_name, test_name, T, OutT, lhs_input, rhs_input, baseline_callback) \
|
||||
GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, T, OutT, OutT, lhs_input, \
|
||||
rhs_input, baseline_callback, \
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
/// Test `tf.AddV2`.
|
||||
|
||||
@ -415,14 +415,14 @@ TEST_F(GpuBinaryOpTest, Atan2FloatSpecialCases) {
|
||||
"Atan2", /*shape=*/{20},
|
||||
test::InputAsVector<float>({1, 1, 1, 0, -1, -1, -1, 0}),
|
||||
test::InputAsVector<float>({1, 0, -1, -1, -1, 0, 1, 1}), std::atan2,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual());
|
||||
test::OpsTestConfig().ExpectStrictlyEqual());
|
||||
}
|
||||
TEST_F(GpuBinaryOpTest, Atan2DoubleSpecialCases) {
|
||||
TestEqualShapes<double, double, double, double>(
|
||||
"Atan2", /*shape=*/{20},
|
||||
test::InputAsVector<double>({1, 1, 1, 0, -1, -1, -1, 0}),
|
||||
test::InputAsVector<double>({1, 0, -1, -1, -1, 0, 1, 1}), std::atan2,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual());
|
||||
test::OpsTestConfig().ExpectStrictlyEqual());
|
||||
}
|
||||
|
||||
/// Test `tf.BitwiseAnd`.
|
||||
@ -480,17 +480,17 @@ std::complex<T> baseline_complex(T lhs, T rhs) {
|
||||
return std::complex<T>(lhs, rhs);
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TESTS_2(
|
||||
Complex,
|
||||
/*test_name=*/C64, float, float, std::complex<float>, std::complex<float>,
|
||||
test::DefaultInput<float>(), test::DefaultInput<float>(), baseline_complex,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().AddTout())
|
||||
GENERATE_DEFAULT_TESTS_2(
|
||||
Complex,
|
||||
/*test_name=*/C128, double, double, std::complex<double>,
|
||||
std::complex<double>, test::DefaultInput<double>(),
|
||||
test::DefaultInput<double>(), baseline_complex,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().AddTout())
|
||||
GENERATE_DEFAULT_TESTS_2(Complex,
|
||||
/*test_name=*/C64, float, float, std::complex<float>,
|
||||
std::complex<float>, test::DefaultInput<float>(),
|
||||
test::DefaultInput<float>(), baseline_complex,
|
||||
test::OpsTestConfig().ExpectStrictlyEqual().AddTout())
|
||||
GENERATE_DEFAULT_TESTS_2(Complex,
|
||||
/*test_name=*/C128, double, double,
|
||||
std::complex<double>, std::complex<double>,
|
||||
test::DefaultInput<double>(),
|
||||
test::DefaultInput<double>(), baseline_complex,
|
||||
test::OpsTestConfig().ExpectStrictlyEqual().AddTout())
|
||||
|
||||
/// Test `tf.Div`.
|
||||
|
||||
@ -667,7 +667,7 @@ GENERATE_DEFAULT_TESTS_2(LogicalAnd, /*test_name=*/Bool, /*T=*/bool,
|
||||
/*BaselineT=*/bool, /*OutT=*/bool,
|
||||
/*BaselineOutT=*/bool, test::DefaultInput<bool>(),
|
||||
test::DefaultInput<bool>(), baseline_logical_and,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoT())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual().NoT())
|
||||
|
||||
/// Test `tf.LogicalOr`.
|
||||
|
||||
@ -677,7 +677,7 @@ GENERATE_DEFAULT_TESTS_2(LogicalOr, /*test_name=*/Bool, /*T=*/bool,
|
||||
/*BaselineT=*/bool, /*OutT=*/bool,
|
||||
/*BaselineOutT=*/bool, test::DefaultInput<bool>(),
|
||||
test::DefaultInput<bool>(), baseline_logical_or,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoT())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual().NoT())
|
||||
|
||||
/// Test `tf.Maximum`.
|
||||
|
||||
|
||||
@ -13,31 +13,17 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <functional>
|
||||
#include <initializer_list>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_test_util.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_ops_test.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_unary_ops_test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class GpuUnaryOpTest : public OpsTestBase {
|
||||
// Test fixture `UnaryOpsTest` that sets the TF device is expected by the TEST
|
||||
// macros below.
|
||||
class UnaryOpsTest : public UnaryOpsTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
std::unique_ptr<tensorflow::Device> device_gpu(
|
||||
@ -45,152 +31,26 @@ class GpuUnaryOpTest : public OpsTestBase {
|
||||
"/job:a/replica:0/task:0"));
|
||||
SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
|
||||
}
|
||||
|
||||
template <typename T, typename OutT>
|
||||
void SetOpKernel(const std::string& op_name, const TensorShape& shape,
|
||||
const absl::InlinedVector<T, 10>& input, bool add_t,
|
||||
bool add_tout) {
|
||||
NodeDefBuilder builder("some_name", op_name);
|
||||
builder.Input(FakeInput(DataTypeToEnum<T>::v()));
|
||||
if (add_t) {
|
||||
builder.Attr("T", DataTypeToEnum<T>::v());
|
||||
}
|
||||
if (add_tout) {
|
||||
builder.Attr("Tout", DataTypeToEnum<OutT>::v());
|
||||
}
|
||||
TF_ASSERT_OK(builder.Finalize(node_def()));
|
||||
|
||||
TF_ASSERT_OK(InitOp());
|
||||
AddInputFromArray<T>(shape, input);
|
||||
}
|
||||
|
||||
template <typename T, typename OutT>
|
||||
void RunAndExpectResult(const std::string& op_name, const TensorShape& shape,
|
||||
const absl::InlinedVector<T, 10>& input,
|
||||
const absl::InlinedVector<OutT, 10>& expected_output,
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
SetOpKernel<T, OutT>(op_name, shape, input, config.add_t, config.add_tout);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Assert buffer reuse if expected.
|
||||
if (config.expect_buffer_reuse) {
|
||||
void* arg_ptr_on_device = context_->input(0).data();
|
||||
void* result_ptr_on_device = context_->mutable_output(0)->data();
|
||||
ASSERT_EQ(arg_ptr_on_device, result_ptr_on_device);
|
||||
}
|
||||
|
||||
// Assert expected results.
|
||||
Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value, shape);
|
||||
test::FillValues<OutT>(&expected_tensor, expected_output);
|
||||
if (config.expect_strictly_equal) {
|
||||
test::ExpectEqual(expected_tensor, *GetOutput(0));
|
||||
} else {
|
||||
test::ExpectClose(expected_tensor, *GetOutput(0), kAbsoluteTolerance,
|
||||
kRelativeTolerance);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename BaselineT, typename OutT,
|
||||
typename BaselineOutT>
|
||||
void Test(const std::string& op_name, const TensorShape& shape,
|
||||
const absl::InlinedVector<T, 10>& input,
|
||||
BaselineOutT (*baseline_callback)(BaselineT),
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
// Prepare inputs and compute expected results.
|
||||
CHECK(input.size() <= shape.num_elements());
|
||||
auto repeated_input =
|
||||
test::RepeatInputToMatchShape(input, shape.num_elements());
|
||||
absl::InlinedVector<OutT, 10> expected_output =
|
||||
ComputeExpectedOutput<T, BaselineT, OutT, BaselineOutT>(
|
||||
repeated_input, baseline_callback);
|
||||
|
||||
RunAndExpectResult<T, OutT>(op_name, shape, repeated_input, expected_output,
|
||||
config);
|
||||
}
|
||||
|
||||
template <typename T, typename OutT>
|
||||
void TestEmptyShape(const std::string& op_name,
|
||||
const test::GpuOpsTestConfig& config) {
|
||||
TensorShape shape{0, 1, 2};
|
||||
absl::InlinedVector<T, 10> empty_input = {};
|
||||
absl::InlinedVector<OutT, 10> expected_output = {};
|
||||
RunAndExpectResult<T, OutT>(op_name, shape, empty_input, expected_output,
|
||||
config);
|
||||
}
|
||||
|
||||
private:
|
||||
constexpr static double kAbsoluteTolerance = 0.001;
|
||||
constexpr static double kRelativeTolerance = 0.001;
|
||||
|
||||
template <typename T, typename BaselineT, typename OutT,
|
||||
typename BaselineOutT>
|
||||
absl::InlinedVector<OutT, 10> ComputeExpectedOutput(
|
||||
absl::InlinedVector<T, 10> input,
|
||||
BaselineOutT (*baseline_callback)(BaselineT)) {
|
||||
absl::InlinedVector<OutT, 10> expected_output;
|
||||
for (int i = 0; i < input.size(); i++) {
|
||||
auto arg = static_cast<BaselineT>(input[i]);
|
||||
auto result = static_cast<OutT>(baseline_callback(arg));
|
||||
expected_output.push_back(result);
|
||||
}
|
||||
return expected_output;
|
||||
}
|
||||
};
|
||||
|
||||
// Macros to easily generate common test cases. For specific inputs, please
|
||||
// define your own test fixtures.
|
||||
|
||||
#define GENERATE_DEFAULT_TEST(op_name, InT, OutT, baseline_callback, config) \
|
||||
GENERATE_DEFAULT_TEST_2(op_name, InT, InT, OutT, OutT, baseline_callback, \
|
||||
config)
|
||||
|
||||
#define GENERATE_DEFAULT_TEST_2(op_name, InT, BaselineT, OutT, BaselineOutT, \
|
||||
baseline_callback, config) \
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
|
||||
op_name, InT, BaselineT, OutT, BaselineOutT, \
|
||||
test::DefaultInput<NativeT>(), baseline_callback, config)
|
||||
|
||||
#define GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES( \
|
||||
op_name, InT, OutT, input_values, baseline_callback, config) \
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
|
||||
op_name, InT, InT, OutT, OutT, input_values, baseline_callback, config)
|
||||
|
||||
#define GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \
|
||||
op_name, InT, BaselineT, OutT, BaselineOutT, input_values, \
|
||||
baseline_callback, config) \
|
||||
TEST_F(GpuUnaryOpTest, op_name##InT) { \
|
||||
using NativeT = EnumToDataType<InT>::Type; \
|
||||
using NativeBaselineT = EnumToDataType<BaselineT>::Type; \
|
||||
using NativeOutT = EnumToDataType<OutT>::Type; \
|
||||
using NativeBaselineOutT = EnumToDataType<BaselineOutT>::Type; \
|
||||
Test<NativeT, NativeBaselineT, NativeOutT, NativeBaselineOutT>( \
|
||||
#op_name, test::DefaultInputShape(), input_values, baseline_callback, \
|
||||
config); \
|
||||
} \
|
||||
TEST_F(GpuUnaryOpTest, op_name##InT##EmptyShape) { \
|
||||
using NativeT = EnumToDataType<InT>::Type; \
|
||||
using NativeOutT = EnumToDataType<OutT>::Type; \
|
||||
TestEmptyShape<NativeT, NativeOutT>(#op_name, config); \
|
||||
}
|
||||
|
||||
/// Test `tf.Abs`.
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Abs, DT_FLOAT, DT_FLOAT, test::NearZeroAndExtremeInput<float>(), std::abs,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Abs, DT_DOUBLE, DT_DOUBLE, test::NearZeroAndExtremeInput<double>(),
|
||||
std::abs, test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
std::abs, test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
Abs, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||
test::NearZeroAndExtremeInput<Eigen::half>(), std::abs,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Abs, DT_INT64, DT_INT64, test::NearZeroAndExtremeInput<int64>(), std::abs,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
/// Test `tf.Acos`.
|
||||
|
||||
@ -198,21 +58,21 @@ GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
// fails comparison for equality.
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Acos, DT_FLOAT, DT_FLOAT, test::DefaultInputBetweenZeroAndOne<float>(),
|
||||
std::acos, test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
std::acos, test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Acos, DT_DOUBLE, DT_DOUBLE, test::DefaultInputBetweenZeroAndOne<double>(),
|
||||
std::acos, test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
std::acos, test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
/// Test `tf.Acosh`.
|
||||
|
||||
// TODO(herhut): Give this better input once TF testing also supports NaN.
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Acosh, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterEqualOne<float>(),
|
||||
std::acosh, test::GpuOpsTestConfig())
|
||||
std::acosh, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Acosh, DT_DOUBLE, DT_DOUBLE, test::DefaultInputGreaterEqualOne<double>(),
|
||||
std::acosh, test::GpuOpsTestConfig())
|
||||
std::acosh, test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Angle`.
|
||||
|
||||
@ -225,14 +85,14 @@ GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Angle, DT_COMPLEX64, DT_FLOAT,
|
||||
test::ComplexInputFromValues<std::complex<float>>(
|
||||
test::DefaultInputNonZero<float>(), test::DefaultInputNonZero<float>()),
|
||||
baseline_angle, test::GpuOpsTestConfig().AddTout().NoBufferReuse())
|
||||
baseline_angle, test::OpsTestConfig().AddTout().NoBufferReuse())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Angle, DT_COMPLEX128, DT_DOUBLE,
|
||||
test::ComplexInputFromValues<std::complex<double>>(
|
||||
test::DefaultInputNonZero<double>(),
|
||||
test::DefaultInputNonZero<double>()),
|
||||
baseline_angle, test::GpuOpsTestConfig().AddTout().NoBufferReuse())
|
||||
baseline_angle, test::OpsTestConfig().AddTout().NoBufferReuse())
|
||||
|
||||
/// Test `tf.Asin`.
|
||||
|
||||
@ -240,47 +100,47 @@ GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
// fails comparison for equality.
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Asin, DT_FLOAT, DT_FLOAT, test::DefaultInputBetweenZeroAndOne<float>(),
|
||||
std::asin, test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
std::asin, test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Asin, DT_DOUBLE, DT_DOUBLE, test::DefaultInputBetweenZeroAndOne<double>(),
|
||||
std::asin, test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
std::asin, test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
/// Test `tf.Asinh`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Asinh, DT_FLOAT, DT_FLOAT, std::asinh,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Asinh, DT_DOUBLE, DT_DOUBLE, std::asinh,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Atan`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Atan, DT_FLOAT, DT_FLOAT, std::atan,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Atan, DT_DOUBLE, DT_DOUBLE, std::atan,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Atanh`.
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Atanh, DT_FLOAT, DT_FLOAT, test::DefaultInputBetweenZeroAndOne<float>(),
|
||||
std::atanh, test::GpuOpsTestConfig())
|
||||
std::atanh, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Atanh, DT_DOUBLE, DT_DOUBLE, test::DefaultInputBetweenZeroAndOne<double>(),
|
||||
std::atanh, test::GpuOpsTestConfig())
|
||||
std::atanh, test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Ceil`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Ceil, DT_FLOAT, DT_FLOAT, std::ceil,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Ceil, DT_DOUBLE, DT_DOUBLE, std::ceil,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST_2(Ceil, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::ceil,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
/// Test `tf.ComplexAbs`.
|
||||
|
||||
@ -290,11 +150,11 @@ typename T::value_type baseline_complex_abs(T x) {
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TEST(ComplexAbs, DT_COMPLEX64, DT_FLOAT, baseline_complex_abs,
|
||||
test::GpuOpsTestConfig().AddTout().NoBufferReuse())
|
||||
test::OpsTestConfig().AddTout().NoBufferReuse())
|
||||
|
||||
GENERATE_DEFAULT_TEST(ComplexAbs, DT_COMPLEX128, DT_DOUBLE,
|
||||
baseline_complex_abs,
|
||||
test::GpuOpsTestConfig().AddTout().NoBufferReuse())
|
||||
test::OpsTestConfig().AddTout().NoBufferReuse())
|
||||
|
||||
/// Test `tf.Conj`.
|
||||
|
||||
@ -304,29 +164,28 @@ T baseline_conj(T x) {
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TEST(Conj, DT_COMPLEX64, DT_COMPLEX64, baseline_conj,
|
||||
test::GpuOpsTestConfig().NoBufferReuse())
|
||||
test::OpsTestConfig().NoBufferReuse())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Conj, DT_COMPLEX128, DT_COMPLEX128, baseline_conj,
|
||||
test::GpuOpsTestConfig().NoBufferReuse())
|
||||
test::OpsTestConfig().NoBufferReuse())
|
||||
|
||||
/// Test `tf.Cos`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Cos, DT_FLOAT, DT_FLOAT, std::cos,
|
||||
test::GpuOpsTestConfig())
|
||||
GENERATE_DEFAULT_TEST(Cos, DT_FLOAT, DT_FLOAT, std::cos, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Cos, DT_DOUBLE, DT_DOUBLE, std::cos,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_2(Cos, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::cos,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Cosh`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Cosh, DT_FLOAT, DT_FLOAT, std::cosh,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Cosh, DT_DOUBLE, DT_DOUBLE, std::cosh,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Digamma`.
|
||||
|
||||
@ -351,16 +210,16 @@ constexpr std::initializer_list<double> kDigammaValues = {
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
Digamma, DT_FLOAT, DT_DOUBLE, DT_FLOAT, DT_DOUBLE,
|
||||
test::InputAsVector<float>(kDigammaValues), baseline_digamma,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Digamma, DT_DOUBLE, DT_DOUBLE, test::InputAsVector<double>(kDigammaValues),
|
||||
baseline_digamma, test::GpuOpsTestConfig())
|
||||
baseline_digamma, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
Digamma, DT_HALF, DT_DOUBLE, DT_HALF, DT_DOUBLE,
|
||||
test::InputAsVector<Eigen::half>(kDigammaValues), baseline_digamma,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Erf` and `tf.Erfc`.
|
||||
|
||||
@ -473,7 +332,7 @@ static constexpr std::initializer_list<float> kErfcF32Values = {
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(Erf, DT_DOUBLE, DT_DOUBLE,
|
||||
kErfcF64Values, std::erf,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
// Use specific values to cover the different intervals of the f32 erf
|
||||
// approximation.
|
||||
@ -487,54 +346,53 @@ GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
-0.1, 0.0, 0.1, 0.2, 0.5, 1.0, 1.1,
|
||||
1.2, 2.3, 3.4, 3.9, 4.0, 4.1, 6.7,
|
||||
8.9, 16.0, 100.0}),
|
||||
std::erf, test::GpuOpsTestConfig())
|
||||
std::erf, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_2(Erf, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::erf,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Erfc, DT_DOUBLE, DT_DOUBLE, test::InputAsVector<double>(kErfcF64Values),
|
||||
std::erfc, test::GpuOpsTestConfig())
|
||||
std::erfc, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Erfc, DT_FLOAT, DT_FLOAT, test::InputAsVector<float>(kErfcF32Values),
|
||||
std::erfc, test::GpuOpsTestConfig())
|
||||
std::erfc, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_2(Erfc, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::erfc,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Exp`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Exp, DT_FLOAT, DT_FLOAT, std::exp,
|
||||
test::GpuOpsTestConfig())
|
||||
GENERATE_DEFAULT_TEST(Exp, DT_FLOAT, DT_FLOAT, std::exp, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Exp, DT_DOUBLE, DT_DOUBLE, std::exp,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_2(Exp, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::exp,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Expm1`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Expm1, DT_FLOAT, DT_FLOAT, std::expm1,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Expm1, DT_DOUBLE, DT_DOUBLE, std::expm1,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_2(Expm1, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::expm1,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Floor`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Floor, DT_FLOAT, DT_FLOAT, std::floor,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Floor, DT_DOUBLE, DT_DOUBLE, std::floor,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST_2(Floor, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::floor,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
/// Test `tf.Imag`.
|
||||
|
||||
@ -544,10 +402,10 @@ typename T::value_type baseline_imag(T x) {
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TEST(Imag, DT_COMPLEX64, DT_FLOAT, baseline_imag,
|
||||
test::GpuOpsTestConfig().AddTout().NoBufferReuse())
|
||||
test::OpsTestConfig().AddTout().NoBufferReuse())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Imag, DT_COMPLEX128, DT_DOUBLE, baseline_imag,
|
||||
test::GpuOpsTestConfig().AddTout().NoBufferReuse())
|
||||
test::OpsTestConfig().AddTout().NoBufferReuse())
|
||||
|
||||
/// Test `tf.Invert`.
|
||||
|
||||
@ -558,40 +416,40 @@ T baseline_invert(T x) {
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TEST(Invert, DT_INT8, DT_INT8, baseline_invert,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Invert, DT_INT16, DT_INT16, baseline_invert,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Invert, DT_INT32, DT_INT32, baseline_invert,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Invert, DT_INT64, DT_INT64, baseline_invert,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
/// Test `tf.IsFinite`.
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
IsFinite, DT_FLOAT, DT_FLOAT, DT_BOOL, DT_BOOL,
|
||||
test::NearZeroAndExtremeInput<float>(), std::isfinite,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
test::OpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
IsFinite, DT_DOUBLE, DT_DOUBLE, DT_BOOL, DT_BOOL,
|
||||
test::NearZeroAndExtremeInput<double>(), std::isfinite,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
test::OpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
IsFinite, DT_HALF, DT_FLOAT, DT_BOOL, DT_BOOL,
|
||||
test::NearZeroAndExtremeInput<Eigen::half>(), std::isfinite,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
test::OpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
|
||||
/// Test `tf.IsInf`.
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
IsInf, DT_FLOAT, DT_FLOAT, DT_BOOL, DT_BOOL,
|
||||
test::NearZeroAndExtremeInput<float>(), std::isinf,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
test::OpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
|
||||
// Workaround for gcc bug, it would fail with "unresolved overloaded function
|
||||
// type" if passing std::isinf with type double. So we use type float for
|
||||
@ -599,19 +457,19 @@ GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
IsInf, DT_DOUBLE, DT_FLOAT, DT_BOOL, DT_BOOL,
|
||||
test::NearZeroAndExtremeInput<double>(), std::isinf,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
test::OpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
IsInf, DT_HALF, DT_FLOAT, DT_BOOL, DT_BOOL,
|
||||
test::NearZeroAndExtremeInput<Eigen::half>(), std::isinf,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
test::OpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
|
||||
/// Test `tf.IsNan`.
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
IsNan, DT_FLOAT, DT_FLOAT, DT_BOOL, DT_BOOL,
|
||||
test::NearZeroInfAndNanInput<float>(), std::isnan,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
test::OpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
|
||||
// Workaround for gcc bug, it would fail with "unresolved overloaded function
|
||||
// type" if passing std::isnan with type double. So we use type float for
|
||||
@ -619,12 +477,12 @@ GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
IsNan, DT_DOUBLE, DT_FLOAT, DT_BOOL, DT_BOOL,
|
||||
test::NearZeroInfAndNanInput<double>(), std::isnan,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
test::OpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
IsNan, DT_HALF, DT_FLOAT, DT_BOOL, DT_BOOL,
|
||||
test::NearZeroInfAndNanInput<Eigen::half>(), std::isnan,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
test::OpsTestConfig().ExpectStrictlyEqual().NoBufferReuse());
|
||||
|
||||
/// Test `tf.Lgamma`.
|
||||
|
||||
@ -661,53 +519,53 @@ static constexpr std::initializer_list<double> kLgammaValues = {
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Lgamma, DT_FLOAT, DT_FLOAT, test::InputAsVector<float>(kLgammaValues),
|
||||
std::lgamma, test::GpuOpsTestConfig())
|
||||
std::lgamma, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Lgamma, DT_DOUBLE, DT_DOUBLE, test::InputAsVector<double>(kLgammaValues),
|
||||
std::lgamma, test::GpuOpsTestConfig())
|
||||
std::lgamma, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
Lgamma, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||
test::InputAsVector<Eigen::half>(kLgammaValues), std::lgamma,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Log`.
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Log, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterThanZero<float>(),
|
||||
std::log, test::GpuOpsTestConfig())
|
||||
std::log, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Log, DT_DOUBLE, DT_DOUBLE, test::DefaultInputGreaterThanZero<double>(),
|
||||
std::log, test::GpuOpsTestConfig())
|
||||
std::log, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
Log, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||
test::DefaultInputGreaterThanZero<Eigen::half>(), std::log,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Log1p`.
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Log1p, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterThanZero<float>(),
|
||||
std::log1p, test::GpuOpsTestConfig())
|
||||
std::log1p, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Log1p, DT_DOUBLE, DT_DOUBLE, test::DefaultInputGreaterThanZero<double>(),
|
||||
std::log1p, test::GpuOpsTestConfig())
|
||||
std::log1p, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
Log1p, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||
test::DefaultInputGreaterThanZero<Eigen::half>(), std::log1p,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.LogicalNot`
|
||||
|
||||
bool baseline_logical_not(bool x) { return !x; }
|
||||
|
||||
GENERATE_DEFAULT_TEST(LogicalNot, DT_BOOL, DT_BOOL, baseline_logical_not,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual().NoT())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual().NoT())
|
||||
|
||||
/// Test `tf.Neg`.
|
||||
|
||||
@ -718,22 +576,22 @@ T baseline_neg(T x) {
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TEST(Neg, DT_FLOAT, DT_FLOAT, baseline_neg,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Neg, DT_DOUBLE, DT_DOUBLE, baseline_neg,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST_2(Neg, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, baseline_neg,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Neg, DT_INT8, DT_INT8, baseline_neg,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Neg, DT_INT16, DT_INT16, baseline_neg,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Neg, DT_INT64, DT_INT64, baseline_neg,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
/// Test `tf.Real`.
|
||||
|
||||
@ -743,10 +601,10 @@ typename T::value_type baseline_real(T x) {
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TEST(Real, DT_COMPLEX64, DT_FLOAT, baseline_real,
|
||||
test::GpuOpsTestConfig().AddTout().NoBufferReuse())
|
||||
test::OpsTestConfig().AddTout().NoBufferReuse())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Real, DT_COMPLEX128, DT_DOUBLE, baseline_real,
|
||||
test::GpuOpsTestConfig().AddTout().NoBufferReuse())
|
||||
test::OpsTestConfig().AddTout().NoBufferReuse())
|
||||
|
||||
/// Test `tf.Rsqrt`.
|
||||
|
||||
@ -758,16 +616,16 @@ T baseline_rsqrt(T x) {
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Rsqrt, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterThanZero<float>(),
|
||||
baseline_rsqrt, test::GpuOpsTestConfig())
|
||||
baseline_rsqrt, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Rsqrt, DT_DOUBLE, DT_DOUBLE, test::DefaultInputGreaterThanZero<double>(),
|
||||
baseline_rsqrt, test::GpuOpsTestConfig())
|
||||
baseline_rsqrt, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
Rsqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||
test::DefaultInputGreaterThanZero<Eigen::half>(), baseline_rsqrt,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Sign`.
|
||||
|
||||
@ -780,75 +638,73 @@ T baseline_sign(T x) {
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TEST(Sign, DT_FLOAT, DT_FLOAT, baseline_sign,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Sign, DT_DOUBLE, DT_DOUBLE, baseline_sign,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
// TODO(b/162577610): We should actually use ExpectStrictlyEqual()
|
||||
// here. This requires returning 0.0 for input -0.0.
|
||||
GENERATE_DEFAULT_TEST_2(Sign, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||
baseline_sign, test::GpuOpsTestConfig())
|
||||
baseline_sign, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Sign, DT_INT64, DT_INT64, baseline_sign,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
/// Test `tf.Sin`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Sin, DT_FLOAT, DT_FLOAT, std::sin,
|
||||
test::GpuOpsTestConfig())
|
||||
GENERATE_DEFAULT_TEST(Sin, DT_FLOAT, DT_FLOAT, std::sin, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Sin, DT_DOUBLE, DT_DOUBLE, std::sin,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_2(Sin, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::sin,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Sinh`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Sinh, DT_FLOAT, DT_FLOAT, std::sinh,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Sinh, DT_DOUBLE, DT_DOUBLE, std::sinh,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Sqrt`.
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Sqrt, DT_FLOAT, DT_FLOAT, test::DefaultInputGreaterOrEqualToZero<float>(),
|
||||
std::sqrt, test::GpuOpsTestConfig())
|
||||
std::sqrt, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
|
||||
Sqrt, DT_DOUBLE, DT_DOUBLE,
|
||||
test::DefaultInputGreaterOrEqualToZero<double>(), std::sqrt,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2(
|
||||
Sqrt, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT,
|
||||
test::DefaultInputGreaterOrEqualToZero<Eigen::half>(), std::sqrt,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Tan`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Tan, DT_FLOAT, DT_FLOAT, std::tan,
|
||||
test::GpuOpsTestConfig())
|
||||
GENERATE_DEFAULT_TEST(Tan, DT_FLOAT, DT_FLOAT, std::tan, test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Tan, DT_DOUBLE, DT_DOUBLE, std::tan,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_2(Tan, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::tan,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Tanh`.
|
||||
|
||||
GENERATE_DEFAULT_TEST(Tanh, DT_FLOAT, DT_FLOAT, std::tanh,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST(Tanh, DT_DOUBLE, DT_DOUBLE, std::tanh,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
GENERATE_DEFAULT_TEST_2(Tanh, DT_HALF, DT_FLOAT, DT_HALF, DT_FLOAT, std::tanh,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
|
||||
/// Test `tf.Square`.
|
||||
|
||||
@ -858,13 +714,13 @@ T baseline_square(T inp) {
|
||||
}
|
||||
|
||||
GENERATE_DEFAULT_TEST(Square, DT_HALF, DT_HALF, baseline_square,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
GENERATE_DEFAULT_TEST(Square, DT_FLOAT, DT_FLOAT, baseline_square,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
GENERATE_DEFAULT_TEST(Square, DT_DOUBLE, DT_DOUBLE, baseline_square,
|
||||
test::GpuOpsTestConfig())
|
||||
test::OpsTestConfig())
|
||||
GENERATE_DEFAULT_TEST(Square, DT_INT64, DT_INT64, baseline_square,
|
||||
test::GpuOpsTestConfig().ExpectStrictlyEqual())
|
||||
test::OpsTestConfig().ExpectStrictlyEqual())
|
||||
|
||||
} // namespace
|
||||
} // end namespace tensorflow
|
||||
} // namespace tensorflow
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user