Merge pull request #41477 from vnvo2409:s3-appendable

PiperOrigin-RevId: 321822368
Change-Id: I8f3813efc374a522305656c95e3bcd0cdb072b42
This commit is contained in:
TensorFlower Gardener 2020-07-17 11:50:59 -07:00
commit fdf1095dcd

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.h"
#include <aws/core/config/AWSProfileConfigLoader.h>
#include <aws/core/utils/FileSystemUtils.h>
#include <aws/core/utils/stream/PreallocatedStreamBuf.h>
#include <aws/s3/model/GetObjectRequest.h>
#include <stdlib.h>
@ -39,6 +40,9 @@ constexpr int kExecutorPoolSize = 25;
constexpr uint64_t kS3MultiPartUploadChunkSize = 50 * 1024 * 1024; // 50 MB
constexpr uint64_t kS3MultiPartDownloadChunkSize = 50 * 1024 * 1024; // 50 MB
constexpr size_t kDownloadRetries = 3;
constexpr size_t kUploadRetries = 3;
constexpr size_t kS3ReadAppendableFileBufferSize = 1024 * 1024; // 1 MB
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
@ -54,6 +58,9 @@ static inline void TF_SetStatusFromAWSError(
case Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE:
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
break;
case Aws::Http::HttpResponseCode::NOT_FOUND:
TF_SetStatus(status, TF_NOT_FOUND, error.GetMessage().c_str());
break;
default:
TF_SetStatus(
status, TF_UNKNOWN,
@ -331,8 +338,104 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
typedef struct S3File {
Aws::String bucket;
Aws::String object;
std::shared_ptr<Aws::S3::S3Client> s3_client;
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager;
bool sync_needed;
std::shared_ptr<Aws::Utils::TempFile> outfile;
S3File(Aws::String bucket, Aws::String object,
std::shared_ptr<Aws::S3::S3Client> s3_client,
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager)
: bucket(bucket),
object(object),
s3_client(s3_client),
transfer_manager(transfer_manager),
outfile(Aws::MakeShared<Aws::Utils::TempFile>(
kS3FileSystemAllocationTag, nullptr, "_s3_filesystem_XXXXXX",
std::ios_base::binary | std::ios_base::trunc | std::ios_base::in |
std::ios_base::out)) {}
} S3File;
// TODO(vnvo2409): Implement later
void Cleanup(TF_WritableFile* file) {
auto s3_file = static_cast<S3File*>(file->plugin_file);
delete s3_file;
}
void Append(const TF_WritableFile* file, const char* buffer, size_t n,
TF_Status* status) {
auto s3_file = static_cast<S3File*>(file->plugin_file);
if (!s3_file->outfile) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"The internal temporary file is not writable.");
return;
}
s3_file->sync_needed = true;
s3_file->outfile->write(buffer, n);
if (!s3_file->outfile->good())
TF_SetStatus(status, TF_INTERNAL,
"Could not append to the internal temporary file.");
else
TF_SetStatus(status, TF_OK, "");
}
int64_t Tell(const TF_WritableFile* file, TF_Status* status) {
auto s3_file = static_cast<S3File*>(file->plugin_file);
auto position = static_cast<int64_t>(s3_file->outfile->tellp());
if (position == -1)
TF_SetStatus(status, TF_INTERNAL,
"tellp on the internal temporary file failed");
else
TF_SetStatus(status, TF_OK, "");
return position;
}
void Sync(const TF_WritableFile* file, TF_Status* status) {
auto s3_file = static_cast<S3File*>(file->plugin_file);
if (!s3_file->outfile) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"The internal temporary file is not writable.");
return;
}
if (!s3_file->sync_needed) {
TF_SetStatus(status, TF_OK, "");
return;
}
auto position = static_cast<int64_t>(s3_file->outfile->tellp());
auto handle = s3_file->transfer_manager->UploadFile(
s3_file->outfile, s3_file->bucket, s3_file->object,
"application/octet-stream", Aws::Map<Aws::String, Aws::String>());
handle->WaitUntilFinished();
size_t retries = 0;
while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED &&
retries++ < kUploadRetries) {
// if multipart upload was used, only the failed parts will be re-sent
s3_file->transfer_manager->RetryUpload(s3_file->outfile, handle);
handle->WaitUntilFinished();
}
if (handle->GetStatus() != Aws::Transfer::TransferStatus::COMPLETED)
return TF_SetStatusFromAWSError(handle->GetLastError(), status);
s3_file->outfile->clear();
s3_file->outfile->seekp(position);
s3_file->sync_needed = false;
TF_SetStatus(status, TF_OK, "");
}
void Flush(const TF_WritableFile* file, TF_Status* status) {
Sync(file, status);
}
void Close(const TF_WritableFile* file, TF_Status* status) {
auto s3_file = static_cast<S3File*>(file->plugin_file);
if (s3_file->outfile) {
Sync(file, status);
if (TF_GetCode(status) != TF_OK) return;
s3_file->outfile.reset();
}
TF_SetStatus(status, TF_OK, "");
}
} // namespace tf_writable_file
@ -397,6 +500,79 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_SetStatus(status, TF_OK, "");
}
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
Aws::String bucket, object;
ParseS3Path(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
GetS3Client(s3_file);
GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file);
file->plugin_file = new tf_writable_file::S3File(
bucket, object, s3_file->s3_client,
s3_file->transfer_managers[Aws::Transfer::TransferDirection::UPLOAD]);
TF_SetStatus(status, TF_OK, "");
}
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
Aws::String bucket, object;
ParseS3Path(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
GetS3Client(s3_file);
GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file);
// We need to delete `file->plugin_file` in case of errors.
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile*)> writer(
file, [](TF_WritableFile* file) {
if (file != nullptr && file->plugin_file != nullptr) {
tf_writable_file::Cleanup(file);
}
});
writer->plugin_file = new tf_writable_file::S3File(
bucket, object, s3_file->s3_client,
s3_file->transfer_managers[Aws::Transfer::TransferDirection::UPLOAD]);
TF_SetStatus(status, TF_OK, "");
// Wraping inside a `std::unique_ptr` to prevent memory-leaking.
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile*)> reader(
new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
if (file != nullptr) {
tf_random_access_file::Cleanup(file);
delete file;
}
});
NewRandomAccessFile(filesystem, path, reader.get(), status);
if (TF_GetCode(status) != TF_OK) return;
uint64_t offset = 0;
std::string buffer(kS3ReadAppendableFileBufferSize, {});
while (true) {
auto read = tf_random_access_file::Read(reader.get(), offset,
kS3ReadAppendableFileBufferSize,
&buffer[0], status);
if (TF_GetCode(status) == TF_NOT_FOUND) {
break;
} else if (TF_GetCode(status) == TF_OK) {
offset += read;
tf_writable_file::Append(file, buffer.c_str(), read, status);
if (TF_GetCode(status) != TF_OK) return;
} else if (TF_GetCode(status) == TF_OUT_OF_RANGE) {
offset += read;
tf_writable_file::Append(file, buffer.c_str(), read, status);
if (TF_GetCode(status) != TF_OK) return;
break;
} else {
return;
}
}
writer.release();
TF_SetStatus(status, TF_OK, "");
}
// TODO(vnvo2409): Implement later
} // namespace tf_s3_filesystem