[XLA] Introduce ManifestCheckingTest

PiperOrigin-RevId: 317229603
Change-Id: Ibcc9ea3895d520024f5d80d52330aeb3b970585d
This commit is contained in:
David Majnemer 2020-06-18 19:11:40 -07:00 committed by TensorFlower Gardener
parent 13fe5862de
commit 4a14e778d6
9 changed files with 201 additions and 209 deletions

View File

@ -52,16 +52,26 @@ cc_library(
name = "test_macros_header", name = "test_macros_header",
testonly = True, testonly = True,
hdrs = ["test_macros.h"], hdrs = ["test_macros.h"],
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
],
) )
# Generate a test_macros_${BACKEND} library per backend with the proper copts. # Generate a test_macros_${BACKEND} library per backend with the proper copts.
generate_backend_test_macros() generate_backend_test_macros()
cc_library(
name = "manifest_checking_test",
testonly = True,
srcs = ["manifest_checking_test.cc"],
hdrs = ["manifest_checking_test.h"],
deps = [
":test_macros_header",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
"//tensorflow/core/platform:logging",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
cc_library( cc_library(
name = "test_utils", name = "test_utils",
srcs = ["test_utils.cc"], srcs = ["test_utils.cc"],
@ -136,6 +146,7 @@ cc_library(
hdrs = ["hlo_test_base.h"], hdrs = ["hlo_test_base.h"],
deps = [ deps = [
":literal_test_util", ":literal_test_util",
":manifest_checking_test",
":test_utils", ":test_utils",
":verified_hlo_module", ":verified_hlo_module",
"//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:debug_options_flags",
@ -193,6 +204,7 @@ cc_library(
srcs = ["client_library_test_base.cc"], srcs = ["client_library_test_base.cc"],
hdrs = ["client_library_test_base.h"], hdrs = ["client_library_test_base.h"],
deps = [ deps = [
":manifest_checking_test",
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:array4d",
@ -273,6 +285,7 @@ cc_library(
hdrs = ["local_client_test_base.h"], hdrs = ["local_client_test_base.h"],
deps = [ deps = [
":client_library_test_base", ":client_library_test_base",
":manifest_checking_test",
":verified_hlo_module", ":verified_hlo_module",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",

View File

@ -266,11 +266,6 @@ def generate_backend_test_macros(backends = []):
"-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest, "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
], ],
deps = [ deps = [
"@com_google_absl//absl/container:flat_hash_map", "//tensorflow/core/platform:logging",
"@com_google_absl//absl/strings",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
], ],
) )

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/bitmap.h"
@ -62,7 +63,7 @@ std::vector<TestCase> ExpandUseBfloat16(
} }
// A client library test establishes an in-process XLA client connection. // A client library test establishes an in-process XLA client connection.
class ClientLibraryTestBase : public ::testing::Test { class ClientLibraryTestBase : public ManifestCheckingTest {
protected: protected:
explicit ClientLibraryTestBase(se::Platform* platform = nullptr); explicit ClientLibraryTestBase(se::Platform* platform = nullptr);

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
@ -67,7 +68,7 @@ namespace xla {
// ) // )
// //
// For a more detailed example, see "../tests/sample_text_test.cc". // For a more detailed example, see "../tests/sample_text_test.cc".
class HloTestBase : public ::testing::Test { class HloTestBase : public ManifestCheckingTest {
public: public:
// Creates a new HLO module for a test. The module created will have // Creates a new HLO module for a test. The module created will have
// TestName() for its name; it will also automatically populate its debug // TestName() for its name; it will also automatically populate its debug

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
@ -75,7 +76,7 @@ class TestAllocator : public se::StreamExecutorMemoryAllocator {
}; };
// A base class for tests which exercise the LocalClient interface. // A base class for tests which exercise the LocalClient interface.
class LocalClientTestBase : public ::testing::Test { class LocalClientTestBase : public ManifestCheckingTest {
protected: protected:
struct EigenThreadPoolWrapper; struct EigenThreadPoolWrapper;
explicit LocalClientTestBase(se::Platform* platform = nullptr); explicit LocalClientTestBase(se::Platform* platform = nullptr);

View File

@ -0,0 +1,129 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
#include <fstream>
#include <iterator>
#include <string>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla {
namespace {
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is
// disabled - a sequence of regexps.
using ManifestT = absl::flat_hash_map<std::string, std::vector<std::string>>;
ManifestT ReadManifest() {
ManifestT manifest;
absl::string_view path = absl::NullSafeStringView(kDisabledManifestPath);
if (path.empty()) {
return manifest;
}
// Note: parens are required to disambiguate vs function decl.
std::ifstream file_stream((std::string(path)));
std::string contents((std::istreambuf_iterator<char>(file_stream)),
std::istreambuf_iterator<char>());
std::vector<std::string> lines = absl::StrSplit(contents, '\n');
for (std::string& line : lines) {
auto comment = line.find("//");
if (comment != std::string::npos) {
line = line.substr(0, comment);
}
if (line.empty()) {
continue;
}
absl::StripTrailingAsciiWhitespace(&line);
std::vector<std::string> pieces = absl::StrSplit(line, ' ');
CHECK_GE(pieces.size(), 1);
auto& platforms = manifest[pieces[0]];
for (size_t i = 1; i < pieces.size(); ++i) {
platforms.push_back(pieces[i]);
}
}
return manifest;
}
} // namespace
void ManifestCheckingTest::SetUp() {
const testing::TestInfo* test_info =
testing::UnitTest::GetInstance()->current_test_info();
absl::string_view test_case_name = test_info->test_suite_name();
absl::string_view test_name = test_info->name();
VLOG(1) << "test_case_name: " << test_case_name;
VLOG(1) << "test_name: " << test_name;
// Remove the type suffix from the test case name.
if (const char* type_param = test_info->type_param()) {
VLOG(1) << "type_param: " << type_param;
size_t last_slash = test_case_name.rfind('/');
test_case_name = test_case_name.substr(0, last_slash);
VLOG(1) << "test_case_name: " << test_case_name;
}
// Remove the test instantiation name if it is present.
auto first_slash = test_case_name.find('/');
if (first_slash != test_case_name.npos) {
test_case_name.remove_prefix(first_slash + 1);
VLOG(1) << "test_case_name: " << test_case_name;
}
ManifestT manifest = ReadManifest();
// If the test name ends with a slash followed by one or more characters,
// strip that off.
auto last_slash = test_name.rfind('/');
if (last_slash != test_name.npos) {
test_name = test_name.substr(0, last_slash);
VLOG(1) << "test_name: " << test_name;
}
// First try full match: test_case_name.test_name
// If that fails, try to find just the test_case_name; this would disable all
// tests in the test case.
auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
if (it == manifest.end()) {
it = manifest.find(test_case_name);
if (it == manifest.end()) {
return;
}
}
// Expect a full match vs. one of the platform regexps to disable the test.
const std::vector<std::string>& disabled_platforms = it->second;
auto platform_string = kTestPlatform;
for (const auto& s : disabled_platforms) {
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
GTEST_SKIP();
return;
}
}
// We didn't hit in the disabled manifest entries, so don't disable it.
}
} // namespace xla

View File

@ -0,0 +1,35 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
#define TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
#include "tensorflow/core/platform/test.h"
namespace xla {
// This class allows us to intercept the test name and use an arbitrary
// heuristic to decide whether the test case should be disabled. We
// determine whether the test case should be disabled by resolving the (test
// case name, test name) in a manifest file.
class ManifestCheckingTest : public ::testing::Test {
protected:
// This method runs before each test runs.
void SetUp() override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_

View File

@ -15,93 +15,18 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_macros.h"
#include <fstream>
#include <streambuf>
#include <string>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla { namespace xla {
namespace {
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is static bool InitModule() {
// disabled - a sequence of regexps. kDisabledManifestPath = XLA_DISABLED_MANIFEST;
using ManifestT = absl::flat_hash_map<string, std::vector<string>>; VLOG(1) << "kDisabledManifestPath: " << kDisabledManifestPath;
kTestPlatform = XLA_PLATFORM;
ManifestT ReadManifest() { VLOG(1) << "kTestPlatform: " << kTestPlatform;
ManifestT manifest; return false;
string path = XLA_DISABLED_MANIFEST;
if (path.empty()) {
return manifest;
} }
std::ifstream file_stream(path); static bool module_initialized = InitModule();
// Note: parens are required to disambiguate vs function decl.
string contents((std::istreambuf_iterator<char>(file_stream)),
std::istreambuf_iterator<char>());
std::vector<string> lines = absl::StrSplit(contents, '\n');
for (string& line : lines) {
auto comment = line.find("//");
if (comment != string::npos) {
line = line.substr(0, comment);
}
if (line.empty()) {
continue;
}
absl::StripTrailingAsciiWhitespace(&line);
std::vector<string> pieces = absl::StrSplit(line, ' ');
CHECK_GE(pieces.size(), 1);
auto& platforms = manifest[pieces[0]];
for (int64 i = 1; i < pieces.size(); ++i) {
platforms.push_back(pieces[i]);
}
}
return manifest;
}
} // namespace
std::string PrependDisabledIfIndicated(absl::string_view test_case_name,
absl::string_view test_name) {
ManifestT manifest = ReadManifest();
// If the test name ends with a slash followed by one or more digits, strip
// that off; this is just a shard number, and matching on this would be
// unstable even if someone wanted to do it.
static LazyRE2 shard_num_pattern = {R"(/\d+$)"};
absl::string_view suffix;
if (RE2::PartialMatch(test_name, *shard_num_pattern, &suffix)) {
test_name.remove_suffix(suffix.size());
}
// First try full match: test_case_name.test_name
// If that fails, try to find just the test_case_name; this would disable all
// tests in the test case.
auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
if (it == manifest.end()) {
it = manifest.find(test_case_name);
if (it == manifest.end()) {
return std::string(test_name);
}
}
// Expect a full match vs. one of the platform regexps to disable the test.
const std::vector<string>& disabled_platforms = it->second;
string platform_string = XLA_PLATFORM;
for (const auto& s : disabled_platforms) {
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
return absl::StrCat("DISABLED_", test_name);
}
}
// We didn't hit in the disabled manifest entries, so don't disable it.
return std::string(test_name);
}
} // namespace xla } // namespace xla

View File

@ -28,12 +28,6 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ #ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
#define TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ #define TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
#include <string>
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/test.h"
#define DISABLED_ON_CPU(X) X #define DISABLED_ON_CPU(X) X
#define DISABLED_ON_GPU(X) X #define DISABLED_ON_GPU(X) X
#define DISABLED_ON_GPU_ROCM(X) X #define DISABLED_ON_GPU_ROCM(X) X
@ -79,117 +73,15 @@ limitations under the License.
namespace xla { namespace xla {
// Reads a disabled manifest file to resolve whether test cases should be inline const char *kDisabledManifestPath = nullptr;
// disabled on a particular platform. For a test that should be disabled, inline const char *kTestPlatform = nullptr;
// returns DISABLED_ prepended to its name; otherwise returns the test name
// unmodified.
std::string PrependDisabledIfIndicated(absl::string_view test_case_name,
absl::string_view test_name);
} // namespace xla } // namespace xla
// This is the internal "gtest" class instantiation -- it is identical to the #define XLA_TEST_F(test_fixture, test_name) TEST_F(test_fixture, test_name)
// GTEST_TEST_ macro, except that we intercept the test name for potential
// modification by PrependDisabledIfIndicated. That file can use an arbitrary
// heuristic to decide whether the test case should be disabled, and we
// determine whether the test case should be disabled by resolving the (test
// case name, test name) in a manifest file.
#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class) \
class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \
: public parent_class { \
public: \
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \
\
private: \
virtual void TestBody(); \
static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \
GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \
test_name)); \
}; \
\
::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \
test_name)::test_info_ = \
::testing::RegisterTest( \
#test_case_name, \
::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \
.c_str(), \
nullptr, nullptr, __FILE__, __LINE__, []() -> parent_class* { \
return new GTEST_TEST_CLASS_NAME_(test_case_name, test_name)(); \
}); \
void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody()
// This is identical to the TEST_F macro from "gtest", but it potentially #define XLA_TEST_P(test_case_name, test_name) TEST_P(test_case_name, test_name)
// disables the test based on an external manifest file, DISABLED_MANIFEST.
//
// Per usual, you can see what tests are available via --gunit_list_tests and
// choose to run tests that have been disabled via the manifest via
// --gunit_also_run_disabled_tests.
#define XLA_TEST_F(test_fixture, test_name) \
XLA_GTEST_TEST_(test_fixture, test_name, test_fixture)
// Likewise, this is identical to the TEST_P macro from "gtest", but #define XLA_TYPED_TEST(CaseName, TestName) TYPED_TEST(CaseName, TestName)
// potentially disables the test based on the DISABLED_MANIFEST file.
//
// We have to wrap this in an outer layer so that any DISABLED_ON_* macros will
// be properly expanded before the stringification occurs.
#define XLA_TEST_P_IMPL_(test_case_name, test_name) \
class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \
: public test_case_name { \
public: \
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \
virtual void TestBody(); \
\
private: \
static int AddToRegistry() { \
::testing::UnitTest::GetInstance() \
->parameterized_test_registry() \
.GetTestCasePatternHolder<test_case_name>( \
#test_case_name, \
::testing::internal::CodeLocation(__FILE__, __LINE__)) \
->AddTestPattern( \
#test_case_name, \
::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \
.c_str(), \
new ::testing::internal::TestMetaFactory<GTEST_TEST_CLASS_NAME_( \
test_case_name, test_name)>()); \
return 0; \
} \
static int gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_; \
GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \
test_name)); \
}; \
int GTEST_TEST_CLASS_NAME_(test_case_name, \
test_name)::gtest_registering_dummy_ = \
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::AddToRegistry(); \
void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody()
#define XLA_TEST_P(test_case_name, test_name) \
XLA_TEST_P_IMPL_(test_case_name, test_name)
// This is identical to the TEST_F macro from "gtest", but it potentially
// disables the test based on an external manifest file, DISABLED_MANIFEST.
#define XLA_TYPED_TEST(CaseName, TestName) \
template <typename gtest_TypeParam_> \
class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \
: public CaseName<gtest_TypeParam_> { \
private: \
typedef CaseName<gtest_TypeParam_> TestFixture; \
typedef gtest_TypeParam_ TypeParam; \
virtual void TestBody(); \
}; \
bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ = \
::testing::internal::TypeParameterizedTest< \
CaseName, \
::testing::internal::TemplateSel<GTEST_TEST_CLASS_NAME_(CaseName, \
TestName)>, \
GTEST_TYPE_PARAMS_(CaseName)>:: \
Register( \
"", ::testing::internal::CodeLocation(__FILE__, __LINE__), \
#CaseName, \
::xla::PrependDisabledIfIndicated(#CaseName, #TestName).c_str(), \
0); \
template <typename gtest_TypeParam_> \
void GTEST_TEST_CLASS_NAME_(CaseName, \
TestName)<gtest_TypeParam_>::TestBody()
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_