Add tmp_dir

This commit is contained in:
Vo Van Nghia 2020-05-04 14:13:58 +07:00
parent 579bbe61a9
commit 43ff3c1e94
1 changed files with 33 additions and 17 deletions

View File

@ -85,21 +85,23 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
const std::string test_name = tensorflow::str_util::StringReplace(
::testing::UnitTest::GetInstance()->current_test_info()->name(), "/",
"_", /*replace_all=*/true);
// Since we need the tests for cloud filesystem to run on all OSs (Windows,
// MacOS, Linux, ...) The path to temp directory must not be dependent on
// the OS which runs the tests.
const std::string tmp_dir_ =
cloud_path_.empty() ? ::testing::TempDir() : "/tmp/";
root_dir_ = tensorflow::io::JoinPath(
tmp_dir_,
tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name));
if (!cloud_path_.empty()) {
// We have to join path for non-local filesystem manually to make sure
// that this test will run on Windows since `tensorflow::io::JoinPath`
// behaves differently on Windows. `tmp_dir` should be something like
// `path/to/tmp/dir/`. After joining path, we will have
// /path/to/tmp/dir/tf_fs_rng_name/`
root_dir_ = tensorflow::strings::StrCat(
"/", tmp_dir_,
tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name), "/");
} else {
root_dir_ = tensorflow::io::JoinPath(
tmp_dir_,
tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name));
}
if (!GetParam().empty()) {
if (!cloud_path_.empty()) {
root_dir_ = tensorflow::strings::StrCat(GetParam(), "://", cloud_path_,
root_dir_, "/");
} else {
root_dir_ = tensorflow::strings::StrCat(GetParam(), "://", root_dir_);
}
root_dir_ = tensorflow::strings::StrCat(GetParam(), "://", cloud_path_,
root_dir_);
}
env_ = Env::Default();
}
@ -151,8 +153,13 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
rng_val_ = distribution(gen);
}
static void SetCloudPath(const std::string& cloud_path_) {
ModularFileSystemTest::cloud_path_ = cloud_path_;
static void SetCloudPath(const std::string& cloud_path) {
cloud_path_ = cloud_path;
if (cloud_path_.back() == '/') cloud_path_.pop_back();
}
static void SetTmpDir(const std::string& tmp_dir) {
tmp_dir_ = tmp_dir.empty() ? ::testing::TempDir() : tmp_dir;
}
protected:
@ -162,10 +169,12 @@ class ModularFileSystemTest : public ::testing::TestWithParam<std::string> {
std::string root_dir_;
static int rng_val_;
static std::string cloud_path_;
static std::string tmp_dir_;
};
int ModularFileSystemTest::rng_val_;
std::string ModularFileSystemTest::cloud_path_;
std::string ModularFileSystemTest::tmp_dir_;
// As some of the implementations might be missing, the tests should still pass
// if the returned `Status` signals the unimplemented state.
@ -1762,6 +1771,11 @@ static bool SetCloudPath(const std::string& cloud_path_) {
return true;
}
static bool SetTmpDir(const std::string& tmp_dir_) {
ModularFileSystemTest::SetTmpDir(tmp_dir_);
return true;
}
} // namespace
} // namespace tensorflow
@ -1777,7 +1791,9 @@ GTEST_API_ int main(int argc, char** argv) {
"URI scheme to test"),
tensorflow::Flag("cloud_path", tensorflow::SetCloudPath, "",
"Path for cloud filesystem (namenode for hdfs, "
"bucketname for s3/gcs)")};
"bucketname for s3/gcs)"),
tensorflow::Flag("tmp_dir", tensorflow::SetTmpDir, "",
"Temporary directory to store test data.")};
if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) {
std::cout << tensorflow::Flags::Usage(argv[0], flag_list);
return -1;