Merge pull request #40710 from vnvo2409:gcs-random-access

PiperOrigin-RevId: 318880431
Change-Id: If7665e5407de8745bf3b61b32f3bdeca3c913c7e
This commit is contained in:
TensorFlower Gardener 2020-06-29 13:28:19 -07:00
commit 0a36436b9f
4 changed files with 190 additions and 4 deletions

View File

@ -30,6 +30,7 @@ cc_library(
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
"@com_google_absl//absl/strings",
],
)
@ -46,7 +47,6 @@ cc_library(
tf_cc_test(
name = "gcs_filesystem_test",
srcs = [
"gcs_filesystem.cc",
"gcs_filesystem_test.cc",
],
tags = [
@ -58,5 +58,6 @@ tf_cc_test(
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/platform:stacktrace_handler",
"//tensorflow/core/platform:test",
"@com_google_absl//absl/strings",
],
)

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <stdlib.h>
#include <string.h>
#include "absl/strings/numbers.h"
#include "google/cloud/storage/client.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
@ -75,8 +76,42 @@ void ParseGCSPath(const std::string& fname, bool object_empty_ok,
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
typedef struct GCSFile {
const std::string bucket;
const std::string object;
gcs::Client* gcs_client; // not owned
} GCSFile;
// TODO(vnvo2409): Implement later
void Cleanup(TF_RandomAccessFile* file) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
delete gcs_file;
}
// TODO(vnvo2409): Adding cache.
// `google-cloud-cpp` is working on a feature that we may want to use.
// See https://github.com/googleapis/google-cloud-cpp/issues/4013.
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
auto stream = gcs_file->gcs_client->ReadObject(
gcs_file->bucket, gcs_file->object, gcs::ReadRange(offset, offset + n));
TF_SetStatusFromGCSStatus(stream.status(), status);
if ((TF_GetCode(status) != TF_OK) &&
(TF_GetCode(status) != TF_OUT_OF_RANGE)) {
return -1;
}
int64_t read;
if (!absl::SimpleAtoi(stream.headers().find("content-length")->second,
&read)) {
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
return -1;
}
if (read != n) {
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
}
stream.read(buffer, read);
return read;
}
} // namespace tf_random_access_file
@ -251,6 +286,17 @@ void Cleanup(TF_Filesystem* filesystem) {
}
// TODO(vnvo2409): Implement later
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_RandomAccessFile* file, TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
file->plugin_file = new tf_random_access_file::GCSFile(
{std::move(bucket), std::move(object), &gcs_file->gcs_client});
TF_SetStatus(status, TF_OK, "");
}
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
@ -322,6 +368,11 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
ops->random_access_file_ops->read = tf_random_access_file::Read;
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
@ -330,6 +381,8 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
ops->filesystem_ops->cleanup = tf_gcs_filesystem::Cleanup;
ops->filesystem_ops->new_random_access_file =
tf_gcs_filesystem::NewRandomAccessFile;
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_gcs_filesystem::NewAppendableFile;

View File

@ -22,9 +22,17 @@
void ParseGCSPath(const std::string& fname, bool object_empty_ok,
std::string* bucket, std::string* object, TF_Status* status);
namespace tf_random_access_file {
void Cleanup(TF_RandomAccessFile* file);
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status);
} // namespace tf_random_access_file
namespace tf_gcs_filesystem {
void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem);
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_RandomAccessFile* file, TF_Status* status);
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,

View File

@ -14,11 +14,21 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
#include <random>
#include "absl/strings/string_view.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/stacktrace_handler.h"
#include "tensorflow/core/platform/test.h"
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x))
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
static const char* content = "abcdefghijklmnopqrstuvwxyz1234567890";
// We will work with content_view instead of content.
static const absl::string_view content_view = content;
namespace gcs = google::cloud::storage;
namespace tensorflow {
namespace {
@ -26,10 +36,13 @@ namespace {
class GCSFilesystemTest : public ::testing::Test {
public:
void SetUp() override {
root_dir_ = io::JoinPath(
tmp_dir_,
::testing::UnitTest::GetInstance()->current_test_info()->name());
status_ = TF_NewStatus();
filesystem_ = new TF_Filesystem;
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_) << "Can not initialize filesystem. "
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
}
void TearDown() override {
@ -38,10 +51,82 @@ class GCSFilesystemTest : public ::testing::Test {
delete filesystem_;
}
static bool InitializeTmpDir() {
// This env should be something like `gs://bucket/path`
const char* test_dir = getenv("GCS_TEST_TMPDIR");
if (test_dir != nullptr) {
std::string bucket, object;
TF_Status* status = TF_NewStatus();
ParseGCSPath(test_dir, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) {
TF_DeleteStatus(status);
return false;
}
TF_DeleteStatus(status);
// We add a random value into `test_dir` to ensures that two consecutive
// runs are unlikely to clash.
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> distribution;
std::string rng_val = std::to_string(distribution(gen));
tmp_dir_ = io::JoinPath(string(test_dir), rng_val);
return true;
} else {
return false;
}
}
std::string GetURIForPath(absl::string_view path) {
const std::string translated_name =
tensorflow::io::JoinPath(root_dir_, path);
return translated_name;
}
protected:
TF_Filesystem* filesystem_;
TF_Status* status_;
private:
std::string root_dir_;
static std::string tmp_dir_;
};
std::string GCSFilesystemTest::tmp_dir_;
::testing::AssertionResult WriteToServer(const std::string& path, size_t length,
gcs::Client* gcs_client,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) {
return ::testing::AssertionFailure() << TF_Message(status);
}
auto writer = gcs_client->WriteObject(bucket, object);
writer.write(content, length);
writer.Close();
if (writer.metadata()) {
return ::testing::AssertionSuccess();
} else {
return ::testing::AssertionFailure()
<< writer.metadata().status().message();
}
}
::testing::AssertionResult CompareSubString(int64_t offset, size_t n,
absl::string_view result,
size_t read) {
// Result isn't a null-terminated string so we have to wrap it inside a
// `string_view`
if (n == read && content_view.substr(offset, n) ==
absl::string_view(result).substr(0, read)) {
return ::testing::AssertionSuccess();
} else {
return ::testing::AssertionFailure()
<< "Result: " << absl::string_view(result).substr(0, read)
<< " Read:" << read;
}
}
TEST_F(GCSFilesystemTest, ParseGCSPath) {
std::string bucket, object;
@ -65,11 +150,50 @@ TEST_F(GCSFilesystemTest, ParseGCSPath) {
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
}
TEST_F(GCSFilesystemTest, RandomAccessFile) {
std::string filepath = GetURIForPath("a_file");
TF_RandomAccessFile* file = new TF_RandomAccessFile;
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, filepath.c_str(), file,
status_);
ASSERT_TF_OK(status_);
char* result = new char[content_view.length()];
int64_t read = tf_random_access_file::Read(file, 0, 1, result, status_);
ASSERT_EQ(read, -1) << "Read: " << read;
ASSERT_EQ(TF_GetCode(status_), TF_NOT_FOUND) << TF_Message(status_);
TF_SetStatus(status_, TF_OK, "");
auto gcs_client = static_cast<gcs::Client*>(filesystem_->plugin_filesystem);
ASSERT_TRUE(
WriteToServer(filepath, content_view.length(), gcs_client, status_));
read = tf_random_access_file::Read(file, 0, content_view.length(), result,
status_);
ASSERT_TF_OK(status_);
ASSERT_TRUE(CompareSubString(0, content_view.length(), result, read));
read = tf_random_access_file::Read(file, 0, 4, result, status_);
ASSERT_TF_OK(status_);
ASSERT_TRUE(CompareSubString(0, 4, result, read));
read = tf_random_access_file::Read(file, content_view.length() - 2, 4, result,
status_);
ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_);
ASSERT_TRUE(CompareSubString(content_view.length() - 2, 2, result, read));
delete result;
tf_random_access_file::Cleanup(file);
delete file;
}
} // namespace
} // namespace tensorflow
GTEST_API_ int main(int argc, char** argv) {
tensorflow::testing::InstallStacktraceHandler();
if (!tensorflow::GCSFilesystemTest::InitializeTmpDir()) {
std::cerr << "Could not read GCS_TEST_TMPDIR env";
return -1;
}
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}