STT-tensorflow/tensorflow/compiler/xla/tests/manifest_checking_test.cc
David Majnemer 728a4a4405 Unbreak the MSVC build
Don't use inline variables, our MSVC builds do not like them.

Instead, simulate them using inline functions + function-scope static variables.

PiperOrigin-RevId: 317428053
Change-Id: Icf8838c159ab9f132ad32360633046f4c2224a79
2020-06-19 21:15:11 -07:00

130 lines
4.2 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
#include <fstream>
#include <iterator>
#include <string>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla {
namespace {
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is
// disabled - a sequence of regexps.
using ManifestT = absl::flat_hash_map<std::string, std::vector<std::string>>;
ManifestT ReadManifest() {
ManifestT manifest;
absl::string_view path = absl::NullSafeStringView(*DisabledManifestPath());
if (path.empty()) {
return manifest;
}
// Note: parens are required to disambiguate vs function decl.
std::ifstream file_stream((std::string(path)));
std::string contents((std::istreambuf_iterator<char>(file_stream)),
std::istreambuf_iterator<char>());
std::vector<std::string> lines = absl::StrSplit(contents, '\n');
for (std::string& line : lines) {
auto comment = line.find("//");
if (comment != std::string::npos) {
line = line.substr(0, comment);
}
if (line.empty()) {
continue;
}
absl::StripTrailingAsciiWhitespace(&line);
std::vector<std::string> pieces = absl::StrSplit(line, ' ');
CHECK_GE(pieces.size(), 1);
auto& platforms = manifest[pieces[0]];
for (size_t i = 1; i < pieces.size(); ++i) {
platforms.push_back(pieces[i]);
}
}
return manifest;
}
} // namespace
void ManifestCheckingTest::SetUp() {
const testing::TestInfo* test_info =
testing::UnitTest::GetInstance()->current_test_info();
absl::string_view test_case_name = test_info->test_suite_name();
absl::string_view test_name = test_info->name();
VLOG(1) << "test_case_name: " << test_case_name;
VLOG(1) << "test_name: " << test_name;
// Remove the type suffix from the test case name.
if (const char* type_param = test_info->type_param()) {
VLOG(1) << "type_param: " << type_param;
size_t last_slash = test_case_name.rfind('/');
test_case_name = test_case_name.substr(0, last_slash);
VLOG(1) << "test_case_name: " << test_case_name;
}
// Remove the test instantiation name if it is present.
auto first_slash = test_case_name.find('/');
if (first_slash != test_case_name.npos) {
test_case_name.remove_prefix(first_slash + 1);
VLOG(1) << "test_case_name: " << test_case_name;
}
ManifestT manifest = ReadManifest();
// If the test name ends with a slash followed by one or more characters,
// strip that off.
auto last_slash = test_name.rfind('/');
if (last_slash != test_name.npos) {
test_name = test_name.substr(0, last_slash);
VLOG(1) << "test_name: " << test_name;
}
// First try full match: test_case_name.test_name
// If that fails, try to find just the test_case_name; this would disable all
// tests in the test case.
auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
if (it == manifest.end()) {
it = manifest.find(test_case_name);
if (it == manifest.end()) {
return;
}
}
// Expect a full match vs. one of the platform regexps to disable the test.
const std::vector<std::string>& disabled_platforms = it->second;
auto platform_string = *TestPlatform();
for (const auto& s : disabled_platforms) {
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
GTEST_SKIP();
return;
}
}
// We didn't hit in the disabled manifest entries, so don't disable it.
}
} // namespace xla