[XLA] Introduce ManifestCheckingTest
PiperOrigin-RevId: 317229603 Change-Id: Ibcc9ea3895d520024f5d80d52330aeb3b970585d
This commit is contained in:
parent
13fe5862de
commit
4a14e778d6
|
@ -52,16 +52,26 @@ cc_library(
|
||||||
name = "test_macros_header",
|
name = "test_macros_header",
|
||||||
testonly = True,
|
testonly = True,
|
||||||
hdrs = ["test_macros.h"],
|
hdrs = ["test_macros.h"],
|
||||||
deps = [
|
|
||||||
"//tensorflow/compiler/xla:types",
|
|
||||||
"//tensorflow/core:test",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate a test_macros_${BACKEND} library per backend with the proper copts.
|
# Generate a test_macros_${BACKEND} library per backend with the proper copts.
|
||||||
generate_backend_test_macros()
|
generate_backend_test_macros()
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "manifest_checking_test",
|
||||||
|
testonly = True,
|
||||||
|
srcs = ["manifest_checking_test.cc"],
|
||||||
|
hdrs = ["manifest_checking_test.h"],
|
||||||
|
deps = [
|
||||||
|
":test_macros_header",
|
||||||
|
"//tensorflow/core:regexp_internal",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core/platform:logging",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "test_utils",
|
name = "test_utils",
|
||||||
srcs = ["test_utils.cc"],
|
srcs = ["test_utils.cc"],
|
||||||
|
@ -136,6 +146,7 @@ cc_library(
|
||||||
hdrs = ["hlo_test_base.h"],
|
hdrs = ["hlo_test_base.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":literal_test_util",
|
":literal_test_util",
|
||||||
|
":manifest_checking_test",
|
||||||
":test_utils",
|
":test_utils",
|
||||||
":verified_hlo_module",
|
":verified_hlo_module",
|
||||||
"//tensorflow/compiler/xla:debug_options_flags",
|
"//tensorflow/compiler/xla:debug_options_flags",
|
||||||
|
@ -193,6 +204,7 @@ cc_library(
|
||||||
srcs = ["client_library_test_base.cc"],
|
srcs = ["client_library_test_base.cc"],
|
||||||
hdrs = ["client_library_test_base.h"],
|
hdrs = ["client_library_test_base.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":manifest_checking_test",
|
||||||
"//tensorflow/compiler/xla:array2d",
|
"//tensorflow/compiler/xla:array2d",
|
||||||
"//tensorflow/compiler/xla:array3d",
|
"//tensorflow/compiler/xla:array3d",
|
||||||
"//tensorflow/compiler/xla:array4d",
|
"//tensorflow/compiler/xla:array4d",
|
||||||
|
@ -273,6 +285,7 @@ cc_library(
|
||||||
hdrs = ["local_client_test_base.h"],
|
hdrs = ["local_client_test_base.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":client_library_test_base",
|
":client_library_test_base",
|
||||||
|
":manifest_checking_test",
|
||||||
":verified_hlo_module",
|
":verified_hlo_module",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
|
|
@ -266,11 +266,6 @@ def generate_backend_test_macros(backends = []):
|
||||||
"-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
|
"-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"//tensorflow/core/platform:logging",
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"//tensorflow/compiler/xla:types",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:regexp_internal",
|
|
||||||
"//tensorflow/core:test",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
|
||||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/bitmap.h"
|
#include "tensorflow/core/lib/core/bitmap.h"
|
||||||
|
@ -62,7 +63,7 @@ std::vector<TestCase> ExpandUseBfloat16(
|
||||||
}
|
}
|
||||||
|
|
||||||
// A client library test establishes an in-process XLA client connection.
|
// A client library test establishes an in-process XLA client connection.
|
||||||
class ClientLibraryTestBase : public ::testing::Test {
|
class ClientLibraryTestBase : public ManifestCheckingTest {
|
||||||
protected:
|
protected:
|
||||||
explicit ClientLibraryTestBase(se::Platform* platform = nullptr);
|
explicit ClientLibraryTestBase(se::Platform* platform = nullptr);
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||||
#include "tensorflow/compiler/xla/shape_layout.h"
|
#include "tensorflow/compiler/xla/shape_layout.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
|
||||||
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
@ -67,7 +68,7 @@ namespace xla {
|
||||||
// )
|
// )
|
||||||
//
|
//
|
||||||
// For a more detailed example, see "../tests/sample_text_test.cc".
|
// For a more detailed example, see "../tests/sample_text_test.cc".
|
||||||
class HloTestBase : public ::testing::Test {
|
class HloTestBase : public ManifestCheckingTest {
|
||||||
public:
|
public:
|
||||||
// Creates a new HLO module for a test. The module created will have
|
// Creates a new HLO module for a test. The module created will have
|
||||||
// TestName() for its name; it will also automatically populate its debug
|
// TestName() for its name; it will also automatically populate its debug
|
||||||
|
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
|
||||||
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
@ -75,7 +76,7 @@ class TestAllocator : public se::StreamExecutorMemoryAllocator {
|
||||||
};
|
};
|
||||||
|
|
||||||
// A base class for tests which exercise the LocalClient interface.
|
// A base class for tests which exercise the LocalClient interface.
|
||||||
class LocalClientTestBase : public ::testing::Test {
|
class LocalClientTestBase : public ManifestCheckingTest {
|
||||||
protected:
|
protected:
|
||||||
struct EigenThreadPoolWrapper;
|
struct EigenThreadPoolWrapper;
|
||||||
explicit LocalClientTestBase(se::Platform* platform = nullptr);
|
explicit LocalClientTestBase(se::Platform* platform = nullptr);
|
||||||
|
|
|
@ -0,0 +1,129 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <iterator>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/strings/ascii.h"
|
||||||
|
#include "absl/strings/str_split.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/regexp.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is
|
||||||
|
// disabled - a sequence of regexps.
|
||||||
|
using ManifestT = absl::flat_hash_map<std::string, std::vector<std::string>>;
|
||||||
|
|
||||||
|
ManifestT ReadManifest() {
|
||||||
|
ManifestT manifest;
|
||||||
|
|
||||||
|
absl::string_view path = absl::NullSafeStringView(kDisabledManifestPath);
|
||||||
|
if (path.empty()) {
|
||||||
|
return manifest;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: parens are required to disambiguate vs function decl.
|
||||||
|
std::ifstream file_stream((std::string(path)));
|
||||||
|
std::string contents((std::istreambuf_iterator<char>(file_stream)),
|
||||||
|
std::istreambuf_iterator<char>());
|
||||||
|
|
||||||
|
std::vector<std::string> lines = absl::StrSplit(contents, '\n');
|
||||||
|
for (std::string& line : lines) {
|
||||||
|
auto comment = line.find("//");
|
||||||
|
if (comment != std::string::npos) {
|
||||||
|
line = line.substr(0, comment);
|
||||||
|
}
|
||||||
|
if (line.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
absl::StripTrailingAsciiWhitespace(&line);
|
||||||
|
std::vector<std::string> pieces = absl::StrSplit(line, ' ');
|
||||||
|
CHECK_GE(pieces.size(), 1);
|
||||||
|
auto& platforms = manifest[pieces[0]];
|
||||||
|
for (size_t i = 1; i < pieces.size(); ++i) {
|
||||||
|
platforms.push_back(pieces[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return manifest;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void ManifestCheckingTest::SetUp() {
|
||||||
|
const testing::TestInfo* test_info =
|
||||||
|
testing::UnitTest::GetInstance()->current_test_info();
|
||||||
|
absl::string_view test_case_name = test_info->test_suite_name();
|
||||||
|
absl::string_view test_name = test_info->name();
|
||||||
|
VLOG(1) << "test_case_name: " << test_case_name;
|
||||||
|
VLOG(1) << "test_name: " << test_name;
|
||||||
|
|
||||||
|
// Remove the type suffix from the test case name.
|
||||||
|
if (const char* type_param = test_info->type_param()) {
|
||||||
|
VLOG(1) << "type_param: " << type_param;
|
||||||
|
size_t last_slash = test_case_name.rfind('/');
|
||||||
|
test_case_name = test_case_name.substr(0, last_slash);
|
||||||
|
VLOG(1) << "test_case_name: " << test_case_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the test instantiation name if it is present.
|
||||||
|
auto first_slash = test_case_name.find('/');
|
||||||
|
if (first_slash != test_case_name.npos) {
|
||||||
|
test_case_name.remove_prefix(first_slash + 1);
|
||||||
|
VLOG(1) << "test_case_name: " << test_case_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
ManifestT manifest = ReadManifest();
|
||||||
|
|
||||||
|
// If the test name ends with a slash followed by one or more characters,
|
||||||
|
// strip that off.
|
||||||
|
auto last_slash = test_name.rfind('/');
|
||||||
|
if (last_slash != test_name.npos) {
|
||||||
|
test_name = test_name.substr(0, last_slash);
|
||||||
|
VLOG(1) << "test_name: " << test_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
// First try full match: test_case_name.test_name
|
||||||
|
// If that fails, try to find just the test_case_name; this would disable all
|
||||||
|
// tests in the test case.
|
||||||
|
auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
|
||||||
|
if (it == manifest.end()) {
|
||||||
|
it = manifest.find(test_case_name);
|
||||||
|
if (it == manifest.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect a full match vs. one of the platform regexps to disable the test.
|
||||||
|
const std::vector<std::string>& disabled_platforms = it->second;
|
||||||
|
auto platform_string = kTestPlatform;
|
||||||
|
for (const auto& s : disabled_platforms) {
|
||||||
|
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
|
||||||
|
GTEST_SKIP();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We didn't hit in the disabled manifest entries, so don't disable it.
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
|
@ -0,0 +1,35 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// This class allows us to intercept the test name and use an arbitrary
|
||||||
|
// heuristic to decide whether the test case should be disabled. We
|
||||||
|
// determine whether the test case should be disabled by resolving the (test
|
||||||
|
// case name, test name) in a manifest file.
|
||||||
|
class ManifestCheckingTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
// This method runs before each test runs.
|
||||||
|
void SetUp() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
|
|
@ -15,93 +15,18 @@ limitations under the License.
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
|
|
||||||
#include <fstream>
|
|
||||||
#include <streambuf>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "absl/strings/str_split.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/regexp.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
|
||||||
|
|
||||||
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is
|
static bool InitModule() {
|
||||||
// disabled - a sequence of regexps.
|
kDisabledManifestPath = XLA_DISABLED_MANIFEST;
|
||||||
using ManifestT = absl::flat_hash_map<string, std::vector<string>>;
|
VLOG(1) << "kDisabledManifestPath: " << kDisabledManifestPath;
|
||||||
|
kTestPlatform = XLA_PLATFORM;
|
||||||
ManifestT ReadManifest() {
|
VLOG(1) << "kTestPlatform: " << kTestPlatform;
|
||||||
ManifestT manifest;
|
return false;
|
||||||
|
|
||||||
string path = XLA_DISABLED_MANIFEST;
|
|
||||||
if (path.empty()) {
|
|
||||||
return manifest;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ifstream file_stream(path);
|
static bool module_initialized = InitModule();
|
||||||
// Note: parens are required to disambiguate vs function decl.
|
|
||||||
string contents((std::istreambuf_iterator<char>(file_stream)),
|
|
||||||
std::istreambuf_iterator<char>());
|
|
||||||
|
|
||||||
std::vector<string> lines = absl::StrSplit(contents, '\n');
|
|
||||||
for (string& line : lines) {
|
|
||||||
auto comment = line.find("//");
|
|
||||||
if (comment != string::npos) {
|
|
||||||
line = line.substr(0, comment);
|
|
||||||
}
|
|
||||||
if (line.empty()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
absl::StripTrailingAsciiWhitespace(&line);
|
|
||||||
std::vector<string> pieces = absl::StrSplit(line, ' ');
|
|
||||||
CHECK_GE(pieces.size(), 1);
|
|
||||||
auto& platforms = manifest[pieces[0]];
|
|
||||||
for (int64 i = 1; i < pieces.size(); ++i) {
|
|
||||||
platforms.push_back(pieces[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return manifest;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::string PrependDisabledIfIndicated(absl::string_view test_case_name,
|
|
||||||
absl::string_view test_name) {
|
|
||||||
ManifestT manifest = ReadManifest();
|
|
||||||
|
|
||||||
// If the test name ends with a slash followed by one or more digits, strip
|
|
||||||
// that off; this is just a shard number, and matching on this would be
|
|
||||||
// unstable even if someone wanted to do it.
|
|
||||||
static LazyRE2 shard_num_pattern = {R"(/\d+$)"};
|
|
||||||
absl::string_view suffix;
|
|
||||||
if (RE2::PartialMatch(test_name, *shard_num_pattern, &suffix)) {
|
|
||||||
test_name.remove_suffix(suffix.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
// First try full match: test_case_name.test_name
|
|
||||||
// If that fails, try to find just the test_case_name; this would disable all
|
|
||||||
// tests in the test case.
|
|
||||||
auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
|
|
||||||
if (it == manifest.end()) {
|
|
||||||
it = manifest.find(test_case_name);
|
|
||||||
if (it == manifest.end()) {
|
|
||||||
return std::string(test_name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expect a full match vs. one of the platform regexps to disable the test.
|
|
||||||
const std::vector<string>& disabled_platforms = it->second;
|
|
||||||
string platform_string = XLA_PLATFORM;
|
|
||||||
for (const auto& s : disabled_platforms) {
|
|
||||||
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
|
|
||||||
return absl::StrCat("DISABLED_", test_name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We didn't hit in the disabled manifest entries, so don't disable it.
|
|
||||||
return std::string(test_name);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
@ -28,12 +28,6 @@ limitations under the License.
|
||||||
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
#define TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
||||||
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
|
||||||
#include "tensorflow/core/platform/test.h"
|
|
||||||
|
|
||||||
#define DISABLED_ON_CPU(X) X
|
#define DISABLED_ON_CPU(X) X
|
||||||
#define DISABLED_ON_GPU(X) X
|
#define DISABLED_ON_GPU(X) X
|
||||||
#define DISABLED_ON_GPU_ROCM(X) X
|
#define DISABLED_ON_GPU_ROCM(X) X
|
||||||
|
@ -79,117 +73,15 @@ limitations under the License.
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// Reads a disabled manifest file to resolve whether test cases should be
|
inline const char *kDisabledManifestPath = nullptr;
|
||||||
// disabled on a particular platform. For a test that should be disabled,
|
inline const char *kTestPlatform = nullptr;
|
||||||
// returns DISABLED_ prepended to its name; otherwise returns the test name
|
|
||||||
// unmodified.
|
|
||||||
std::string PrependDisabledIfIndicated(absl::string_view test_case_name,
|
|
||||||
absl::string_view test_name);
|
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
// This is the internal "gtest" class instantiation -- it is identical to the
|
#define XLA_TEST_F(test_fixture, test_name) TEST_F(test_fixture, test_name)
|
||||||
// GTEST_TEST_ macro, except that we intercept the test name for potential
|
|
||||||
// modification by PrependDisabledIfIndicated. That file can use an arbitrary
|
|
||||||
// heuristic to decide whether the test case should be disabled, and we
|
|
||||||
// determine whether the test case should be disabled by resolving the (test
|
|
||||||
// case name, test name) in a manifest file.
|
|
||||||
#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class) \
|
|
||||||
class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \
|
|
||||||
: public parent_class { \
|
|
||||||
public: \
|
|
||||||
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \
|
|
||||||
\
|
|
||||||
private: \
|
|
||||||
virtual void TestBody(); \
|
|
||||||
static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \
|
|
||||||
GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \
|
|
||||||
test_name)); \
|
|
||||||
}; \
|
|
||||||
\
|
|
||||||
::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \
|
|
||||||
test_name)::test_info_ = \
|
|
||||||
::testing::RegisterTest( \
|
|
||||||
#test_case_name, \
|
|
||||||
::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \
|
|
||||||
.c_str(), \
|
|
||||||
nullptr, nullptr, __FILE__, __LINE__, []() -> parent_class* { \
|
|
||||||
return new GTEST_TEST_CLASS_NAME_(test_case_name, test_name)(); \
|
|
||||||
}); \
|
|
||||||
void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody()
|
|
||||||
|
|
||||||
// This is identical to the TEST_F macro from "gtest", but it potentially
|
#define XLA_TEST_P(test_case_name, test_name) TEST_P(test_case_name, test_name)
|
||||||
// disables the test based on an external manifest file, DISABLED_MANIFEST.
|
|
||||||
//
|
|
||||||
// Per usual, you can see what tests are available via --gunit_list_tests and
|
|
||||||
// choose to run tests that have been disabled via the manifest via
|
|
||||||
// --gunit_also_run_disabled_tests.
|
|
||||||
#define XLA_TEST_F(test_fixture, test_name) \
|
|
||||||
XLA_GTEST_TEST_(test_fixture, test_name, test_fixture)
|
|
||||||
|
|
||||||
// Likewise, this is identical to the TEST_P macro from "gtest", but
|
#define XLA_TYPED_TEST(CaseName, TestName) TYPED_TEST(CaseName, TestName)
|
||||||
// potentially disables the test based on the DISABLED_MANIFEST file.
|
|
||||||
//
|
|
||||||
// We have to wrap this in an outer layer so that any DISABLED_ON_* macros will
|
|
||||||
// be properly expanded before the stringification occurs.
|
|
||||||
#define XLA_TEST_P_IMPL_(test_case_name, test_name) \
|
|
||||||
class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \
|
|
||||||
: public test_case_name { \
|
|
||||||
public: \
|
|
||||||
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \
|
|
||||||
virtual void TestBody(); \
|
|
||||||
\
|
|
||||||
private: \
|
|
||||||
static int AddToRegistry() { \
|
|
||||||
::testing::UnitTest::GetInstance() \
|
|
||||||
->parameterized_test_registry() \
|
|
||||||
.GetTestCasePatternHolder<test_case_name>( \
|
|
||||||
#test_case_name, \
|
|
||||||
::testing::internal::CodeLocation(__FILE__, __LINE__)) \
|
|
||||||
->AddTestPattern( \
|
|
||||||
#test_case_name, \
|
|
||||||
::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \
|
|
||||||
.c_str(), \
|
|
||||||
new ::testing::internal::TestMetaFactory<GTEST_TEST_CLASS_NAME_( \
|
|
||||||
test_case_name, test_name)>()); \
|
|
||||||
return 0; \
|
|
||||||
} \
|
|
||||||
static int gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_; \
|
|
||||||
GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \
|
|
||||||
test_name)); \
|
|
||||||
}; \
|
|
||||||
int GTEST_TEST_CLASS_NAME_(test_case_name, \
|
|
||||||
test_name)::gtest_registering_dummy_ = \
|
|
||||||
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::AddToRegistry(); \
|
|
||||||
void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody()
|
|
||||||
|
|
||||||
#define XLA_TEST_P(test_case_name, test_name) \
|
|
||||||
XLA_TEST_P_IMPL_(test_case_name, test_name)
|
|
||||||
|
|
||||||
// This is identical to the TEST_F macro from "gtest", but it potentially
|
|
||||||
// disables the test based on an external manifest file, DISABLED_MANIFEST.
|
|
||||||
#define XLA_TYPED_TEST(CaseName, TestName) \
|
|
||||||
template <typename gtest_TypeParam_> \
|
|
||||||
class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \
|
|
||||||
: public CaseName<gtest_TypeParam_> { \
|
|
||||||
private: \
|
|
||||||
typedef CaseName<gtest_TypeParam_> TestFixture; \
|
|
||||||
typedef gtest_TypeParam_ TypeParam; \
|
|
||||||
virtual void TestBody(); \
|
|
||||||
}; \
|
|
||||||
bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ = \
|
|
||||||
::testing::internal::TypeParameterizedTest< \
|
|
||||||
CaseName, \
|
|
||||||
::testing::internal::TemplateSel<GTEST_TEST_CLASS_NAME_(CaseName, \
|
|
||||||
TestName)>, \
|
|
||||||
GTEST_TYPE_PARAMS_(CaseName)>:: \
|
|
||||||
Register( \
|
|
||||||
"", ::testing::internal::CodeLocation(__FILE__, __LINE__), \
|
|
||||||
#CaseName, \
|
|
||||||
::xla::PrependDisabledIfIndicated(#CaseName, #TestName).c_str(), \
|
|
||||||
0); \
|
|
||||||
template <typename gtest_TypeParam_> \
|
|
||||||
void GTEST_TEST_CLASS_NAME_(CaseName, \
|
|
||||||
TestName)<gtest_TypeParam_>::TestBody()
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
||||||
|
|
Loading…
Reference in New Issue