A. Unique TensorFlower ae06854dec Some const declarations changed to constexpr
PiperOrigin-RevId: 307514970
Change-Id: Ib27a5e36e23c8658e8bd517a06809eb831f452d8
2020-04-20 18:15:21 -07:00

233 lines
8.1 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include <vector>
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
// Tests the handling of the basic mathematics operations with F16 operands.
namespace xla {
namespace {
class HalfTestBase : public ClientLibraryTestBase {
protected:
const ErrorSpec error_spec_{0.001, 0.001};
// Number of elements in the input buffers.
static constexpr int kNumElements = 4;
};
using UnaryBuildFuncTy = std::function<void(const xla::XlaOp& src)>;
struct UnaryOpTestParam {
std::function<half(half)> compute_func;
UnaryBuildFuncTy build_func;
};
class UnaryOpTest : public HalfTestBase,
public ::testing::WithParamInterface<UnaryOpTestParam> {};
XLA_TEST_P(UnaryOpTest, Ops) {
std::vector<half> x({half(1.4), half(-2.3), half(3.2), half(-4.1), half(9.0),
half(42.0), half(-9.0), half(-100.0)});
XlaBuilder builder(TestName());
XlaOp x_opnd;
auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x",
&builder, &x_opnd);
std::function<half(half)> compute_func = GetParam().compute_func;
std::vector<half> expected;
for (int64 i = 0; i < x.size(); ++i) {
expected.push_back(compute_func(x[i]));
}
UnaryBuildFuncTy build_func = GetParam().build_func;
build_func(x_opnd);
ComputeAndCompareR1<half>(&builder, expected, {x_data.get()}, error_spec_);
}
half sign_imp(half value) {
const float x(std::move(value));
return half((x < .0) ? -1 : (x > .0));
}
half round_imp(half value) {
return half(std::round(static_cast<float>(std::move(value))));
}
INSTANTIATE_TEST_CASE_P(
half, UnaryOpTest,
::testing::Values(
UnaryOpTestParam{[](half x) { return abs(x); }, &Abs},
UnaryOpTestParam{[](half x) { return round_imp(x); }, &Round},
UnaryOpTestParam{[](half x) { return ceil(x); }, &Ceil},
UnaryOpTestParam{[](half x) { return cos(x); }, &Cos},
UnaryOpTestParam{[](half x) { return exp(x); }, &Exp},
UnaryOpTestParam{[](half x) { return floor(x); }, &Floor},
UnaryOpTestParam{[](half x) { return log(x); }, &Log},
UnaryOpTestParam{[](half x) { return -x; }, &Neg},
UnaryOpTestParam{[](half x) { return sign_imp(x); }, &Sign},
UnaryOpTestParam{[](half x) { return sin(x); }, &Sin},
UnaryOpTestParam{[](half x) { return tanh(x); }, &Tanh}
));
struct UnaryPredTestParam {
std::function<bool(half)> compute_func;
UnaryBuildFuncTy build_func;
};
class UnaryPredTest : public HalfTestBase,
public ::testing::WithParamInterface<UnaryPredTestParam> {
};
XLA_TEST_P(UnaryPredTest, Ops) {
std::vector<half> x({half(1.4), half(-2.3), half(3.2), half(-4.1)});
XlaBuilder builder(TestName());
XlaOp x_opnd;
auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x",
&builder, &x_opnd);
std::function<bool(half)> compute_func = GetParam().compute_func;
CHECK_EQ(kNumElements, x.size());
bool expected[kNumElements];
for (int64 i = 0; i < x.size(); ++i) {
expected[i] = compute_func(x[i]);
}
UnaryBuildFuncTy build_func = GetParam().build_func;
build_func(x_opnd);
ComputeAndCompareR1<bool>(&builder, expected, {x_data.get()});
}
INSTANTIATE_TEST_CASE_P(half, UnaryPredTest,
::testing::Values(UnaryPredTestParam{
[](half x) { return isfinite(x); }, &IsFinite}));
using BinaryBuildFuncTy = std::function<void(
const xla::XlaOp& x, const xla::XlaOp& y, absl::Span<const int64>)>;
struct BinaryOpTestParam {
std::function<half(half, half)> compute_func;
BinaryBuildFuncTy build_func;
};
class BinaryOpTest : public HalfTestBase,
public ::testing::WithParamInterface<BinaryOpTestParam> {};
XLA_TEST_P(BinaryOpTest, Ops) {
std::vector<half> x({half(1.0), half(2.0), half(3.0), half(-4.0)});
std::vector<half> y({half(0.4), half(-0.3), half(0.2), half(0.1)});
XlaBuilder builder(TestName());
XlaOp x_opnd;
auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x",
&builder, &x_opnd);
XlaOp y_opnd;
auto y_data = CreateR1Parameter<half>(y, /*parameter_number=*/1, "y",
&builder, &y_opnd);
std::function<half(half, half)> compute_func = GetParam().compute_func;
std::vector<half> expected;
for (int64 i = 0; i < x.size(); ++i) {
expected.push_back(compute_func(x[i], y[i]));
}
BinaryBuildFuncTy build_func = GetParam().build_func;
build_func(x_opnd, y_opnd, {});
ComputeAndCompareR1<half>(&builder, expected, {x_data.get(), y_data.get()},
error_spec_);
}
half atan2_imp(half x, half y) {
return half(std::atan2(static_cast<float>(std::move(x)),
static_cast<float>(std::move(y))));
}
INSTANTIATE_TEST_CASE_P(
half, BinaryOpTest,
::testing::Values(
BinaryOpTestParam{[](half x, half y) { return x + y; }, &Add},
BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); },
&Atan2},
BinaryOpTestParam{[](half x, half y) { return x / y; }, &Div},
BinaryOpTestParam{[](half x, half y) { return max(x, y); }, &Max},
BinaryOpTestParam{[](half x, half y) { return min(x, y); }, &Min},
BinaryOpTestParam{[](half x, half y) { return x * y; }, &Mul},
BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, &Pow},
BinaryOpTestParam{[](half x, half y) { return x - y; }, &Sub}
));
struct BinaryPredTestParam {
std::function<bool(half, half)> compute_func;
BinaryBuildFuncTy build_func;
};
class BinaryPredTest
: public HalfTestBase,
public ::testing::WithParamInterface<BinaryPredTestParam> {};
XLA_TEST_P(BinaryPredTest, Ops) {
std::vector<half> x({half(1.0), half(2.0), half(0.2), half(-4.0)});
std::vector<half> y({half(0.4), half(-0.3), half(0.2), half(0.1)});
XlaBuilder builder(TestName());
XlaOp x_opnd;
auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x",
&builder, &x_opnd);
XlaOp y_opnd;
auto y_data = CreateR1Parameter<half>(y, /*parameter_number=*/1, "y",
&builder, &y_opnd);
std::function<bool(half, half)> compute_func = GetParam().compute_func;
CHECK_EQ(kNumElements, x.size());
bool expected[kNumElements];
for (int64 i = 0; i < x.size(); ++i) {
expected[i] = compute_func(x[i], y[i]);
}
BinaryBuildFuncTy build_func = GetParam().build_func;
build_func(x_opnd, y_opnd, {});
ComputeAndCompareR1<bool>(&builder, expected, {x_data.get(), y_data.get()});
}
INSTANTIATE_TEST_CASE_P(
half, BinaryPredTest,
::testing::Values(
BinaryPredTestParam{[](half x, half y) { return x == y; }, &Eq},
BinaryPredTestParam{[](half x, half y) { return x != y; }, &Ne},
BinaryPredTestParam{[](half x, half y) { return x >= y; }, &Ge},
BinaryPredTestParam{[](half x, half y) { return x > y; }, &Gt},
BinaryPredTestParam{[](half x, half y) { return x <= y; }, &Le},
BinaryPredTestParam{[](half x, half y) { return x < y; }, &Lt}
));
} // namespace
} // namespace xla