Merge pull request #41477 from vnvo2409:s3-appendable
PiperOrigin-RevId: 321822368 Change-Id: I8f3813efc374a522305656c95e3bcd0cdb072b42
This commit is contained in:
commit
fdf1095dcd
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.h"
|
||||
|
||||
#include <aws/core/config/AWSProfileConfigLoader.h>
|
||||
#include <aws/core/utils/FileSystemUtils.h>
|
||||
#include <aws/core/utils/stream/PreallocatedStreamBuf.h>
|
||||
#include <aws/s3/model/GetObjectRequest.h>
|
||||
#include <stdlib.h>
|
||||
@ -39,6 +40,9 @@ constexpr int kExecutorPoolSize = 25;
|
||||
constexpr uint64_t kS3MultiPartUploadChunkSize = 50 * 1024 * 1024; // 50 MB
|
||||
constexpr uint64_t kS3MultiPartDownloadChunkSize = 50 * 1024 * 1024; // 50 MB
|
||||
constexpr size_t kDownloadRetries = 3;
|
||||
constexpr size_t kUploadRetries = 3;
|
||||
|
||||
constexpr size_t kS3ReadAppendableFileBufferSize = 1024 * 1024; // 1 MB
|
||||
|
||||
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
|
||||
static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||
@ -54,6 +58,9 @@ static inline void TF_SetStatusFromAWSError(
|
||||
case Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE:
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
|
||||
break;
|
||||
case Aws::Http::HttpResponseCode::NOT_FOUND:
|
||||
TF_SetStatus(status, TF_NOT_FOUND, error.GetMessage().c_str());
|
||||
break;
|
||||
default:
|
||||
TF_SetStatus(
|
||||
status, TF_UNKNOWN,
|
||||
@ -331,8 +338,104 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
// SECTION 2. Implementation for `TF_WritableFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_writable_file {
|
||||
typedef struct S3File {
|
||||
Aws::String bucket;
|
||||
Aws::String object;
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client;
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager;
|
||||
bool sync_needed;
|
||||
std::shared_ptr<Aws::Utils::TempFile> outfile;
|
||||
S3File(Aws::String bucket, Aws::String object,
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client,
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager)
|
||||
: bucket(bucket),
|
||||
object(object),
|
||||
s3_client(s3_client),
|
||||
transfer_manager(transfer_manager),
|
||||
outfile(Aws::MakeShared<Aws::Utils::TempFile>(
|
||||
kS3FileSystemAllocationTag, nullptr, "_s3_filesystem_XXXXXX",
|
||||
std::ios_base::binary | std::ios_base::trunc | std::ios_base::in |
|
||||
std::ios_base::out)) {}
|
||||
} S3File;
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
void Cleanup(TF_WritableFile* file) {
|
||||
auto s3_file = static_cast<S3File*>(file->plugin_file);
|
||||
delete s3_file;
|
||||
}
|
||||
|
||||
void Append(const TF_WritableFile* file, const char* buffer, size_t n,
|
||||
TF_Status* status) {
|
||||
auto s3_file = static_cast<S3File*>(file->plugin_file);
|
||||
if (!s3_file->outfile) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"The internal temporary file is not writable.");
|
||||
return;
|
||||
}
|
||||
s3_file->sync_needed = true;
|
||||
s3_file->outfile->write(buffer, n);
|
||||
if (!s3_file->outfile->good())
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Could not append to the internal temporary file.");
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
int64_t Tell(const TF_WritableFile* file, TF_Status* status) {
|
||||
auto s3_file = static_cast<S3File*>(file->plugin_file);
|
||||
auto position = static_cast<int64_t>(s3_file->outfile->tellp());
|
||||
if (position == -1)
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"tellp on the internal temporary file failed");
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return position;
|
||||
}
|
||||
|
||||
void Sync(const TF_WritableFile* file, TF_Status* status) {
|
||||
auto s3_file = static_cast<S3File*>(file->plugin_file);
|
||||
if (!s3_file->outfile) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"The internal temporary file is not writable.");
|
||||
return;
|
||||
}
|
||||
if (!s3_file->sync_needed) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return;
|
||||
}
|
||||
auto position = static_cast<int64_t>(s3_file->outfile->tellp());
|
||||
auto handle = s3_file->transfer_manager->UploadFile(
|
||||
s3_file->outfile, s3_file->bucket, s3_file->object,
|
||||
"application/octet-stream", Aws::Map<Aws::String, Aws::String>());
|
||||
handle->WaitUntilFinished();
|
||||
|
||||
size_t retries = 0;
|
||||
while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED &&
|
||||
retries++ < kUploadRetries) {
|
||||
// if multipart upload was used, only the failed parts will be re-sent
|
||||
s3_file->transfer_manager->RetryUpload(s3_file->outfile, handle);
|
||||
handle->WaitUntilFinished();
|
||||
}
|
||||
if (handle->GetStatus() != Aws::Transfer::TransferStatus::COMPLETED)
|
||||
return TF_SetStatusFromAWSError(handle->GetLastError(), status);
|
||||
s3_file->outfile->clear();
|
||||
s3_file->outfile->seekp(position);
|
||||
s3_file->sync_needed = false;
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void Flush(const TF_WritableFile* file, TF_Status* status) {
|
||||
Sync(file, status);
|
||||
}
|
||||
|
||||
void Close(const TF_WritableFile* file, TF_Status* status) {
|
||||
auto s3_file = static_cast<S3File*>(file->plugin_file);
|
||||
if (s3_file->outfile) {
|
||||
Sync(file, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
s3_file->outfile.reset();
|
||||
}
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
} // namespace tf_writable_file
|
||||
|
||||
@ -397,6 +500,79 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
|
||||
GetS3Client(s3_file);
|
||||
GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file);
|
||||
file->plugin_file = new tf_writable_file::S3File(
|
||||
bucket, object, s3_file->s3_client,
|
||||
s3_file->transfer_managers[Aws::Transfer::TransferDirection::UPLOAD]);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
|
||||
GetS3Client(s3_file);
|
||||
GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file);
|
||||
|
||||
// We need to delete `file->plugin_file` in case of errors.
|
||||
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile*)> writer(
|
||||
file, [](TF_WritableFile* file) {
|
||||
if (file != nullptr && file->plugin_file != nullptr) {
|
||||
tf_writable_file::Cleanup(file);
|
||||
}
|
||||
});
|
||||
writer->plugin_file = new tf_writable_file::S3File(
|
||||
bucket, object, s3_file->s3_client,
|
||||
s3_file->transfer_managers[Aws::Transfer::TransferDirection::UPLOAD]);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
|
||||
// Wraping inside a `std::unique_ptr` to prevent memory-leaking.
|
||||
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile*)> reader(
|
||||
new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
|
||||
if (file != nullptr) {
|
||||
tf_random_access_file::Cleanup(file);
|
||||
delete file;
|
||||
}
|
||||
});
|
||||
NewRandomAccessFile(filesystem, path, reader.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
uint64_t offset = 0;
|
||||
std::string buffer(kS3ReadAppendableFileBufferSize, {});
|
||||
while (true) {
|
||||
auto read = tf_random_access_file::Read(reader.get(), offset,
|
||||
kS3ReadAppendableFileBufferSize,
|
||||
&buffer[0], status);
|
||||
if (TF_GetCode(status) == TF_NOT_FOUND) {
|
||||
break;
|
||||
} else if (TF_GetCode(status) == TF_OK) {
|
||||
offset += read;
|
||||
tf_writable_file::Append(file, buffer.c_str(), read, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
} else if (TF_GetCode(status) == TF_OUT_OF_RANGE) {
|
||||
offset += read;
|
||||
tf_writable_file::Append(file, buffer.c_str(), read, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
break;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
writer.release();
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
} // namespace tf_s3_filesystem
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user