Merge pull request #41592 from vnvo2409:refactor-gcs-part-2-random-access-file

PiperOrigin-RevId: 322436927
Change-Id: I3d4c2bd3900dd2225a678bad893805fbcd69ded6
This commit is contained in:
TensorFlower Gardener 2020-07-21 14:25:15 -07:00
commit 286e101e8f

View File

@ -106,11 +106,12 @@ static void MaybeAppendSlash(std::string* name) {
// A helper function to actually read the data from GCS.
static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
size_t buffer_size, char* buffer,
gcs::Client* gcs_client, TF_Status* status) {
tf_gcs_filesystem::GCSFile* gcs_file,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return -1;
auto stream = gcs_client->ReadObject(
auto stream = gcs_file->gcs_client.ReadObject(
bucket, object, gcs::ReadRange(offset, offset + buffer_size));
TF_SetStatusFromGCSStatus(stream.status(), status);
if ((TF_GetCode(status) != TF_OK) &&
@ -120,11 +121,33 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
int64_t read;
if (!absl::SimpleAtoi(stream.headers().find("content-length")->second,
&read)) {
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
return -1;
// When we read a file with offset that is bigger than the actual file size.
// GCS will return an empty header (e.g no `content-length` header). In this
// case, we will set read to `0` and continue.
if (TF_GetCode(status) == TF_OUT_OF_RANGE) {
read = 0;
} else {
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
return -1;
}
}
// `TF_OUT_OF_RANGE` isn't considered as an error. So we clear it here.
TF_SetStatus(status, TF_OK, "");
stream.read(buffer, read);
return stream.gcount();
read = stream.gcount();
if (read < buffer_size) {
// Check stat cache to see if we encountered an interrupted read.
tf_gcs_filesystem::GcsFileStat stat;
if (gcs_file->stat_cache->Lookup(path, &stat)) {
if (offset + read < stat.base.length) {
TF_SetStatus(status, TF_INTERNAL,
absl::StrCat("File contents are inconsistent for file: ",
path, " @ ", offset)
.c_str());
}
}
}
return read;
}
// SECTION 1. Implementation for `TF_RandomAccessFile`
@ -198,13 +221,13 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
(std::min)(n - copy_size, gcs_file->buffer.size());
memcpy(buffer + copy_size, gcs_file->buffer.data(), remaining_copy);
copy_size += remaining_copy;
if (copy_size < n) {
// Forget the end-of-file flag to allow for clients that poll on the
// same file.
gcs_file->buffer_end_is_past_eof = false;
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
return copy_size;
}
}
if (copy_size < n) {
// Forget the end-of-file flag to allow for clients that poll on the
// same file.
gcs_file->buffer_end_is_past_eof = false;
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
return copy_size;
}
TF_SetStatus(status, TF_OK, "");
return copy_size;
@ -405,13 +428,12 @@ GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client)
max_staleness = value;
}
auto gcs_client_ptr = &this->gcs_client;
file_block_cache = std::make_unique<RamFileBlockCache>(
block_size, max_bytes, max_staleness,
[gcs_client_ptr](const std::string& filename, size_t offset,
size_t buffer_size, char* buffer, TF_Status* status) {
return LoadBufferFromGCS(filename, offset, buffer_size, buffer,
gcs_client_ptr, status);
[this](const std::string& filename, size_t offset, size_t buffer_size,
char* buffer, TF_Status* status) {
return LoadBufferFromGCS(filename, offset, buffer_size, buffer, this,
status);
});
uint64_t stat_cache_max_age = kStatCacheDefaultMaxAge;
@ -443,6 +465,19 @@ void Cleanup(TF_Filesystem* filesystem) {
delete gcs_file;
}
static void UncachedStatForObject(const std::string& bucket,
const std::string& object, GcsFileStat* stat,
gcs::Client* gcs_client, TF_Status* status) {
auto metadata = gcs_client->GetObjectMetadata(bucket, object);
if (!metadata) return TF_SetStatusFromGCSStatus(metadata.status(), status);
stat->generation_number = metadata->generation();
stat->base.length = metadata->size();
stat->base.mtime_nsec =
metadata->time_storage_class_updated().time_since_epoch().count();
stat->base.is_directory = object.back() == '/';
return TF_SetStatus(status, TF_OK, "");
}
// TODO(vnvo2409): Implement later
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_RandomAccessFile* file, TF_Status* status) {
@ -456,17 +491,31 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
absl::MutexLock l(&gcs_file->block_cache_lock);
is_cache_enabled = gcs_file->file_block_cache->IsCacheEnabled();
}
auto read_fn = [gcs_file, is_cache_enabled](
auto read_fn = [gcs_file, is_cache_enabled, bucket, object](
const std::string& path, uint64_t offset, size_t n,
char* buffer, TF_Status* status) -> int64_t {
// TODO(vnvo2409): Check for `stat_cache`.
int64_t read = 0;
if (is_cache_enabled) {
absl::ReaderMutexLock l(&gcs_file->block_cache_lock);
GcsFileStat stat;
gcs_file->stat_cache->LookupOrCompute(
path, &stat,
[gcs_file, bucket, object](const std::string& path, GcsFileStat* stat,
TF_Status* status) {
UncachedStatForObject(bucket, object, stat, &gcs_file->gcs_client,
status);
},
status);
if (TF_GetCode(status) != TF_OK) return -1;
if (!gcs_file->file_block_cache->ValidateAndUpdateFileSignature(
path, stat.generation_number)) {
std::cout
<< "File signature has been changed. Refreshing the cache. Path: "
<< path;
}
read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status);
} else {
read = LoadBufferFromGCS(path, offset, n, buffer, &gcs_file->gcs_client,
status);
read = LoadBufferFromGCS(path, offset, n, buffer, gcs_file, status);
}
if (TF_GetCode(status) != TF_OK) return -1;
if (read < n)