[XLA] Introduce ManifestCheckingTest
PiperOrigin-RevId: 317229603 Change-Id: Ibcc9ea3895d520024f5d80d52330aeb3b970585d
This commit is contained in:
parent
13fe5862de
commit
4a14e778d6
|
@ -52,16 +52,26 @@ cc_library(
|
|||
name = "test_macros_header",
|
||||
testonly = True,
|
||||
hdrs = ["test_macros.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
# Generate a test_macros_${BACKEND} library per backend with the proper copts.
|
||||
generate_backend_test_macros()
|
||||
|
||||
cc_library(
|
||||
name = "manifest_checking_test",
|
||||
testonly = True,
|
||||
srcs = ["manifest_checking_test.cc"],
|
||||
hdrs = ["manifest_checking_test.h"],
|
||||
deps = [
|
||||
":test_macros_header",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_utils",
|
||||
srcs = ["test_utils.cc"],
|
||||
|
@ -136,6 +146,7 @@ cc_library(
|
|||
hdrs = ["hlo_test_base.h"],
|
||||
deps = [
|
||||
":literal_test_util",
|
||||
":manifest_checking_test",
|
||||
":test_utils",
|
||||
":verified_hlo_module",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
|
@ -193,6 +204,7 @@ cc_library(
|
|||
srcs = ["client_library_test_base.cc"],
|
||||
hdrs = ["client_library_test_base.h"],
|
||||
deps = [
|
||||
":manifest_checking_test",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
|
@ -273,6 +285,7 @@ cc_library(
|
|||
hdrs = ["local_client_test_base.h"],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
":manifest_checking_test",
|
||||
":verified_hlo_module",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
|
|
|
@ -266,11 +266,6 @@ def generate_backend_test_macros(backends = []):
|
|||
"-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
|
||||
],
|
||||
deps = [
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/platform:logging",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/bitmap.h"
|
||||
|
@ -62,7 +63,7 @@ std::vector<TestCase> ExpandUseBfloat16(
|
|||
}
|
||||
|
||||
// A client library test establishes an in-process XLA client connection.
|
||||
class ClientLibraryTestBase : public ::testing::Test {
|
||||
class ClientLibraryTestBase : public ManifestCheckingTest {
|
||||
protected:
|
||||
explicit ClientLibraryTestBase(se::Platform* platform = nullptr);
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/shape_layout.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
|
||||
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
@ -67,7 +68,7 @@ namespace xla {
|
|||
// )
|
||||
//
|
||||
// For a more detailed example, see "../tests/sample_text_test.cc".
|
||||
class HloTestBase : public ::testing::Test {
|
||||
class HloTestBase : public ManifestCheckingTest {
|
||||
public:
|
||||
// Creates a new HLO module for a test. The module created will have
|
||||
// TestName() for its name; it will also automatically populate its debug
|
||||
|
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
|
||||
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
@ -75,7 +76,7 @@ class TestAllocator : public se::StreamExecutorMemoryAllocator {
|
|||
};
|
||||
|
||||
// A base class for tests which exercise the LocalClient interface.
|
||||
class LocalClientTestBase : public ::testing::Test {
|
||||
class LocalClientTestBase : public ManifestCheckingTest {
|
||||
protected:
|
||||
struct EigenThreadPoolWrapper;
|
||||
explicit LocalClientTestBase(se::Platform* platform = nullptr);
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is
|
||||
// disabled - a sequence of regexps.
|
||||
using ManifestT = absl::flat_hash_map<std::string, std::vector<std::string>>;
|
||||
|
||||
ManifestT ReadManifest() {
|
||||
ManifestT manifest;
|
||||
|
||||
absl::string_view path = absl::NullSafeStringView(kDisabledManifestPath);
|
||||
if (path.empty()) {
|
||||
return manifest;
|
||||
}
|
||||
|
||||
// Note: parens are required to disambiguate vs function decl.
|
||||
std::ifstream file_stream((std::string(path)));
|
||||
std::string contents((std::istreambuf_iterator<char>(file_stream)),
|
||||
std::istreambuf_iterator<char>());
|
||||
|
||||
std::vector<std::string> lines = absl::StrSplit(contents, '\n');
|
||||
for (std::string& line : lines) {
|
||||
auto comment = line.find("//");
|
||||
if (comment != std::string::npos) {
|
||||
line = line.substr(0, comment);
|
||||
}
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
absl::StripTrailingAsciiWhitespace(&line);
|
||||
std::vector<std::string> pieces = absl::StrSplit(line, ' ');
|
||||
CHECK_GE(pieces.size(), 1);
|
||||
auto& platforms = manifest[pieces[0]];
|
||||
for (size_t i = 1; i < pieces.size(); ++i) {
|
||||
platforms.push_back(pieces[i]);
|
||||
}
|
||||
}
|
||||
return manifest;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ManifestCheckingTest::SetUp() {
|
||||
const testing::TestInfo* test_info =
|
||||
testing::UnitTest::GetInstance()->current_test_info();
|
||||
absl::string_view test_case_name = test_info->test_suite_name();
|
||||
absl::string_view test_name = test_info->name();
|
||||
VLOG(1) << "test_case_name: " << test_case_name;
|
||||
VLOG(1) << "test_name: " << test_name;
|
||||
|
||||
// Remove the type suffix from the test case name.
|
||||
if (const char* type_param = test_info->type_param()) {
|
||||
VLOG(1) << "type_param: " << type_param;
|
||||
size_t last_slash = test_case_name.rfind('/');
|
||||
test_case_name = test_case_name.substr(0, last_slash);
|
||||
VLOG(1) << "test_case_name: " << test_case_name;
|
||||
}
|
||||
|
||||
// Remove the test instantiation name if it is present.
|
||||
auto first_slash = test_case_name.find('/');
|
||||
if (first_slash != test_case_name.npos) {
|
||||
test_case_name.remove_prefix(first_slash + 1);
|
||||
VLOG(1) << "test_case_name: " << test_case_name;
|
||||
}
|
||||
|
||||
ManifestT manifest = ReadManifest();
|
||||
|
||||
// If the test name ends with a slash followed by one or more characters,
|
||||
// strip that off.
|
||||
auto last_slash = test_name.rfind('/');
|
||||
if (last_slash != test_name.npos) {
|
||||
test_name = test_name.substr(0, last_slash);
|
||||
VLOG(1) << "test_name: " << test_name;
|
||||
}
|
||||
|
||||
// First try full match: test_case_name.test_name
|
||||
// If that fails, try to find just the test_case_name; this would disable all
|
||||
// tests in the test case.
|
||||
auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
|
||||
if (it == manifest.end()) {
|
||||
it = manifest.find(test_case_name);
|
||||
if (it == manifest.end()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Expect a full match vs. one of the platform regexps to disable the test.
|
||||
const std::vector<std::string>& disabled_platforms = it->second;
|
||||
auto platform_string = kTestPlatform;
|
||||
for (const auto& s : disabled_platforms) {
|
||||
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
|
||||
GTEST_SKIP();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// We didn't hit in the disabled manifest entries, so don't disable it.
|
||||
}
|
||||
|
||||
} // namespace xla
|
|
@ -0,0 +1,35 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// This class allows us to intercept the test name and use an arbitrary
|
||||
// heuristic to decide whether the test case should be disabled. We
|
||||
// determine whether the test case should be disabled by resolving the (test
|
||||
// case name, test name) in a manifest file.
|
||||
class ManifestCheckingTest : public ::testing::Test {
|
||||
protected:
|
||||
// This method runs before each test runs.
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
|
|
@ -15,93 +15,18 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <streambuf>
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is
|
||||
// disabled - a sequence of regexps.
|
||||
using ManifestT = absl::flat_hash_map<string, std::vector<string>>;
|
||||
|
||||
ManifestT ReadManifest() {
|
||||
ManifestT manifest;
|
||||
|
||||
string path = XLA_DISABLED_MANIFEST;
|
||||
if (path.empty()) {
|
||||
return manifest;
|
||||
}
|
||||
|
||||
std::ifstream file_stream(path);
|
||||
// Note: parens are required to disambiguate vs function decl.
|
||||
string contents((std::istreambuf_iterator<char>(file_stream)),
|
||||
std::istreambuf_iterator<char>());
|
||||
|
||||
std::vector<string> lines = absl::StrSplit(contents, '\n');
|
||||
for (string& line : lines) {
|
||||
auto comment = line.find("//");
|
||||
if (comment != string::npos) {
|
||||
line = line.substr(0, comment);
|
||||
}
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
absl::StripTrailingAsciiWhitespace(&line);
|
||||
std::vector<string> pieces = absl::StrSplit(line, ' ');
|
||||
CHECK_GE(pieces.size(), 1);
|
||||
auto& platforms = manifest[pieces[0]];
|
||||
for (int64 i = 1; i < pieces.size(); ++i) {
|
||||
platforms.push_back(pieces[i]);
|
||||
}
|
||||
}
|
||||
return manifest;
|
||||
static bool InitModule() {
|
||||
kDisabledManifestPath = XLA_DISABLED_MANIFEST;
|
||||
VLOG(1) << "kDisabledManifestPath: " << kDisabledManifestPath;
|
||||
kTestPlatform = XLA_PLATFORM;
|
||||
VLOG(1) << "kTestPlatform: " << kTestPlatform;
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::string PrependDisabledIfIndicated(absl::string_view test_case_name,
|
||||
absl::string_view test_name) {
|
||||
ManifestT manifest = ReadManifest();
|
||||
|
||||
// If the test name ends with a slash followed by one or more digits, strip
|
||||
// that off; this is just a shard number, and matching on this would be
|
||||
// unstable even if someone wanted to do it.
|
||||
static LazyRE2 shard_num_pattern = {R"(/\d+$)"};
|
||||
absl::string_view suffix;
|
||||
if (RE2::PartialMatch(test_name, *shard_num_pattern, &suffix)) {
|
||||
test_name.remove_suffix(suffix.size());
|
||||
}
|
||||
|
||||
// First try full match: test_case_name.test_name
|
||||
// If that fails, try to find just the test_case_name; this would disable all
|
||||
// tests in the test case.
|
||||
auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
|
||||
if (it == manifest.end()) {
|
||||
it = manifest.find(test_case_name);
|
||||
if (it == manifest.end()) {
|
||||
return std::string(test_name);
|
||||
}
|
||||
}
|
||||
|
||||
// Expect a full match vs. one of the platform regexps to disable the test.
|
||||
const std::vector<string>& disabled_platforms = it->second;
|
||||
string platform_string = XLA_PLATFORM;
|
||||
for (const auto& s : disabled_platforms) {
|
||||
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
|
||||
return absl::StrCat("DISABLED_", test_name);
|
||||
}
|
||||
}
|
||||
|
||||
// We didn't hit in the disabled manifest entries, so don't disable it.
|
||||
return std::string(test_name);
|
||||
}
|
||||
static bool module_initialized = InitModule();
|
||||
|
||||
} // namespace xla
|
||||
|
|
|
@ -28,12 +28,6 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#define DISABLED_ON_CPU(X) X
|
||||
#define DISABLED_ON_GPU(X) X
|
||||
#define DISABLED_ON_GPU_ROCM(X) X
|
||||
|
@ -79,117 +73,15 @@ limitations under the License.
|
|||
|
||||
namespace xla {
|
||||
|
||||
// Reads a disabled manifest file to resolve whether test cases should be
|
||||
// disabled on a particular platform. For a test that should be disabled,
|
||||
// returns DISABLED_ prepended to its name; otherwise returns the test name
|
||||
// unmodified.
|
||||
std::string PrependDisabledIfIndicated(absl::string_view test_case_name,
|
||||
absl::string_view test_name);
|
||||
inline const char *kDisabledManifestPath = nullptr;
|
||||
inline const char *kTestPlatform = nullptr;
|
||||
|
||||
} // namespace xla
|
||||
|
||||
// This is the internal "gtest" class instantiation -- it is identical to the
|
||||
// GTEST_TEST_ macro, except that we intercept the test name for potential
|
||||
// modification by PrependDisabledIfIndicated. That file can use an arbitrary
|
||||
// heuristic to decide whether the test case should be disabled, and we
|
||||
// determine whether the test case should be disabled by resolving the (test
|
||||
// case name, test name) in a manifest file.
|
||||
#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class) \
|
||||
class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \
|
||||
: public parent_class { \
|
||||
public: \
|
||||
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \
|
||||
\
|
||||
private: \
|
||||
virtual void TestBody(); \
|
||||
static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \
|
||||
GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \
|
||||
test_name)); \
|
||||
}; \
|
||||
\
|
||||
::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \
|
||||
test_name)::test_info_ = \
|
||||
::testing::RegisterTest( \
|
||||
#test_case_name, \
|
||||
::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \
|
||||
.c_str(), \
|
||||
nullptr, nullptr, __FILE__, __LINE__, []() -> parent_class* { \
|
||||
return new GTEST_TEST_CLASS_NAME_(test_case_name, test_name)(); \
|
||||
}); \
|
||||
void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody()
|
||||
#define XLA_TEST_F(test_fixture, test_name) TEST_F(test_fixture, test_name)
|
||||
|
||||
// This is identical to the TEST_F macro from "gtest", but it potentially
|
||||
// disables the test based on an external manifest file, DISABLED_MANIFEST.
|
||||
//
|
||||
// Per usual, you can see what tests are available via --gunit_list_tests and
|
||||
// choose to run tests that have been disabled via the manifest via
|
||||
// --gunit_also_run_disabled_tests.
|
||||
#define XLA_TEST_F(test_fixture, test_name) \
|
||||
XLA_GTEST_TEST_(test_fixture, test_name, test_fixture)
|
||||
#define XLA_TEST_P(test_case_name, test_name) TEST_P(test_case_name, test_name)
|
||||
|
||||
// Likewise, this is identical to the TEST_P macro from "gtest", but
|
||||
// potentially disables the test based on the DISABLED_MANIFEST file.
|
||||
//
|
||||
// We have to wrap this in an outer layer so that any DISABLED_ON_* macros will
|
||||
// be properly expanded before the stringification occurs.
|
||||
#define XLA_TEST_P_IMPL_(test_case_name, test_name) \
|
||||
class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \
|
||||
: public test_case_name { \
|
||||
public: \
|
||||
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \
|
||||
virtual void TestBody(); \
|
||||
\
|
||||
private: \
|
||||
static int AddToRegistry() { \
|
||||
::testing::UnitTest::GetInstance() \
|
||||
->parameterized_test_registry() \
|
||||
.GetTestCasePatternHolder<test_case_name>( \
|
||||
#test_case_name, \
|
||||
::testing::internal::CodeLocation(__FILE__, __LINE__)) \
|
||||
->AddTestPattern( \
|
||||
#test_case_name, \
|
||||
::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \
|
||||
.c_str(), \
|
||||
new ::testing::internal::TestMetaFactory<GTEST_TEST_CLASS_NAME_( \
|
||||
test_case_name, test_name)>()); \
|
||||
return 0; \
|
||||
} \
|
||||
static int gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_; \
|
||||
GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \
|
||||
test_name)); \
|
||||
}; \
|
||||
int GTEST_TEST_CLASS_NAME_(test_case_name, \
|
||||
test_name)::gtest_registering_dummy_ = \
|
||||
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::AddToRegistry(); \
|
||||
void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody()
|
||||
|
||||
#define XLA_TEST_P(test_case_name, test_name) \
|
||||
XLA_TEST_P_IMPL_(test_case_name, test_name)
|
||||
|
||||
// This is identical to the TEST_F macro from "gtest", but it potentially
|
||||
// disables the test based on an external manifest file, DISABLED_MANIFEST.
|
||||
#define XLA_TYPED_TEST(CaseName, TestName) \
|
||||
template <typename gtest_TypeParam_> \
|
||||
class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \
|
||||
: public CaseName<gtest_TypeParam_> { \
|
||||
private: \
|
||||
typedef CaseName<gtest_TypeParam_> TestFixture; \
|
||||
typedef gtest_TypeParam_ TypeParam; \
|
||||
virtual void TestBody(); \
|
||||
}; \
|
||||
bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ = \
|
||||
::testing::internal::TypeParameterizedTest< \
|
||||
CaseName, \
|
||||
::testing::internal::TemplateSel<GTEST_TEST_CLASS_NAME_(CaseName, \
|
||||
TestName)>, \
|
||||
GTEST_TYPE_PARAMS_(CaseName)>:: \
|
||||
Register( \
|
||||
"", ::testing::internal::CodeLocation(__FILE__, __LINE__), \
|
||||
#CaseName, \
|
||||
::xla::PrependDisabledIfIndicated(#CaseName, #TestName).c_str(), \
|
||||
0); \
|
||||
template <typename gtest_TypeParam_> \
|
||||
void GTEST_TEST_CLASS_NAME_(CaseName, \
|
||||
TestName)<gtest_TypeParam_>::TestBody()
|
||||
#define XLA_TYPED_TEST(CaseName, TestName) TYPED_TEST(CaseName, TestName)
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
||||
|
|
Loading…
Reference in New Issue