Split up exhaustive_binary_test.cc into two files.
This also allows us to get rid of #define flags, and we can fix the issue that required the usage of the GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST macro. PiperOrigin-RevId: 293584602 Change-Id: Idd6723a3304200c8df5b6066a08159b5c45cb470
This commit is contained in:
parent
bc28d49ce2
commit
e0374aaae8
@ -836,14 +836,12 @@ xla_test(
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "exhaustive_binary_test_f16",
|
||||
srcs = ["exhaustive_binary_test.cc"],
|
||||
name = "exhaustive_binary_16_bit_test",
|
||||
srcs = ["exhaustive_binary_16_bit_test.cc"],
|
||||
backends = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
copts = ["-DBINARY_TEST_TARGET_F16"],
|
||||
real_hardware_only = True, # Very slow on the interpreter.
|
||||
shard_count = 48,
|
||||
tags = [
|
||||
"optonly",
|
||||
@ -857,56 +855,12 @@ xla_test(
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "exhaustive_binary_test_bf16",
|
||||
srcs = ["exhaustive_binary_test.cc"],
|
||||
name = "exhaustive_binary_test_f32_f64",
|
||||
srcs = ["exhaustive_binary_test_f32_f64.cc"],
|
||||
backends = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
copts = ["-DBINARY_TEST_TARGET_BF16"],
|
||||
real_hardware_only = True, # Very slow on the interpreter.
|
||||
shard_count = 48,
|
||||
tags = [
|
||||
"optonly",
|
||||
# This is a big test that we skip for capacity reasons in OSS testing.
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
],
|
||||
deps = [
|
||||
":exhaustive_op_test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "exhaustive_binary_test_f32",
|
||||
srcs = ["exhaustive_binary_test.cc"],
|
||||
backends = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
copts = ["-DBINARY_TEST_TARGET_F32"],
|
||||
real_hardware_only = True, # Very slow on the interpreter.
|
||||
shard_count = 48,
|
||||
tags = [
|
||||
"optonly",
|
||||
# This is a big test that we skip for capacity reasons in OSS testing.
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
],
|
||||
deps = [
|
||||
":exhaustive_op_test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "exhaustive_binary_test_f64",
|
||||
srcs = ["exhaustive_binary_test.cc"],
|
||||
backends = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
copts = ["-DBINARY_TEST_TARGET_F64"],
|
||||
real_hardware_only = True, # Very slow on the interpreter.
|
||||
shard_count = 48,
|
||||
tags = [
|
||||
"optonly",
|
||||
|
143
tensorflow/compiler/xla/tests/exhaustive_binary_16_bit_test.cc
Normal file
143
tensorflow/compiler/xla/tests/exhaustive_binary_16_bit_test.cc
Normal file
@ -0,0 +1,143 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h"
|
||||
|
||||
#ifdef __FAST_MATH__
|
||||
#error("Can't be compiled with fast math on");
|
||||
#endif
|
||||
|
||||
namespace xla {
|
||||
namespace exhaustive_op_test {
|
||||
namespace {
|
||||
|
||||
// Exhaustive test for binary operations for 16 bit floating point types,
|
||||
// including float16 and bfloat.
|
||||
//
|
||||
// Test parameter is a pair of (begin, end) for range under test.
|
||||
template <PrimitiveType T>
|
||||
class Exhaustive16BitBinaryTest
|
||||
: public ExhaustiveBinaryTest<T>,
|
||||
public ::testing::WithParamInterface<std::pair<int64, int64>> {
|
||||
public:
|
||||
int64 GetInputSize() override {
|
||||
int64 begin, end;
|
||||
std::tie(begin, end) = GetParam();
|
||||
return end - begin;
|
||||
}
|
||||
|
||||
// Given a range of uint64 representation, uses bits 0..15 and bits 16..31 for
|
||||
// the values of src0 and src1 for a 16 bit binary operation being tested,
|
||||
// and generates the cartesian product of the two sets as the two inputs for
|
||||
// the test.
|
||||
void FillInput(std::array<Literal, 2>* input_literals) override {
|
||||
int64 input_size = GetInputSize();
|
||||
CHECK_EQ(input_size, (*input_literals)[0].element_count());
|
||||
CHECK_EQ(input_size, (*input_literals)[1].element_count());
|
||||
|
||||
int64 begin, end;
|
||||
std::tie(begin, end) = GetParam();
|
||||
VLOG(2) << "Checking range [" << begin << ", " << end << "]";
|
||||
|
||||
absl::Span<NativeT> input_arr_0 = (*input_literals)[0].data<NativeT>();
|
||||
absl::Span<NativeT> input_arr_1 = (*input_literals)[1].data<NativeT>();
|
||||
for (int64 i = 0; i < input_size; i++) {
|
||||
uint32 input_val = i + begin;
|
||||
// Convert the lower 16 bits to the NativeT and replaced known incorrect
|
||||
// input values with 0.
|
||||
input_arr_0[i] = ConvertAndReplaceKnownIncorrectValueWith(input_val, 0);
|
||||
input_arr_1[i] =
|
||||
ConvertAndReplaceKnownIncorrectValueWith(input_val >> 16, 0);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
using typename ExhaustiveBinaryTest<T>::NativeT;
|
||||
using ExhaustiveBinaryTest<T>::ConvertAndReplaceKnownIncorrectValueWith;
|
||||
};
|
||||
|
||||
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
|
||||
using ExhaustiveF16BinaryTest = Exhaustive16BitBinaryTest<F16>;
|
||||
#define BINARY_TEST_F16(test_name, ...) \
|
||||
XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \
|
||||
__VA_ARGS__
|
||||
#else
|
||||
#define BINARY_TEST_F16(test_name, ...)
|
||||
#endif
|
||||
|
||||
#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
|
||||
using ExhaustiveBF16BinaryTest = Exhaustive16BitBinaryTest<BF16>;
|
||||
#define BINARY_TEST_BF16(test_name, ...) \
|
||||
XLA_TEST_P(ExhaustiveBF16BinaryTest, test_name) \
|
||||
__VA_ARGS__
|
||||
#else
|
||||
#define BINARY_TEST_BF16(test_name, ...)
|
||||
#endif
|
||||
|
||||
#define BINARY_TEST_16BIT(test_name, ...) \
|
||||
BINARY_TEST_F16(test_name, __VA_ARGS__) \
|
||||
BINARY_TEST_BF16(test_name, __VA_ARGS__)
|
||||
|
||||
BINARY_TEST_16BIT(Add, {
|
||||
auto host_add = [](float x, float y) { return x + y; };
|
||||
Run(AddEmptyBroadcastDimension(Add), host_add);
|
||||
})
|
||||
|
||||
BINARY_TEST_16BIT(Sub, {
|
||||
auto host_sub = [](float x, float y) { return x - y; };
|
||||
Run(AddEmptyBroadcastDimension(Sub), host_sub);
|
||||
})
|
||||
|
||||
// TODO(bixia): Mul fails with bfloat16 on CPU.
|
||||
BINARY_TEST_16BIT(DISABLED_ON_CPU(Mul), {
|
||||
auto host_mul = [](float x, float y) { return x * y; };
|
||||
Run(AddEmptyBroadcastDimension(Mul), host_mul);
|
||||
})
|
||||
|
||||
// TODO(bixia): Div fails with bfloat16 on CPU.
|
||||
BINARY_TEST_16BIT(DISABLED_ON_CPU(Div), {
|
||||
auto host_div = [](float x, float y) { return x / y; };
|
||||
Run(AddEmptyBroadcastDimension(Div), host_div);
|
||||
})
|
||||
|
||||
BINARY_TEST_16BIT(Max, {
|
||||
Run(AddEmptyBroadcastDimension(Max), ReferenceMax<float>);
|
||||
})
|
||||
|
||||
BINARY_TEST_16BIT(Min, {
|
||||
Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>);
|
||||
})
|
||||
|
||||
// TODO(bixia): Pow fails with bfloat16 on CPU.
|
||||
BINARY_TEST_16BIT(DISABLED_ON_CPU(Pow),
|
||||
{ Run(AddEmptyBroadcastDimension(Pow), std::powf); })
|
||||
|
||||
// TODO(bixia): Atan2 fails with bfloat16 on CPU.
|
||||
BINARY_TEST_16BIT(DISABLED_ON_CPU(Atan2),
|
||||
{ Run(AddEmptyBroadcastDimension(Atan2), std::atan2f); })
|
||||
|
||||
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
|
||||
INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest,
|
||||
::testing::ValuesIn(CreateExhaustiveF32Ranges()));
|
||||
#endif
|
||||
|
||||
#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
|
||||
INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16BinaryTest,
|
||||
::testing::ValuesIn(CreateExhaustiveF32Ranges()));
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
} // namespace exhaustive_op_test
|
||||
} // namespace xla
|
@ -23,171 +23,6 @@ namespace xla {
|
||||
namespace exhaustive_op_test {
|
||||
namespace {
|
||||
|
||||
template <PrimitiveType T>
|
||||
using ExhaustiveBinaryTest = ExhaustiveOpTestBase<T, 2>;
|
||||
|
||||
// Exhaustive test for binary operations for 16 bit floating point types,
|
||||
// including float16 and bfloat.
|
||||
//
|
||||
// Test parameter is a pair of (begin, end) for range under test.
|
||||
template <
|
||||
PrimitiveType T,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename primitive_util::PrimitiveTypeToNative<T>::type,
|
||||
half>::value ||
|
||||
std::is_same<typename primitive_util::PrimitiveTypeToNative<T>::type,
|
||||
bfloat16>::value>::type* = nullptr>
|
||||
class Exhaustive16BitBinaryTest
|
||||
: public ExhaustiveBinaryTest<T>,
|
||||
public ::testing::WithParamInterface<std::pair<int64, int64>> {
|
||||
public:
|
||||
int64 GetInputSize() override {
|
||||
int64 begin, end;
|
||||
std::tie(begin, end) = GetParam();
|
||||
return end - begin;
|
||||
}
|
||||
|
||||
// Given a range of uint64 representation, uses bits 0..15 and bits 16..31 for
|
||||
// the values of src0 and src1 for a 16 bit binary operation being tested,
|
||||
// and generates the cartesian product of the two sets as the two inputs for
|
||||
// the test.
|
||||
void FillInput(std::array<Literal, 2>* input_literals) override {
|
||||
int64 input_size = GetInputSize();
|
||||
CHECK_EQ(input_size, (*input_literals)[0].element_count());
|
||||
CHECK_EQ(input_size, (*input_literals)[1].element_count());
|
||||
|
||||
int64 begin, end;
|
||||
std::tie(begin, end) = GetParam();
|
||||
VLOG(2) << "Checking range [" << begin << ", " << end << "]";
|
||||
|
||||
absl::Span<NativeT> input_arr_0 = (*input_literals)[0].data<NativeT>();
|
||||
absl::Span<NativeT> input_arr_1 = (*input_literals)[1].data<NativeT>();
|
||||
for (int64 i = 0; i < input_size; i++) {
|
||||
uint32 input_val = i + begin;
|
||||
// Convert the lower 16 bits to the NativeT and replaced known incorrect
|
||||
// input values with 0.
|
||||
input_arr_0[i] = ConvertAndReplaceKnownIncorrectValueWith(input_val, 0);
|
||||
input_arr_1[i] =
|
||||
ConvertAndReplaceKnownIncorrectValueWith(input_val >> 16, 0);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
using typename ExhaustiveBinaryTest<T>::NativeT;
|
||||
using ExhaustiveBinaryTest<T>::ConvertAndReplaceKnownIncorrectValueWith;
|
||||
};
|
||||
|
||||
using ExhaustiveF16BinaryTest = Exhaustive16BitBinaryTest<F16>;
|
||||
using ExhaustiveBF16BinaryTest = Exhaustive16BitBinaryTest<BF16>;
|
||||
|
||||
// Returns a wrapper of the given build method, which build an HLO operation
|
||||
// with an empty broadcast dimension.
|
||||
inline std::function<XlaOp(XlaOp, XlaOp)> AddEmptyBroadcastDimension(
|
||||
std::function<XlaOp(XlaOp, XlaOp, absl::Span<const int64>)> build_method) {
|
||||
return [&](XlaOp src0, XlaOp src1) -> XlaOp {
|
||||
return build_method(src0, src1, {});
|
||||
};
|
||||
}
|
||||
|
||||
#if defined(BINARY_TEST_TARGET_F16) && defined(BINARY_TEST_TARGET_BF16)
|
||||
#error "Can't define both BINARY_TEST_TARGET_F16 and BINARY_TEST_TARGET_BF16"
|
||||
#endif
|
||||
|
||||
#if defined(BINARY_TEST_TARGET_F16) && \
|
||||
!defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
|
||||
#define BINARY_TEST_16BIT(test_name, ...) \
|
||||
XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \
|
||||
__VA_ARGS__
|
||||
#elif defined(BINARY_TEST_TARGET_BF16) && defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
|
||||
#define BINARY_TEST_16BIT(test_name, ...) \
|
||||
XLA_TEST_P(ExhaustiveBF16BinaryTest, test_name) \
|
||||
__VA_ARGS__
|
||||
#else
|
||||
#define BINARY_TEST_16BIT(test_name, ...)
|
||||
#endif
|
||||
|
||||
BINARY_TEST_16BIT(Add, {
|
||||
auto host_add = [](float x, float y) { return x + y; };
|
||||
Run(AddEmptyBroadcastDimension(Add), host_add);
|
||||
})
|
||||
|
||||
BINARY_TEST_16BIT(Sub, {
|
||||
auto host_sub = [](float x, float y) { return x - y; };
|
||||
Run(AddEmptyBroadcastDimension(Sub), host_sub);
|
||||
})
|
||||
|
||||
// TODO(bixia): Mul fails with bfloat16 on CPU.
|
||||
BINARY_TEST_16BIT(DISABLED_ON_CPU(Mul), {
|
||||
auto host_mul = [](float x, float y) { return x * y; };
|
||||
Run(AddEmptyBroadcastDimension(Mul), host_mul);
|
||||
})
|
||||
|
||||
// TODO(bixia): Div fails with bfloat16 on CPU.
|
||||
BINARY_TEST_16BIT(DISABLED_ON_CPU(Div), {
|
||||
auto host_div = [](float x, float y) { return x / y; };
|
||||
Run(AddEmptyBroadcastDimension(Div), host_div);
|
||||
})
|
||||
|
||||
template <typename T, typename std::enable_if<
|
||||
std::is_same<T, float>::value ||
|
||||
std::is_same<T, double>::value>::type* = nullptr>
|
||||
T ReferenceMax(T x, T y) {
|
||||
// We need to propagate NAN here because std::max may not propagate NAN.
|
||||
if (std::fpclassify(x) == FP_NAN) {
|
||||
return x;
|
||||
}
|
||||
if (std::fpclassify(y) == FP_NAN) {
|
||||
return y;
|
||||
}
|
||||
|
||||
return std::max<T>(x, y);
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<
|
||||
std::is_same<T, float>::value ||
|
||||
std::is_same<T, double>::value>::type* = nullptr>
|
||||
T ReferenceMin(T x, T y) {
|
||||
// We need to propagate NAN here because std::max may not propagate NAN.
|
||||
if (std::fpclassify(x) == FP_NAN) {
|
||||
return x;
|
||||
}
|
||||
if (std::fpclassify(y) == FP_NAN) {
|
||||
return y;
|
||||
}
|
||||
|
||||
return std::min<T>(x, y);
|
||||
}
|
||||
|
||||
BINARY_TEST_16BIT(Max, {
|
||||
Run(AddEmptyBroadcastDimension(Max), ReferenceMax<float>);
|
||||
})
|
||||
|
||||
BINARY_TEST_16BIT(Min, {
|
||||
Run(AddEmptyBroadcastDimension(Min), ReferenceMin<float>);
|
||||
})
|
||||
|
||||
// TODO(bixia): Pow fails with bfloat16 on CPU.
|
||||
BINARY_TEST_16BIT(DISABLED_ON_CPU(Pow),
|
||||
{ Run(AddEmptyBroadcastDimension(Pow), std::powf); })
|
||||
|
||||
// TODO(bixia): Atan2 fails with bfloat16 on CPU.
|
||||
BINARY_TEST_16BIT(DISABLED_ON_CPU(Atan2),
|
||||
{ Run(AddEmptyBroadcastDimension(Atan2), std::atan2f); })
|
||||
|
||||
#if defined(BINARY_TEST_TARGET_F16)
|
||||
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
|
||||
INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest,
|
||||
::testing::ValuesIn(CreateExhaustiveF32Ranges()));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(BINARY_TEST_TARGET_BF16)
|
||||
#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
|
||||
INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16BinaryTest,
|
||||
::testing::ValuesIn(CreateExhaustiveF32Ranges()));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Exhaustive test for binary operations for float and double.
|
||||
//
|
||||
// Test parameter is a tuple of (FpValues, FpValues) describing the possible
|
||||
@ -236,20 +71,10 @@ class Exhaustive32BitOrMoreBinaryTest
|
||||
};
|
||||
|
||||
using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest<F32>;
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(
|
||||
ExhaustiveF32BinaryTest); // TODO(b/139702016) go/are-your-tests-running
|
||||
|
||||
using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest<F64>;
|
||||
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(
|
||||
ExhaustiveF64BinaryTest); // TODO(b/139702016) go/are-your-tests-running
|
||||
|
||||
#if defined(BINARY_TEST_TARGET_F32)
|
||||
#define BINARY_TEST_FLOAT_32(test_name, ...) \
|
||||
XLA_TEST_P(ExhaustiveF32BinaryTest, test_name) \
|
||||
__VA_ARGS__
|
||||
#else
|
||||
#define BINARY_TEST_FLOAT_32(test_name, ...)
|
||||
#endif
|
||||
|
||||
BINARY_TEST_FLOAT_32(Add, {
|
||||
auto host_add = [](float x, float y) { return x + y; };
|
||||
@ -332,8 +157,8 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
::testing::ValuesIn(
|
||||
GetFpValuesForMagnitudeExtremeNormals<float>(40000, 2000))));
|
||||
|
||||
#if defined(BINARY_TEST_TARGET_F64) && \
|
||||
!defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
|
||||
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
|
||||
using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest<F64>;
|
||||
#define BINARY_TEST_FLOAT_64(test_name, ...) \
|
||||
XLA_TEST_P(ExhaustiveF64BinaryTest, test_name) \
|
||||
__VA_ARGS__
|
||||
@ -381,6 +206,7 @@ BINARY_TEST_FLOAT_64(DISABLED_ON_CPU(AbsComplex), {
|
||||
Run(device_abs_complex, host_abs_complex);
|
||||
})
|
||||
|
||||
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SpecialValues, ExhaustiveF64BinaryTest,
|
||||
::testing::Combine(
|
||||
@ -414,6 +240,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)),
|
||||
::testing::ValuesIn(
|
||||
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
|
||||
#endif // !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
|
||||
|
||||
} // namespace
|
||||
} // namespace exhaustive_op_test
|
@ -1004,6 +1004,45 @@ typename ErrorSpecGenWrapper<T, N>::type GetDefaultSpecGenerator() {
|
||||
return DefaultSpecGenerator<T, N>;
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<
|
||||
std::is_same<T, float>::value ||
|
||||
std::is_same<T, double>::value>::type* = nullptr>
|
||||
T ReferenceMax(T x, T y) {
|
||||
// We need to propagate NAN here because std::max may not propagate NAN.
|
||||
if (std::fpclassify(x) == FP_NAN) {
|
||||
return x;
|
||||
}
|
||||
if (std::fpclassify(y) == FP_NAN) {
|
||||
return y;
|
||||
}
|
||||
|
||||
return std::max<T>(x, y);
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<
|
||||
std::is_same<T, float>::value ||
|
||||
std::is_same<T, double>::value>::type* = nullptr>
|
||||
T ReferenceMin(T x, T y) {
|
||||
// We need to propagate NAN here because std::max may not propagate NAN.
|
||||
if (std::fpclassify(x) == FP_NAN) {
|
||||
return x;
|
||||
}
|
||||
if (std::fpclassify(y) == FP_NAN) {
|
||||
return y;
|
||||
}
|
||||
|
||||
return std::min<T>(x, y);
|
||||
}
|
||||
|
||||
// Returns a wrapper of the given build method, which build an HLO operation
|
||||
// with an empty broadcast dimension.
|
||||
inline std::function<XlaOp(XlaOp, XlaOp)> AddEmptyBroadcastDimension(
|
||||
std::function<XlaOp(XlaOp, XlaOp, absl::Span<const int64>)> build_method) {
|
||||
return [&](XlaOp src0, XlaOp src1) -> XlaOp {
|
||||
return build_method(src0, src1, {});
|
||||
};
|
||||
}
|
||||
|
||||
template <PrimitiveType T>
|
||||
class ExhaustiveUnaryTest : public ExhaustiveOpTestBase<T, 1> {
|
||||
public:
|
||||
@ -1013,6 +1052,9 @@ class ExhaustiveUnaryTest : public ExhaustiveOpTestBase<T, 1> {
|
||||
}
|
||||
};
|
||||
|
||||
template <PrimitiveType T>
|
||||
using ExhaustiveBinaryTest = ExhaustiveOpTestBase<T, 2>;
|
||||
|
||||
} // namespace exhaustive_op_test
|
||||
} // namespace xla
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_
|
||||
|
Loading…
x
Reference in New Issue
Block a user