Merge pull request #39791 from rahul003:multipartdownload

PiperOrigin-RevId: 313436323
Change-Id: I0c082f8b74cdaedbb9e03d7998e738239e9a9a5f
This commit is contained in:
TensorFlower Gardener 2020-05-27 12:13:07 -07:00
commit 2902e2b24a
5 changed files with 312 additions and 41 deletions

View File

@ -39,7 +39,9 @@ namespace data {
constexpr char kCurrentFileIndex[] = "current_file_index";
constexpr char kOffset[] = "offset";
constexpr char kGcsFsPrefix[] = "gs://";
constexpr char kS3FsPrefix[] = "s3://";
constexpr int64 kCloudTpuBlockSize = 127LL << 20; // 127MB.
constexpr int64 kS3BlockSize = kCloudTpuBlockSize;
bool is_cloud_tpu_gcs_fs() {
#if defined(PLATFORM_CLOUD_TPU) && defined(TPU_GCS_FS)
@ -237,12 +239,14 @@ void TFRecordDatasetOp::MakeDataset(OpKernelContext* ctx,
errors::InvalidArgument("`filenames` must be a scalar or a vector."));
bool is_gcs_fs = true;
bool is_s3_fs = true;
std::vector<string> filenames;
filenames.reserve(filenames_tensor->NumElements());
for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
VLOG(2) << "Reading file: " << filenames_tensor->flat<tstring>()(i);
filenames.push_back(filenames_tensor->flat<tstring>()(i));
is_gcs_fs &= absl::StartsWith(filenames[i], kGcsFsPrefix);
is_s3_fs &= absl::StartsWith(filenames[i], kS3FsPrefix);
}
tstring compression_type;
@ -264,6 +268,13 @@ void TFRecordDatasetOp::MakeDataset(OpKernelContext* ctx,
buffer_size = kCloudTpuBlockSize;
}
if (is_s3_fs && buffer_size < kS3BlockSize) {
VLOG(2) << "User buffer size is too small for reading "
<< "TFRecords stored in S3. Overriding " << buffer_size
<< " to the minimum recommended buffer_size = " << kS3BlockSize;
buffer_size = kS3BlockSize;
}
*output =
new Dataset(ctx, std::move(filenames), compression_type, buffer_size);
}

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <aws/core/utils/StringUtils.h>
#include <aws/core/utils/logging/AWSLogging.h>
#include <aws/core/utils/logging/LogSystemInterface.h>
#include <aws/core/utils/stream/PreallocatedStreamBuf.h>
#include <aws/s3/S3Client.h>
#include <aws/s3/S3Errors.h>
#include <aws/s3/model/AbortMultipartUploadRequest.h>
@ -58,10 +59,16 @@ static const char* kS3TempFileTemplate = "/tmp/s3_filesystem_XXXXXX";
static const char* kS3FileSystemAllocationTag = "S3FileSystemAllocation";
static const size_t kS3ReadAppendableFileBufferSize = 1024 * 1024;
static const int64 kS3TimeoutMsec = 300000; // 5 min
static const uint64 kS3MultiPartCopyPartSize = 50 * 1024 * 1024; // 50MB
static const uint64 kS3MultiPartUploadChunkSize = 50 * 1024 * 1024; // 50 MB
static const uint64 kS3MultiPartDownloadChunkSize = 2 * 1024 * 1024; // 50 MB
static const int kS3GetChildrenMaxKeys = 100;
static const int kExecutorPoolSize = 5;
static const int kUploadRetries = 5;
// With this change multiple threads are used in one single download.
// Increasing the thread pool size since multiple downloads
// and uploads can occur in parallel.
static const int kExecutorPoolSize = 25;
static const int kUploadRetries = 3;
static const int kDownloadRetries = 3;
static const char* kExecutorTag = "TransferManagerExecutor";
Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
@ -223,9 +230,16 @@ static Status CreateStatusFromAwsError(
class S3RandomAccessFile : public RandomAccessFile {
public:
S3RandomAccessFile(const string& bucket, const string& object,
std::shared_ptr<Aws::S3::S3Client> s3_client)
: bucket_(bucket), object_(object), s3_client_(s3_client) {}
S3RandomAccessFile(
const string& bucket, const string& object,
const bool use_multi_part_download,
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager,
std::shared_ptr<Aws::S3::S3Client> s3_client)
: bucket_(bucket),
object_(object),
use_multi_part_download_(use_multi_part_download),
transfer_manager_(transfer_manager),
s3_client_(s3_client) {}
Status Name(StringPiece* result) const override {
return errors::Unimplemented("S3RandomAccessFile does not support Name()");
@ -235,6 +249,66 @@ class S3RandomAccessFile : public RandomAccessFile {
char* scratch) const override {
VLOG(1) << "ReadFilefromS3 s3://" << bucket_ << "/" << object_ << " from "
<< offset << " for n:" << n;
if (use_multi_part_download_) {
return ReadS3TransferManager(offset, n, result, scratch);
} else {
return ReadS3Client(offset, n, result, scratch);
}
}
Status ReadS3TransferManager(uint64 offset, size_t n, StringPiece* result,
char* scratch) const {
VLOG(3) << "Using TransferManager";
auto create_stream_fn = [&]() { // create stream lambda fn
return Aws::New<TFS3UnderlyingStream>(
"S3ReadStream",
Aws::New<Aws::Utils::Stream::PreallocatedStreamBuf>(
"S3ReadStream", reinterpret_cast<unsigned char*>(scratch), n));
};
VLOG(3) << "Created stream to read with transferManager";
std::shared_ptr<Aws::Transfer::TransferHandle> handle =
transfer_manager_.get()->DownloadFile(bucket_.c_str(), object_.c_str(),
offset, n, create_stream_fn);
handle->WaitUntilFinished();
// todo change this
int retries = 0;
while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED &&
handle->GetLastError().GetResponseCode() !=
Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE &&
retries++ < kDownloadRetries) {
// only failed parts will be downloaded again
VLOG(1) << "Retrying read of s3://" << bucket_ << "/" << object_
<< " after failure. Current retry count:" << retries;
transfer_manager_.get()->RetryDownload(handle);
handle->WaitUntilFinished();
}
if (handle->GetStatus() != Aws::Transfer::TransferStatus::COMPLETED) {
auto error = handle->GetLastError();
if (error.GetResponseCode() ==
Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE) {
// expected when end of file is reached
n = 0;
*result = StringPiece(scratch, n);
return Status(error::OUT_OF_RANGE, "Read less bytes than requested");
}
return CreateStatusFromAwsError(error);
} else {
n = handle->GetBytesTotalSize();
*result = StringPiece(scratch, handle->GetBytesTransferred());
return Status::OK();
}
}
Status ReadS3Client(uint64 offset, size_t n, StringPiece* result,
char* scratch) const {
VLOG(3) << "ReadFile using S3Client s3://" << bucket_ << "/" << object_;
Aws::S3::Model::GetObjectRequest getObjectRequest;
getObjectRequest.WithBucket(bucket_.c_str()).WithKey(object_.c_str());
string bytes = strings::StrCat("bytes=", offset, "-", offset + n - 1);
@ -242,6 +316,7 @@ class S3RandomAccessFile : public RandomAccessFile {
getObjectRequest.SetResponseStreamFactory([]() {
return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag);
});
auto getObjectOutcome = this->s3_client_->GetObject(getObjectRequest);
if (!getObjectOutcome.IsSuccess()) {
auto error = getObjectOutcome.GetError();
@ -252,18 +327,21 @@ class S3RandomAccessFile : public RandomAccessFile {
return Status(error::OUT_OF_RANGE, "Read less bytes than requested");
}
return CreateStatusFromAwsError(error);
}
n = getObjectOutcome.GetResult().GetContentLength();
getObjectOutcome.GetResult().GetBody().read(scratch, n);
} else {
n = getObjectOutcome.GetResult().GetContentLength();
getObjectOutcome.GetResult().GetBody().read(scratch, n);
*result = StringPiece(scratch, n);
return Status::OK();
*result = StringPiece(scratch, n);
return Status::OK();
}
}
private:
string bucket_;
string object_;
std::shared_ptr<Aws::S3::S3Client> s3_client_;
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager_;
bool use_multi_part_download_;
};
class S3WritableFile : public WritableFile {
@ -375,16 +453,53 @@ class S3ReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
S3FileSystem::S3FileSystem()
: s3_client_(nullptr, ShutdownClient),
initialization_lock_(),
transfer_manager_(nullptr, ShutdownTransferManager),
executor_(nullptr, ShutdownExecutor) {
const char* part_size_str = getenv("S3_MULTI_PART_COPY_PART_SIZE");
multi_part_copy_part_size_ = kS3MultiPartCopyPartSize;
const char* part_size_str = getenv("S3_MULTI_PART_UPLOAD_CHUNK_SIZE");
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD] =
kS3MultiPartUploadChunkSize;
if (part_size_str) {
uint64 part_size_num;
if (strings::safe_strtou64(part_size_str, &part_size_num)) {
multi_part_copy_part_size_ = part_size_num;
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD] =
part_size_num;
}
}
// Different TensorFlow APIs call the download API with different
// buffer size. Download performance depends on that size and this chunk size.
part_size_str = getenv("S3_MULTI_PART_DOWNLOAD_CHUNK_SIZE");
multi_part_chunk_size_[Aws::Transfer::TransferDirection::DOWNLOAD] =
kS3MultiPartDownloadChunkSize;
if (part_size_str) {
uint64 part_size_num;
if (strings::safe_strtou64(part_size_str, &part_size_num)) {
multi_part_chunk_size_[Aws::Transfer::TransferDirection::DOWNLOAD] =
part_size_num;
}
}
use_multi_part_download_ = true;
const char* disable_transfer_mgr = getenv("S3_DISABLE_MULTI_PART_DOWNLOAD");
if (disable_transfer_mgr) {
if (disable_transfer_mgr[0] == '1') {
use_multi_part_download_ = false;
}
}
auto upload_pair = std::pair<Aws::Transfer::TransferDirection,
std::shared_ptr<Aws::Transfer::TransferManager>>(
Aws::Transfer::TransferDirection::UPLOAD,
std::shared_ptr<Aws::Transfer::TransferManager>(nullptr,
ShutdownTransferManager));
auto download_pair =
std::pair<Aws::Transfer::TransferDirection,
std::shared_ptr<Aws::Transfer::TransferManager>>(
Aws::Transfer::TransferDirection::DOWNLOAD,
std::shared_ptr<Aws::Transfer::TransferManager>(
nullptr, ShutdownTransferManager));
this->transfer_managers_.insert(upload_pair);
this->transfer_managers_.insert(download_pair);
}
S3FileSystem::~S3FileSystem() {}
@ -424,20 +539,22 @@ std::shared_ptr<Aws::S3::S3Client> S3FileSystem::GetS3Client() {
}
std::shared_ptr<Aws::Transfer::TransferManager>
S3FileSystem::GetTransferManager() {
S3FileSystem::GetTransferManager(
const Aws::Transfer::TransferDirection& direction) {
std::shared_ptr<Aws::S3::S3Client> s3_client = this->GetS3Client();
std::lock_guard<mutex> lock(this->initialization_lock_);
if (this->transfer_manager_.get() == nullptr) {
if (this->transfer_managers_[direction].get() == nullptr) {
Aws::Transfer::TransferManagerConfiguration config(
this->GetExecutor().get());
config.s3Client = s3_client;
config.bufferSize = this->multi_part_copy_part_size_;
// must be larger than pool size * multi_part_copy_part_size
config.bufferSize = this->multi_part_chunk_size_[direction];
// must be larger than pool size * multi part chunk size
config.transferBufferMaxHeapSize =
(kExecutorPoolSize + 1) * this->multi_part_copy_part_size_;
this->transfer_manager_ = Aws::Transfer::TransferManager::Create(config);
(kExecutorPoolSize + 1) * this->multi_part_chunk_size_[direction];
this->transfer_managers_[direction] =
Aws::Transfer::TransferManager::Create(config);
}
return this->transfer_manager_;
return this->transfer_managers_[direction];
}
std::shared_ptr<Aws::Utils::Threading::PooledThreadExecutor>
@ -452,9 +569,21 @@ S3FileSystem::GetExecutor() {
Status S3FileSystem::NewRandomAccessFile(
const string& fname, std::unique_ptr<RandomAccessFile>* result) {
return NewRandomAccessFile(fname, result, true);
}
Status S3FileSystem::NewRandomAccessFile(
const string& fname, std::unique_ptr<RandomAccessFile>* result,
bool use_multi_part_download) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
result->reset(new S3RandomAccessFile(bucket, object, this->GetS3Client()));
// check if an override was defined for this file. used for testing
bool use_mpd = this->use_multi_part_download_ && use_multi_part_download;
result->reset(new S3RandomAccessFile(
bucket, object, use_mpd,
this->GetTransferManager(Aws::Transfer::TransferDirection::DOWNLOAD),
this->GetS3Client()));
return Status::OK();
}
@ -462,8 +591,11 @@ Status S3FileSystem::NewWritableFile(const string& fname,
std::unique_ptr<WritableFile>* result) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
result->reset(new S3WritableFile(bucket, object, this->GetTransferManager(),
this->GetS3Client()));
result->reset(new S3WritableFile(
bucket, object,
this->GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD),
this->GetS3Client()));
return Status::OK();
}
@ -478,8 +610,10 @@ Status S3FileSystem::NewAppendableFile(const string& fname,
string bucket, object;
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
result->reset(new S3WritableFile(bucket, object, this->GetTransferManager(),
this->GetS3Client()));
result->reset(new S3WritableFile(
bucket, object,
this->GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD),
this->GetS3Client()));
while (true) {
status = reader->Read(offset, kS3ReadAppendableFileBufferSize, &read_chunk,
@ -773,10 +907,13 @@ Status S3FileSystem::CopyFile(const Aws::String& source_bucket,
TF_RETURN_IF_ERROR(
this->GetFileSize(string(source_full_path.c_str()), &file_length));
int num_parts;
if (file_length <= multi_part_copy_part_size_) {
if (file_length <=
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD]) {
num_parts = 1;
} else {
num_parts = ceil((float)file_length / multi_part_copy_part_size_);
num_parts =
ceil((float)file_length /
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD]);
}
if (num_parts == 1) {
@ -786,7 +923,8 @@ Status S3FileSystem::CopyFile(const Aws::String& source_bucket,
"MultiPartCopy with number of parts more than 10000 is not supported. "
"Your object ",
source, " required ", num_parts,
" as multi_part_copy_part_size is set to ", multi_part_copy_part_size_,
" as multi_part_copy_part_size is set to ",
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD],
". You can control this part size using the environment variable ",
"S3_MULTI_PART_COPY_PART_SIZE to increase it.");
return tensorflow::errors::Unimplemented(message);
@ -831,7 +969,9 @@ Status S3FileSystem::MultiPartCopy(const Aws::String& source,
Aws::String uploadID = multipartUploadOutcome.GetResult().GetUploadId();
VLOG(1) << "Copying from " << source << " in " << num_parts
<< " parts of size " << multi_part_copy_part_size_ << " each";
<< " parts of size "
<< multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD]
<< " each";
Aws::S3::Model::CompletedMultipartUpload completedMPURequest;
// passed to each callback keyed by partNumber
@ -859,8 +999,12 @@ Status S3FileSystem::MultiPartCopy(const Aws::String& source,
for (std::map<int, PartState>::iterator it = incompletePartStates.begin();
it != incompletePartStates.end(); it++) {
int partNumber = it->first;
uint64 startPos = (partNumber - 1) * multi_part_copy_part_size_;
uint64 endPos = startPos + kS3MultiPartCopyPartSize - 1;
uint64 startPos =
(partNumber - 1) *
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD];
uint64 endPos =
startPos +
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD] - 1;
if (endPos >= file_length) {
endPos = file_length - 1;
}

View File

@ -52,6 +52,10 @@ class S3FileSystem : public FileSystem {
Status NewRandomAccessFile(
const string& fname, std::unique_ptr<RandomAccessFile>* result) override;
Status NewRandomAccessFile(const string& fname,
std::unique_ptr<RandomAccessFile>* result,
bool use_multi_part_download);
Status NewWritableFile(const string& fname,
std::unique_ptr<WritableFile>* result) override;
@ -101,8 +105,12 @@ class S3FileSystem : public FileSystem {
std::shared_ptr<Aws::S3::S3Client> s3_client_;
// Returns the member transfer manager, initializing as-needed.
std::shared_ptr<Aws::Transfer::TransferManager> GetTransferManager();
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager_;
std::shared_ptr<Aws::Transfer::TransferManager> GetTransferManager(
const Aws::Transfer::TransferDirection& direction);
void InitializeTransferManagers();
std::map<Aws::Transfer::TransferDirection,
std::shared_ptr<Aws::Transfer::TransferManager> >
transfer_managers_;
// Returns the member executor for transfer manager, initializing as-needed.
std::shared_ptr<Aws::Utils::Threading::PooledThreadExecutor> GetExecutor();
@ -132,8 +140,10 @@ class S3FileSystem : public FileSystem {
// Lock held when checking for s3_client_ and transfer_manager_ initialization
mutex initialization_lock_;
// size to split objects during multipart copy
uint64 multi_part_copy_part_size_;
// size to split objects during multipart upload/download/copy
std::map<Aws::Transfer::TransferDirection, uint64> multi_part_chunk_size_;
bool use_multi_part_download_;
};
/// S3 implementation of a file system with retry on failures.
@ -147,6 +157,16 @@ class RetryingS3FileSystem : public RetryingFileSystem<S3FileSystem> {
)) {}
};
// AWS Streams destroy the buffer (buf) passed, so creating a new
// IOStream that retains the buffer so the calling function
// can control it's lifecycle
class TFS3UnderlyingStream : public Aws::IOStream {
public:
using Base = Aws::IOStream;
TFS3UnderlyingStream(std::streambuf* buf) : Base(buf) {}
virtual ~TFS3UnderlyingStream() = default;
};
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_S3_S3_FILE_SYSTEM_H_

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/platform/s3/s3_file_system.h"
#include <time.h>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/file_system.h"
#include "tensorflow/core/platform/path.h"
@ -62,6 +64,96 @@ class S3FileSystemTest : public ::testing::Test {
return Status::OK();
}
Status ReadAllInChunks(const string& fname, string* content,
bool use_multi_part_download = true) {
std::unique_ptr<RandomAccessFile> reader;
TF_RETURN_IF_ERROR(
s3fs.NewRandomAccessFile(fname, &reader, use_multi_part_download));
uint64 file_size = 0;
TF_RETURN_IF_ERROR(s3fs.GetFileSize(fname, &file_size));
content->resize(file_size);
uint64 buffer_size = 16 * 1024 * 1024;
std::size_t part_count = (std::max)(
static_cast<size_t>((file_size + buffer_size - 1) / buffer_size),
static_cast<std::size_t>(1));
VLOG(1) << "buffersize:" << buffer_size << " file_size:" << file_size
<< " part_count=" << part_count;
std::unique_ptr<char[]> buffer{new char[buffer_size]};
std::stringstream ss;
int offset = 0;
int result_size = 0;
using namespace std::chrono;
auto start = high_resolution_clock::now();
for (int i = 0; i < part_count; i++) {
StringPiece result;
offset = i * buffer_size;
TF_RETURN_IF_ERROR(
reader->Read(offset, buffer_size, &result, buffer.get()));
if (result.size() != 0) {
ss.write(result.data(), result.size());
result_size += result.size();
}
if (result_size == file_size) {
break;
}
if (result.size() != buffer_size) {
VLOG(1) << "Result size and buffer size did not match";
if (result.empty()) {
return errors::OutOfRange("eof");
} else {
return errors::DataLoss("truncated record at ", offset);
}
}
}
if (file_size != result_size) {
return errors::DataLoss("expected ", file_size, " got ", result_size,
" bytes");
}
auto stop = high_resolution_clock::now();
duration<double> time_taken = duration_cast<duration<double>>(stop - start);
VLOG(1) << "Time Taken"
<< " : " << time_taken.count() << "seconds";
memcpy((char*)(content->data()), ss.str().data(),
static_cast<size_t>(file_size));
return Status::OK();
}
Status ReadLargeFile() {
// const string fname = TmpDir("train-00001-of-01024");
auto large_file_name = getenv("LARGE_DOWNLOAD_FILE_NAME");
const string fname = TmpDir(large_file_name);
string content_xfer;
string content_s3client;
// Read using Chunked Transfer Manager
VLOG(1) << "Using transfer manager";
TF_RETURN_IF_ERROR(ReadAllInChunks(fname, &content_xfer));
VLOG(1) << "Without transfer manager";
// Read using old S3 API and see if the contents match with TransferManager
TF_RETURN_IF_ERROR(ReadAllInChunks(fname, &content_s3client, false));
if (content_xfer == content_s3client) {
return Status::OK();
} else {
VLOG(1) << "ReadLargeFile contents DO NOT match";
return Status(error::OUT_OF_RANGE, "ReadLargeFile contents DO NOT match");
}
}
S3FileSystem s3fs;
};
@ -236,5 +328,9 @@ TEST_F(S3FileSystemTest, HasAtomicMove) {
EXPECT_EQ(has_atomic_move, false);
}
TEST_F(S3FileSystemTest, NewRandomAccessBigFile) {
TF_EXPECT_OK(ReadLargeFile());
}
} // namespace
} // namespace tensorflow

View File

@ -9,11 +9,11 @@ def repo():
third_party_http_archive(
name = "aws",
urls = [
"https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.7.266.tar.gz",
"https://github.com/aws/aws-sdk-cpp/archive/1.7.266.tar.gz",
"https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.7.336.tar.gz",
"https://github.com/aws/aws-sdk-cpp/archive/1.7.336.tar.gz",
],
sha256 = "39fd8a2999260d2b8fcbc8187f1ed5299972c2b8bd14adb7850fd674fea67fb7",
strip_prefix = "aws-sdk-cpp-1.7.266",
sha256 = "758174f9788fed6cc1e266bcecb20bf738bd5ef1c3d646131c9ed15c2d6c5720",
strip_prefix = "aws-sdk-cpp-1.7.336",
build_file = "//third_party/aws:BUILD.bazel",
)