diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index e1863a8a4cf..9b36117602b 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -52,16 +52,26 @@ cc_library( name = "test_macros_header", testonly = True, hdrs = ["test_macros.h"], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/core:test", - "@com_google_absl//absl/strings", - ], ) # Generate a test_macros_${BACKEND} library per backend with the proper copts. generate_backend_test_macros() +cc_library( + name = "manifest_checking_test", + testonly = True, + srcs = ["manifest_checking_test.cc"], + hdrs = ["manifest_checking_test.h"], + deps = [ + ":test_macros_header", + "//tensorflow/core:regexp_internal", + "//tensorflow/core:test", + "//tensorflow/core/platform:logging", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "test_utils", srcs = ["test_utils.cc"], @@ -136,6 +146,7 @@ cc_library( hdrs = ["hlo_test_base.h"], deps = [ ":literal_test_util", + ":manifest_checking_test", ":test_utils", ":verified_hlo_module", "//tensorflow/compiler/xla:debug_options_flags", @@ -193,6 +204,7 @@ cc_library( srcs = ["client_library_test_base.cc"], hdrs = ["client_library_test_base.h"], deps = [ + ":manifest_checking_test", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", @@ -273,6 +285,7 @@ cc_library( hdrs = ["local_client_test_base.h"], deps = [ ":client_library_test_base", + ":manifest_checking_test", ":verified_hlo_module", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index c0c0751b0de..94d870aa2ef 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -266,11 +266,6 @@ def generate_backend_test_macros(backends = []): "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest, ], deps = [ - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", - "//tensorflow/core:test", + "//tensorflow/core/platform:logging", ], ) diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 790497f888e..17bb70bdb42 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/manifest_checking_test.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" @@ -62,7 +63,7 @@ std::vector ExpandUseBfloat16( } // A client library test establishes an in-process XLA client connection. -class ClientLibraryTestBase : public ::testing::Test { +class ClientLibraryTestBase : public ManifestCheckingTest { protected: explicit ClientLibraryTestBase(se::Platform* platform = nullptr); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 85b1876dd3c..17c2a55ba5b 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/manifest_checking_test.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -67,7 +68,7 @@ namespace xla { // ) // // For a more detailed example, see "../tests/sample_text_test.cc". -class HloTestBase : public ::testing::Test { +class HloTestBase : public ManifestCheckingTest { public: // Creates a new HLO module for a test. The module created will have // TestName() for its name; it will also automatically populate its debug diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index ea457024618..c1951ad1021 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/manifest_checking_test.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/mutex.h" @@ -75,7 +76,7 @@ class TestAllocator : public se::StreamExecutorMemoryAllocator { }; // A base class for tests which exercise the LocalClient interface. -class LocalClientTestBase : public ::testing::Test { +class LocalClientTestBase : public ManifestCheckingTest { protected: struct EigenThreadPoolWrapper; explicit LocalClientTestBase(se::Platform* platform = nullptr); diff --git a/tensorflow/compiler/xla/tests/manifest_checking_test.cc b/tensorflow/compiler/xla/tests/manifest_checking_test.cc new file mode 100644 index 00000000000..8806290472d --- /dev/null +++ b/tensorflow/compiler/xla/tests/manifest_checking_test.cc @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/manifest_checking_test.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_split.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" + +namespace xla { + +namespace { + +// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is +// disabled - a sequence of regexps. +using ManifestT = absl::flat_hash_map>; + +ManifestT ReadManifest() { + ManifestT manifest; + + absl::string_view path = absl::NullSafeStringView(kDisabledManifestPath); + if (path.empty()) { + return manifest; + } + + // Note: parens are required to disambiguate vs function decl. + std::ifstream file_stream((std::string(path))); + std::string contents((std::istreambuf_iterator(file_stream)), + std::istreambuf_iterator()); + + std::vector lines = absl::StrSplit(contents, '\n'); + for (std::string& line : lines) { + auto comment = line.find("//"); + if (comment != std::string::npos) { + line = line.substr(0, comment); + } + if (line.empty()) { + continue; + } + absl::StripTrailingAsciiWhitespace(&line); + std::vector pieces = absl::StrSplit(line, ' '); + CHECK_GE(pieces.size(), 1); + auto& platforms = manifest[pieces[0]]; + for (size_t i = 1; i < pieces.size(); ++i) { + platforms.push_back(pieces[i]); + } + } + return manifest; +} + +} // namespace + +void ManifestCheckingTest::SetUp() { + const testing::TestInfo* test_info = + testing::UnitTest::GetInstance()->current_test_info(); + absl::string_view test_case_name = test_info->test_suite_name(); + absl::string_view test_name = test_info->name(); + VLOG(1) << "test_case_name: " << test_case_name; + VLOG(1) << "test_name: " << test_name; + + // Remove the type suffix from the test case name. + if (const char* type_param = test_info->type_param()) { + VLOG(1) << "type_param: " << type_param; + size_t last_slash = test_case_name.rfind('/'); + test_case_name = test_case_name.substr(0, last_slash); + VLOG(1) << "test_case_name: " << test_case_name; + } + + // Remove the test instantiation name if it is present. + auto first_slash = test_case_name.find('/'); + if (first_slash != test_case_name.npos) { + test_case_name.remove_prefix(first_slash + 1); + VLOG(1) << "test_case_name: " << test_case_name; + } + + ManifestT manifest = ReadManifest(); + + // If the test name ends with a slash followed by one or more characters, + // strip that off. + auto last_slash = test_name.rfind('/'); + if (last_slash != test_name.npos) { + test_name = test_name.substr(0, last_slash); + VLOG(1) << "test_name: " << test_name; + } + + // First try full match: test_case_name.test_name + // If that fails, try to find just the test_case_name; this would disable all + // tests in the test case. + auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name)); + if (it == manifest.end()) { + it = manifest.find(test_case_name); + if (it == manifest.end()) { + return; + } + } + + // Expect a full match vs. one of the platform regexps to disable the test. + const std::vector& disabled_platforms = it->second; + auto platform_string = kTestPlatform; + for (const auto& s : disabled_platforms) { + if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) { + GTEST_SKIP(); + return; + } + } + + // We didn't hit in the disabled manifest entries, so don't disable it. +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/manifest_checking_test.h b/tensorflow/compiler/xla/tests/manifest_checking_test.h new file mode 100644 index 00000000000..4f44ed76a3e --- /dev/null +++ b/tensorflow/compiler/xla/tests/manifest_checking_test.h @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_ + +#include "tensorflow/core/platform/test.h" + +namespace xla { + +// This class allows us to intercept the test name and use an arbitrary +// heuristic to decide whether the test case should be disabled. We +// determine whether the test case should be disabled by resolving the (test +// case name, test name) in a manifest file. +class ManifestCheckingTest : public ::testing::Test { + protected: + // This method runs before each test runs. + void SetUp() override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_ diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc index dc9ac7b684a..9e85af76e89 100644 --- a/tensorflow/compiler/xla/tests/test_macros.cc +++ b/tensorflow/compiler/xla/tests/test_macros.cc @@ -15,93 +15,18 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_split.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/regexp.h" namespace xla { -namespace { -// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is -// disabled - a sequence of regexps. -using ManifestT = absl::flat_hash_map>; - -ManifestT ReadManifest() { - ManifestT manifest; - - string path = XLA_DISABLED_MANIFEST; - if (path.empty()) { - return manifest; - } - - std::ifstream file_stream(path); - // Note: parens are required to disambiguate vs function decl. - string contents((std::istreambuf_iterator(file_stream)), - std::istreambuf_iterator()); - - std::vector lines = absl::StrSplit(contents, '\n'); - for (string& line : lines) { - auto comment = line.find("//"); - if (comment != string::npos) { - line = line.substr(0, comment); - } - if (line.empty()) { - continue; - } - absl::StripTrailingAsciiWhitespace(&line); - std::vector pieces = absl::StrSplit(line, ' '); - CHECK_GE(pieces.size(), 1); - auto& platforms = manifest[pieces[0]]; - for (int64 i = 1; i < pieces.size(); ++i) { - platforms.push_back(pieces[i]); - } - } - return manifest; +static bool InitModule() { + kDisabledManifestPath = XLA_DISABLED_MANIFEST; + VLOG(1) << "kDisabledManifestPath: " << kDisabledManifestPath; + kTestPlatform = XLA_PLATFORM; + VLOG(1) << "kTestPlatform: " << kTestPlatform; + return false; } -} // namespace - -std::string PrependDisabledIfIndicated(absl::string_view test_case_name, - absl::string_view test_name) { - ManifestT manifest = ReadManifest(); - - // If the test name ends with a slash followed by one or more digits, strip - // that off; this is just a shard number, and matching on this would be - // unstable even if someone wanted to do it. - static LazyRE2 shard_num_pattern = {R"(/\d+$)"}; - absl::string_view suffix; - if (RE2::PartialMatch(test_name, *shard_num_pattern, &suffix)) { - test_name.remove_suffix(suffix.size()); - } - - // First try full match: test_case_name.test_name - // If that fails, try to find just the test_case_name; this would disable all - // tests in the test case. - auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name)); - if (it == manifest.end()) { - it = manifest.find(test_case_name); - if (it == manifest.end()) { - return std::string(test_name); - } - } - - // Expect a full match vs. one of the platform regexps to disable the test. - const std::vector& disabled_platforms = it->second; - string platform_string = XLA_PLATFORM; - for (const auto& s : disabled_platforms) { - if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) { - return absl::StrCat("DISABLED_", test_name); - } - } - - // We didn't hit in the disabled manifest entries, so don't disable it. - return std::string(test_name); -} +static bool module_initialized = InitModule(); } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index 33d2dff9721..f62bccbe850 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -28,12 +28,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ #define TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/test.h" - #define DISABLED_ON_CPU(X) X #define DISABLED_ON_GPU(X) X #define DISABLED_ON_GPU_ROCM(X) X @@ -79,117 +73,15 @@ limitations under the License. namespace xla { -// Reads a disabled manifest file to resolve whether test cases should be -// disabled on a particular platform. For a test that should be disabled, -// returns DISABLED_ prepended to its name; otherwise returns the test name -// unmodified. -std::string PrependDisabledIfIndicated(absl::string_view test_case_name, - absl::string_view test_name); +inline const char *kDisabledManifestPath = nullptr; +inline const char *kTestPlatform = nullptr; } // namespace xla -// This is the internal "gtest" class instantiation -- it is identical to the -// GTEST_TEST_ macro, except that we intercept the test name for potential -// modification by PrependDisabledIfIndicated. That file can use an arbitrary -// heuristic to decide whether the test case should be disabled, and we -// determine whether the test case should be disabled by resolving the (test -// case name, test name) in a manifest file. -#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class) \ - class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ - : public parent_class { \ - public: \ - GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \ - \ - private: \ - virtual void TestBody(); \ - static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \ - GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \ - test_name)); \ - }; \ - \ - ::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \ - test_name)::test_info_ = \ - ::testing::RegisterTest( \ - #test_case_name, \ - ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ - .c_str(), \ - nullptr, nullptr, __FILE__, __LINE__, []() -> parent_class* { \ - return new GTEST_TEST_CLASS_NAME_(test_case_name, test_name)(); \ - }); \ - void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody() +#define XLA_TEST_F(test_fixture, test_name) TEST_F(test_fixture, test_name) -// This is identical to the TEST_F macro from "gtest", but it potentially -// disables the test based on an external manifest file, DISABLED_MANIFEST. -// -// Per usual, you can see what tests are available via --gunit_list_tests and -// choose to run tests that have been disabled via the manifest via -// --gunit_also_run_disabled_tests. -#define XLA_TEST_F(test_fixture, test_name) \ - XLA_GTEST_TEST_(test_fixture, test_name, test_fixture) +#define XLA_TEST_P(test_case_name, test_name) TEST_P(test_case_name, test_name) -// Likewise, this is identical to the TEST_P macro from "gtest", but -// potentially disables the test based on the DISABLED_MANIFEST file. -// -// We have to wrap this in an outer layer so that any DISABLED_ON_* macros will -// be properly expanded before the stringification occurs. -#define XLA_TEST_P_IMPL_(test_case_name, test_name) \ - class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ - : public test_case_name { \ - public: \ - GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \ - virtual void TestBody(); \ - \ - private: \ - static int AddToRegistry() { \ - ::testing::UnitTest::GetInstance() \ - ->parameterized_test_registry() \ - .GetTestCasePatternHolder( \ - #test_case_name, \ - ::testing::internal::CodeLocation(__FILE__, __LINE__)) \ - ->AddTestPattern( \ - #test_case_name, \ - ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ - .c_str(), \ - new ::testing::internal::TestMetaFactory()); \ - return 0; \ - } \ - static int gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_; \ - GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \ - test_name)); \ - }; \ - int GTEST_TEST_CLASS_NAME_(test_case_name, \ - test_name)::gtest_registering_dummy_ = \ - GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::AddToRegistry(); \ - void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody() - -#define XLA_TEST_P(test_case_name, test_name) \ - XLA_TEST_P_IMPL_(test_case_name, test_name) - -// This is identical to the TEST_F macro from "gtest", but it potentially -// disables the test based on an external manifest file, DISABLED_MANIFEST. -#define XLA_TYPED_TEST(CaseName, TestName) \ - template \ - class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \ - : public CaseName { \ - private: \ - typedef CaseName TestFixture; \ - typedef gtest_TypeParam_ TypeParam; \ - virtual void TestBody(); \ - }; \ - bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ = \ - ::testing::internal::TypeParameterizedTest< \ - CaseName, \ - ::testing::internal::TemplateSel, \ - GTEST_TYPE_PARAMS_(CaseName)>:: \ - Register( \ - "", ::testing::internal::CodeLocation(__FILE__, __LINE__), \ - #CaseName, \ - ::xla::PrependDisabledIfIndicated(#CaseName, #TestName).c_str(), \ - 0); \ - template \ - void GTEST_TEST_CLASS_NAME_(CaseName, \ - TestName)::TestBody() +#define XLA_TYPED_TEST(CaseName, TestName) TYPED_TEST(CaseName, TestName) #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_