Merge pull request #40710 from vnvo2409:gcs-random-access
PiperOrigin-RevId: 318880431 Change-Id: If7665e5407de8745bf3b61b32f3bdeca3c913c7e
This commit is contained in:
commit
0a36436b9f
@ -30,6 +30,7 @@ cc_library(
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -46,7 +47,6 @@ cc_library(
|
||||
tf_cc_test(
|
||||
name = "gcs_filesystem_test",
|
||||
srcs = [
|
||||
"gcs_filesystem.cc",
|
||||
"gcs_filesystem_test.cc",
|
||||
],
|
||||
tags = [
|
||||
@ -58,5 +58,6 @@ tf_cc_test(
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core/platform:stacktrace_handler",
|
||||
"//tensorflow/core/platform:test",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "google/cloud/storage/client.h"
|
||||
#include "tensorflow/c/env.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
|
||||
@ -75,8 +76,42 @@ void ParseGCSPath(const std::string& fname, bool object_empty_ok,
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
typedef struct GCSFile {
|
||||
const std::string bucket;
|
||||
const std::string object;
|
||||
gcs::Client* gcs_client; // not owned
|
||||
} GCSFile;
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
void Cleanup(TF_RandomAccessFile* file) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
delete gcs_file;
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Adding cache.
|
||||
// `google-cloud-cpp` is working on a feature that we may want to use.
|
||||
// See https://github.com/googleapis/google-cloud-cpp/issues/4013.
|
||||
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
auto stream = gcs_file->gcs_client->ReadObject(
|
||||
gcs_file->bucket, gcs_file->object, gcs::ReadRange(offset, offset + n));
|
||||
TF_SetStatusFromGCSStatus(stream.status(), status);
|
||||
if ((TF_GetCode(status) != TF_OK) &&
|
||||
(TF_GetCode(status) != TF_OUT_OF_RANGE)) {
|
||||
return -1;
|
||||
}
|
||||
int64_t read;
|
||||
if (!absl::SimpleAtoi(stream.headers().find("content-length")->second,
|
||||
&read)) {
|
||||
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
|
||||
return -1;
|
||||
}
|
||||
if (read != n) {
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
|
||||
}
|
||||
stream.read(buffer, read);
|
||||
return read;
|
||||
}
|
||||
|
||||
} // namespace tf_random_access_file
|
||||
|
||||
@ -251,6 +286,17 @@ void Cleanup(TF_Filesystem* filesystem) {
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_RandomAccessFile* file, TF_Status* status) {
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
file->plugin_file = new tf_random_access_file::GCSFile(
|
||||
{std::move(bucket), std::move(object), &gcs_file->gcs_client});
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
@ -322,6 +368,11 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
|
||||
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
|
||||
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
|
||||
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
|
||||
ops->random_access_file_ops->read = tf_random_access_file::Read;
|
||||
|
||||
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
|
||||
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
|
||||
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||
@ -330,6 +381,8 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
|
||||
ops->filesystem_ops->cleanup = tf_gcs_filesystem::Cleanup;
|
||||
ops->filesystem_ops->new_random_access_file =
|
||||
tf_gcs_filesystem::NewRandomAccessFile;
|
||||
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
|
||||
ops->filesystem_ops->new_appendable_file =
|
||||
tf_gcs_filesystem::NewAppendableFile;
|
||||
|
@ -22,9 +22,17 @@
|
||||
void ParseGCSPath(const std::string& fname, bool object_empty_ok,
|
||||
std::string* bucket, std::string* object, TF_Status* status);
|
||||
|
||||
namespace tf_random_access_file {
|
||||
void Cleanup(TF_RandomAccessFile* file);
|
||||
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status);
|
||||
} // namespace tf_random_access_file
|
||||
|
||||
namespace tf_gcs_filesystem {
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
||||
void Cleanup(TF_Filesystem* filesystem);
|
||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_RandomAccessFile* file, TF_Status* status);
|
||||
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status);
|
||||
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
|
@ -14,11 +14,21 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/stacktrace_handler.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x))
|
||||
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
|
||||
|
||||
static const char* content = "abcdefghijklmnopqrstuvwxyz1234567890";
|
||||
// We will work with content_view instead of content.
|
||||
static const absl::string_view content_view = content;
|
||||
|
||||
namespace gcs = google::cloud::storage;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -26,10 +36,13 @@ namespace {
|
||||
class GCSFilesystemTest : public ::testing::Test {
|
||||
public:
|
||||
void SetUp() override {
|
||||
root_dir_ = io::JoinPath(
|
||||
tmp_dir_,
|
||||
::testing::UnitTest::GetInstance()->current_test_info()->name());
|
||||
status_ = TF_NewStatus();
|
||||
filesystem_ = new TF_Filesystem;
|
||||
tf_gcs_filesystem::Init(filesystem_, status_);
|
||||
ASSERT_TF_OK(status_) << "Can not initialize filesystem. "
|
||||
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
|
||||
<< TF_Message(status_);
|
||||
}
|
||||
void TearDown() override {
|
||||
@ -38,10 +51,82 @@ class GCSFilesystemTest : public ::testing::Test {
|
||||
delete filesystem_;
|
||||
}
|
||||
|
||||
static bool InitializeTmpDir() {
|
||||
// This env should be something like `gs://bucket/path`
|
||||
const char* test_dir = getenv("GCS_TEST_TMPDIR");
|
||||
if (test_dir != nullptr) {
|
||||
std::string bucket, object;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
ParseGCSPath(test_dir, true, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
TF_DeleteStatus(status);
|
||||
return false;
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// We add a random value into `test_dir` to ensures that two consecutive
|
||||
// runs are unlikely to clash.
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<> distribution;
|
||||
std::string rng_val = std::to_string(distribution(gen));
|
||||
tmp_dir_ = io::JoinPath(string(test_dir), rng_val);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
std::string GetURIForPath(absl::string_view path) {
|
||||
const std::string translated_name =
|
||||
tensorflow::io::JoinPath(root_dir_, path);
|
||||
return translated_name;
|
||||
}
|
||||
|
||||
protected:
|
||||
TF_Filesystem* filesystem_;
|
||||
TF_Status* status_;
|
||||
|
||||
private:
|
||||
std::string root_dir_;
|
||||
static std::string tmp_dir_;
|
||||
};
|
||||
std::string GCSFilesystemTest::tmp_dir_;
|
||||
|
||||
::testing::AssertionResult WriteToServer(const std::string& path, size_t length,
|
||||
gcs::Client* gcs_client,
|
||||
TF_Status* status) {
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
return ::testing::AssertionFailure() << TF_Message(status);
|
||||
}
|
||||
|
||||
auto writer = gcs_client->WriteObject(bucket, object);
|
||||
writer.write(content, length);
|
||||
writer.Close();
|
||||
if (writer.metadata()) {
|
||||
return ::testing::AssertionSuccess();
|
||||
} else {
|
||||
return ::testing::AssertionFailure()
|
||||
<< writer.metadata().status().message();
|
||||
}
|
||||
}
|
||||
|
||||
::testing::AssertionResult CompareSubString(int64_t offset, size_t n,
|
||||
absl::string_view result,
|
||||
size_t read) {
|
||||
// Result isn't a null-terminated string so we have to wrap it inside a
|
||||
// `string_view`
|
||||
if (n == read && content_view.substr(offset, n) ==
|
||||
absl::string_view(result).substr(0, read)) {
|
||||
return ::testing::AssertionSuccess();
|
||||
} else {
|
||||
return ::testing::AssertionFailure()
|
||||
<< "Result: " << absl::string_view(result).substr(0, read)
|
||||
<< " Read:" << read;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, ParseGCSPath) {
|
||||
std::string bucket, object;
|
||||
@ -65,11 +150,50 @@ TEST_F(GCSFilesystemTest, ParseGCSPath) {
|
||||
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, RandomAccessFile) {
|
||||
std::string filepath = GetURIForPath("a_file");
|
||||
TF_RandomAccessFile* file = new TF_RandomAccessFile;
|
||||
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, filepath.c_str(), file,
|
||||
status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
char* result = new char[content_view.length()];
|
||||
int64_t read = tf_random_access_file::Read(file, 0, 1, result, status_);
|
||||
ASSERT_EQ(read, -1) << "Read: " << read;
|
||||
ASSERT_EQ(TF_GetCode(status_), TF_NOT_FOUND) << TF_Message(status_);
|
||||
TF_SetStatus(status_, TF_OK, "");
|
||||
|
||||
auto gcs_client = static_cast<gcs::Client*>(filesystem_->plugin_filesystem);
|
||||
ASSERT_TRUE(
|
||||
WriteToServer(filepath, content_view.length(), gcs_client, status_));
|
||||
|
||||
read = tf_random_access_file::Read(file, 0, content_view.length(), result,
|
||||
status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
ASSERT_TRUE(CompareSubString(0, content_view.length(), result, read));
|
||||
|
||||
read = tf_random_access_file::Read(file, 0, 4, result, status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
ASSERT_TRUE(CompareSubString(0, 4, result, read));
|
||||
|
||||
read = tf_random_access_file::Read(file, content_view.length() - 2, 4, result,
|
||||
status_);
|
||||
ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_);
|
||||
ASSERT_TRUE(CompareSubString(content_view.length() - 2, 2, result, read));
|
||||
|
||||
delete result;
|
||||
tf_random_access_file::Cleanup(file);
|
||||
delete file;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
GTEST_API_ int main(int argc, char** argv) {
|
||||
tensorflow::testing::InstallStacktraceHandler();
|
||||
if (!tensorflow::GCSFilesystemTest::InitializeTmpDir()) {
|
||||
std::cerr << "Could not read GCS_TEST_TMPDIR env";
|
||||
return -1;
|
||||
}
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user