Add missing tests for Neg kernel.
Also add tests for LogicalNot. PiperOrigin-RevId: 346955795 Change-Id: I79409b967ca0bf8b03499b7247c40fe245427850
This commit is contained in:
parent
06b99811c1
commit
32d5b268e2
@ -49,18 +49,20 @@ class GpuUnaryOpTest : public OpsTestBase {
|
||||
// function. In most cases it is enough to just provide the input type,
|
||||
// because all the types are the same.
|
||||
template <typename T, typename RT = T, typename OutT = T, typename ROutT = RT>
|
||||
void Run(std::vector<int64> input_shape, std::vector<T> input,
|
||||
void Run(std::vector<int64> input_shape, absl::InlinedVector<T, 10> input,
|
||||
const std::string op_name, ROutT (*expected_callback)(RT),
|
||||
bool expect_equal = true, bool add_tout = false,
|
||||
bool expect_buffer_reuse = true) {
|
||||
bool expect_buffer_reuse = true, bool add_t = true) {
|
||||
assert(std::accumulate(input_shape.begin(), input_shape.end(), 1,
|
||||
std::multiplies<int64>()) == input.size() &&
|
||||
"Expected input length to equal to shape's number of elements.");
|
||||
|
||||
TensorShape shape(input_shape);
|
||||
NodeDefBuilder builder("some_name", op_name);
|
||||
builder.Input(FakeInput(DataTypeToEnum<T>::v()))
|
||||
.Attr("T", DataTypeToEnum<T>::v());
|
||||
builder.Input(FakeInput(DataTypeToEnum<T>::v()));
|
||||
if (add_t) {
|
||||
builder.Attr("T", DataTypeToEnum<T>::v());
|
||||
}
|
||||
if (add_tout) {
|
||||
builder.Attr("Tout", DataTypeToEnum<OutT>::v());
|
||||
}
|
||||
@ -98,15 +100,15 @@ class GpuUnaryOpTest : public OpsTestBase {
|
||||
std::vector<int64> DefaultInputShape() { return std::vector<int64>{2, 7}; }
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> DefaultInput() {
|
||||
absl::InlinedVector<T, 10> DefaultInput() {
|
||||
return InputAsVector<T>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1, 0.2, 0.3,
|
||||
0.5, 0.7, 0.9, 9.0, 18.0});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<std::complex<T>> DefaultComplexInput() {
|
||||
absl::InlinedVector<std::complex<T>, 10> DefaultComplexInput() {
|
||||
auto input = DefaultInput<T>();
|
||||
std::vector<std::complex<T>> complex_input;
|
||||
absl::InlinedVector<std::complex<T>, 10> complex_input;
|
||||
for (T value : input) {
|
||||
complex_input.emplace_back(value, -value);
|
||||
}
|
||||
@ -114,21 +116,22 @@ class GpuUnaryOpTest : public OpsTestBase {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> DefaultInputGreaterThanZero() {
|
||||
absl::InlinedVector<T, 10> DefaultInputGreaterThanZero() {
|
||||
return InputAsVector<T>({18.0, 9.0, 1e-6, 1.0, 0.1, 1e-6, 0.1, 0.2, 0.3,
|
||||
0.5, 0.7, 0.9, 9.0, 18.0});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> DefaultInputGreaterOrEqualToZero() {
|
||||
absl::InlinedVector<T, 10> DefaultInputGreaterOrEqualToZero() {
|
||||
return InputAsVector<T>({18.0, 9.0, 1e-6, 0.0, 0.1, 1e-6, 0.1, 0.2, 0.3,
|
||||
0.5, 0.7, 0.9, 9.0, 18.0});
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
std::vector<T> InputAsVector(std::initializer_list<double> input) {
|
||||
std::vector<T> result;
|
||||
absl::InlinedVector<T, 10> InputAsVector(
|
||||
std::initializer_list<double> input) {
|
||||
absl::InlinedVector<T, 10> result;
|
||||
result.reserve(input.size());
|
||||
for (const auto& value : input) {
|
||||
result.push_back(static_cast<T>(value));
|
||||
@ -386,6 +389,19 @@ TEST_F(GpuUnaryOpTest, LogHalf) {
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
/// Test `tf.LogicalNot`
|
||||
|
||||
TEST_F(GpuUnaryOpTest, LogicalNot) {
|
||||
Run<bool, bool, bool, bool>(
|
||||
DefaultInputShape(), DefaultInput<bool>(),
|
||||
/*op_name=*/"LogicalNot",
|
||||
/*expected_callback=*/[](bool v) { return !v; },
|
||||
/*expect_equal=*/true,
|
||||
/*add_tout=*/false,
|
||||
/*expect_buffer_reuse=*/true,
|
||||
/*add_t=*/false);
|
||||
}
|
||||
|
||||
/// Test `tf.Neg`.
|
||||
|
||||
/// Reference implementation.
|
||||
@ -415,6 +431,27 @@ TEST_F(GpuUnaryOpTest, NegHalf) {
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, NegInt8) {
|
||||
Run<int8>(DefaultInputShape(), DefaultInput<int8>(),
|
||||
/*op_name=*/"Neg",
|
||||
/*expected_callback=*/expected_neg,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, NegInt16) {
|
||||
Run<int16>(DefaultInputShape(), DefaultInput<int16>(),
|
||||
/*op_name=*/"Neg",
|
||||
/*expected_callback=*/expected_neg,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, NegInt64) {
|
||||
Run<int64>(DefaultInputShape(), DefaultInput<int64>(),
|
||||
/*op_name=*/"Neg",
|
||||
/*expected_callback=*/expected_neg,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
/// Test `tf.Real`.
|
||||
|
||||
TEST_F(GpuUnaryOpTest, RealFloat) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user