Merge pull request #39791 from rahul003:multipartdownload
PiperOrigin-RevId: 313436323 Change-Id: I0c082f8b74cdaedbb9e03d7998e738239e9a9a5f
This commit is contained in:
commit
2902e2b24a
@ -39,7 +39,9 @@ namespace data {
|
||||
constexpr char kCurrentFileIndex[] = "current_file_index";
|
||||
constexpr char kOffset[] = "offset";
|
||||
constexpr char kGcsFsPrefix[] = "gs://";
|
||||
constexpr char kS3FsPrefix[] = "s3://";
|
||||
constexpr int64 kCloudTpuBlockSize = 127LL << 20; // 127MB.
|
||||
constexpr int64 kS3BlockSize = kCloudTpuBlockSize;
|
||||
|
||||
bool is_cloud_tpu_gcs_fs() {
|
||||
#if defined(PLATFORM_CLOUD_TPU) && defined(TPU_GCS_FS)
|
||||
@ -237,12 +239,14 @@ void TFRecordDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
errors::InvalidArgument("`filenames` must be a scalar or a vector."));
|
||||
|
||||
bool is_gcs_fs = true;
|
||||
bool is_s3_fs = true;
|
||||
std::vector<string> filenames;
|
||||
filenames.reserve(filenames_tensor->NumElements());
|
||||
for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
|
||||
VLOG(2) << "Reading file: " << filenames_tensor->flat<tstring>()(i);
|
||||
filenames.push_back(filenames_tensor->flat<tstring>()(i));
|
||||
is_gcs_fs &= absl::StartsWith(filenames[i], kGcsFsPrefix);
|
||||
is_s3_fs &= absl::StartsWith(filenames[i], kS3FsPrefix);
|
||||
}
|
||||
|
||||
tstring compression_type;
|
||||
@ -264,6 +268,13 @@ void TFRecordDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
buffer_size = kCloudTpuBlockSize;
|
||||
}
|
||||
|
||||
if (is_s3_fs && buffer_size < kS3BlockSize) {
|
||||
VLOG(2) << "User buffer size is too small for reading "
|
||||
<< "TFRecords stored in S3. Overriding " << buffer_size
|
||||
<< " to the minimum recommended buffer_size = " << kS3BlockSize;
|
||||
buffer_size = kS3BlockSize;
|
||||
}
|
||||
|
||||
*output =
|
||||
new Dataset(ctx, std::move(filenames), compression_type, buffer_size);
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <aws/core/utils/StringUtils.h>
|
||||
#include <aws/core/utils/logging/AWSLogging.h>
|
||||
#include <aws/core/utils/logging/LogSystemInterface.h>
|
||||
#include <aws/core/utils/stream/PreallocatedStreamBuf.h>
|
||||
#include <aws/s3/S3Client.h>
|
||||
#include <aws/s3/S3Errors.h>
|
||||
#include <aws/s3/model/AbortMultipartUploadRequest.h>
|
||||
@ -58,10 +59,16 @@ static const char* kS3TempFileTemplate = "/tmp/s3_filesystem_XXXXXX";
|
||||
static const char* kS3FileSystemAllocationTag = "S3FileSystemAllocation";
|
||||
static const size_t kS3ReadAppendableFileBufferSize = 1024 * 1024;
|
||||
static const int64 kS3TimeoutMsec = 300000; // 5 min
|
||||
static const uint64 kS3MultiPartCopyPartSize = 50 * 1024 * 1024; // 50MB
|
||||
static const uint64 kS3MultiPartUploadChunkSize = 50 * 1024 * 1024; // 50 MB
|
||||
static const uint64 kS3MultiPartDownloadChunkSize = 2 * 1024 * 1024; // 50 MB
|
||||
static const int kS3GetChildrenMaxKeys = 100;
|
||||
static const int kExecutorPoolSize = 5;
|
||||
static const int kUploadRetries = 5;
|
||||
|
||||
// With this change multiple threads are used in one single download.
|
||||
// Increasing the thread pool size since multiple downloads
|
||||
// and uploads can occur in parallel.
|
||||
static const int kExecutorPoolSize = 25;
|
||||
static const int kUploadRetries = 3;
|
||||
static const int kDownloadRetries = 3;
|
||||
static const char* kExecutorTag = "TransferManagerExecutor";
|
||||
|
||||
Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
|
||||
@ -223,9 +230,16 @@ static Status CreateStatusFromAwsError(
|
||||
|
||||
class S3RandomAccessFile : public RandomAccessFile {
|
||||
public:
|
||||
S3RandomAccessFile(const string& bucket, const string& object,
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client)
|
||||
: bucket_(bucket), object_(object), s3_client_(s3_client) {}
|
||||
S3RandomAccessFile(
|
||||
const string& bucket, const string& object,
|
||||
const bool use_multi_part_download,
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager,
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client)
|
||||
: bucket_(bucket),
|
||||
object_(object),
|
||||
use_multi_part_download_(use_multi_part_download),
|
||||
transfer_manager_(transfer_manager),
|
||||
s3_client_(s3_client) {}
|
||||
|
||||
Status Name(StringPiece* result) const override {
|
||||
return errors::Unimplemented("S3RandomAccessFile does not support Name()");
|
||||
@ -235,6 +249,66 @@ class S3RandomAccessFile : public RandomAccessFile {
|
||||
char* scratch) const override {
|
||||
VLOG(1) << "ReadFilefromS3 s3://" << bucket_ << "/" << object_ << " from "
|
||||
<< offset << " for n:" << n;
|
||||
if (use_multi_part_download_) {
|
||||
return ReadS3TransferManager(offset, n, result, scratch);
|
||||
} else {
|
||||
return ReadS3Client(offset, n, result, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
Status ReadS3TransferManager(uint64 offset, size_t n, StringPiece* result,
|
||||
char* scratch) const {
|
||||
VLOG(3) << "Using TransferManager";
|
||||
|
||||
auto create_stream_fn = [&]() { // create stream lambda fn
|
||||
return Aws::New<TFS3UnderlyingStream>(
|
||||
"S3ReadStream",
|
||||
Aws::New<Aws::Utils::Stream::PreallocatedStreamBuf>(
|
||||
"S3ReadStream", reinterpret_cast<unsigned char*>(scratch), n));
|
||||
};
|
||||
|
||||
VLOG(3) << "Created stream to read with transferManager";
|
||||
|
||||
std::shared_ptr<Aws::Transfer::TransferHandle> handle =
|
||||
transfer_manager_.get()->DownloadFile(bucket_.c_str(), object_.c_str(),
|
||||
offset, n, create_stream_fn);
|
||||
handle->WaitUntilFinished();
|
||||
|
||||
// todo change this
|
||||
int retries = 0;
|
||||
|
||||
while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED &&
|
||||
handle->GetLastError().GetResponseCode() !=
|
||||
Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE &&
|
||||
retries++ < kDownloadRetries) {
|
||||
// only failed parts will be downloaded again
|
||||
VLOG(1) << "Retrying read of s3://" << bucket_ << "/" << object_
|
||||
<< " after failure. Current retry count:" << retries;
|
||||
transfer_manager_.get()->RetryDownload(handle);
|
||||
handle->WaitUntilFinished();
|
||||
}
|
||||
|
||||
if (handle->GetStatus() != Aws::Transfer::TransferStatus::COMPLETED) {
|
||||
auto error = handle->GetLastError();
|
||||
if (error.GetResponseCode() ==
|
||||
Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE) {
|
||||
// expected when end of file is reached
|
||||
n = 0;
|
||||
*result = StringPiece(scratch, n);
|
||||
return Status(error::OUT_OF_RANGE, "Read less bytes than requested");
|
||||
}
|
||||
return CreateStatusFromAwsError(error);
|
||||
} else {
|
||||
n = handle->GetBytesTotalSize();
|
||||
*result = StringPiece(scratch, handle->GetBytesTransferred());
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
Status ReadS3Client(uint64 offset, size_t n, StringPiece* result,
|
||||
char* scratch) const {
|
||||
VLOG(3) << "ReadFile using S3Client s3://" << bucket_ << "/" << object_;
|
||||
|
||||
Aws::S3::Model::GetObjectRequest getObjectRequest;
|
||||
getObjectRequest.WithBucket(bucket_.c_str()).WithKey(object_.c_str());
|
||||
string bytes = strings::StrCat("bytes=", offset, "-", offset + n - 1);
|
||||
@ -242,6 +316,7 @@ class S3RandomAccessFile : public RandomAccessFile {
|
||||
getObjectRequest.SetResponseStreamFactory([]() {
|
||||
return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag);
|
||||
});
|
||||
|
||||
auto getObjectOutcome = this->s3_client_->GetObject(getObjectRequest);
|
||||
if (!getObjectOutcome.IsSuccess()) {
|
||||
auto error = getObjectOutcome.GetError();
|
||||
@ -252,18 +327,21 @@ class S3RandomAccessFile : public RandomAccessFile {
|
||||
return Status(error::OUT_OF_RANGE, "Read less bytes than requested");
|
||||
}
|
||||
return CreateStatusFromAwsError(error);
|
||||
}
|
||||
n = getObjectOutcome.GetResult().GetContentLength();
|
||||
getObjectOutcome.GetResult().GetBody().read(scratch, n);
|
||||
} else {
|
||||
n = getObjectOutcome.GetResult().GetContentLength();
|
||||
getObjectOutcome.GetResult().GetBody().read(scratch, n);
|
||||
|
||||
*result = StringPiece(scratch, n);
|
||||
return Status::OK();
|
||||
*result = StringPiece(scratch, n);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
string bucket_;
|
||||
string object_;
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client_;
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager_;
|
||||
bool use_multi_part_download_;
|
||||
};
|
||||
|
||||
class S3WritableFile : public WritableFile {
|
||||
@ -375,16 +453,53 @@ class S3ReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
|
||||
S3FileSystem::S3FileSystem()
|
||||
: s3_client_(nullptr, ShutdownClient),
|
||||
initialization_lock_(),
|
||||
transfer_manager_(nullptr, ShutdownTransferManager),
|
||||
executor_(nullptr, ShutdownExecutor) {
|
||||
const char* part_size_str = getenv("S3_MULTI_PART_COPY_PART_SIZE");
|
||||
multi_part_copy_part_size_ = kS3MultiPartCopyPartSize;
|
||||
const char* part_size_str = getenv("S3_MULTI_PART_UPLOAD_CHUNK_SIZE");
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD] =
|
||||
kS3MultiPartUploadChunkSize;
|
||||
if (part_size_str) {
|
||||
uint64 part_size_num;
|
||||
if (strings::safe_strtou64(part_size_str, &part_size_num)) {
|
||||
multi_part_copy_part_size_ = part_size_num;
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD] =
|
||||
part_size_num;
|
||||
}
|
||||
}
|
||||
|
||||
// Different TensorFlow APIs call the download API with different
|
||||
// buffer size. Download performance depends on that size and this chunk size.
|
||||
part_size_str = getenv("S3_MULTI_PART_DOWNLOAD_CHUNK_SIZE");
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::DOWNLOAD] =
|
||||
kS3MultiPartDownloadChunkSize;
|
||||
if (part_size_str) {
|
||||
uint64 part_size_num;
|
||||
if (strings::safe_strtou64(part_size_str, &part_size_num)) {
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::DOWNLOAD] =
|
||||
part_size_num;
|
||||
}
|
||||
}
|
||||
|
||||
use_multi_part_download_ = true;
|
||||
const char* disable_transfer_mgr = getenv("S3_DISABLE_MULTI_PART_DOWNLOAD");
|
||||
if (disable_transfer_mgr) {
|
||||
if (disable_transfer_mgr[0] == '1') {
|
||||
use_multi_part_download_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
auto upload_pair = std::pair<Aws::Transfer::TransferDirection,
|
||||
std::shared_ptr<Aws::Transfer::TransferManager>>(
|
||||
Aws::Transfer::TransferDirection::UPLOAD,
|
||||
std::shared_ptr<Aws::Transfer::TransferManager>(nullptr,
|
||||
ShutdownTransferManager));
|
||||
auto download_pair =
|
||||
std::pair<Aws::Transfer::TransferDirection,
|
||||
std::shared_ptr<Aws::Transfer::TransferManager>>(
|
||||
Aws::Transfer::TransferDirection::DOWNLOAD,
|
||||
std::shared_ptr<Aws::Transfer::TransferManager>(
|
||||
nullptr, ShutdownTransferManager));
|
||||
|
||||
this->transfer_managers_.insert(upload_pair);
|
||||
this->transfer_managers_.insert(download_pair);
|
||||
}
|
||||
|
||||
S3FileSystem::~S3FileSystem() {}
|
||||
@ -424,20 +539,22 @@ std::shared_ptr<Aws::S3::S3Client> S3FileSystem::GetS3Client() {
|
||||
}
|
||||
|
||||
std::shared_ptr<Aws::Transfer::TransferManager>
|
||||
S3FileSystem::GetTransferManager() {
|
||||
S3FileSystem::GetTransferManager(
|
||||
const Aws::Transfer::TransferDirection& direction) {
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client = this->GetS3Client();
|
||||
std::lock_guard<mutex> lock(this->initialization_lock_);
|
||||
if (this->transfer_manager_.get() == nullptr) {
|
||||
if (this->transfer_managers_[direction].get() == nullptr) {
|
||||
Aws::Transfer::TransferManagerConfiguration config(
|
||||
this->GetExecutor().get());
|
||||
config.s3Client = s3_client;
|
||||
config.bufferSize = this->multi_part_copy_part_size_;
|
||||
// must be larger than pool size * multi_part_copy_part_size
|
||||
config.bufferSize = this->multi_part_chunk_size_[direction];
|
||||
// must be larger than pool size * multi part chunk size
|
||||
config.transferBufferMaxHeapSize =
|
||||
(kExecutorPoolSize + 1) * this->multi_part_copy_part_size_;
|
||||
this->transfer_manager_ = Aws::Transfer::TransferManager::Create(config);
|
||||
(kExecutorPoolSize + 1) * this->multi_part_chunk_size_[direction];
|
||||
this->transfer_managers_[direction] =
|
||||
Aws::Transfer::TransferManager::Create(config);
|
||||
}
|
||||
return this->transfer_manager_;
|
||||
return this->transfer_managers_[direction];
|
||||
}
|
||||
|
||||
std::shared_ptr<Aws::Utils::Threading::PooledThreadExecutor>
|
||||
@ -452,9 +569,21 @@ S3FileSystem::GetExecutor() {
|
||||
|
||||
Status S3FileSystem::NewRandomAccessFile(
|
||||
const string& fname, std::unique_ptr<RandomAccessFile>* result) {
|
||||
return NewRandomAccessFile(fname, result, true);
|
||||
}
|
||||
|
||||
Status S3FileSystem::NewRandomAccessFile(
|
||||
const string& fname, std::unique_ptr<RandomAccessFile>* result,
|
||||
bool use_multi_part_download) {
|
||||
string bucket, object;
|
||||
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
|
||||
result->reset(new S3RandomAccessFile(bucket, object, this->GetS3Client()));
|
||||
|
||||
// check if an override was defined for this file. used for testing
|
||||
bool use_mpd = this->use_multi_part_download_ && use_multi_part_download;
|
||||
result->reset(new S3RandomAccessFile(
|
||||
bucket, object, use_mpd,
|
||||
this->GetTransferManager(Aws::Transfer::TransferDirection::DOWNLOAD),
|
||||
this->GetS3Client()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -462,8 +591,11 @@ Status S3FileSystem::NewWritableFile(const string& fname,
|
||||
std::unique_ptr<WritableFile>* result) {
|
||||
string bucket, object;
|
||||
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
|
||||
result->reset(new S3WritableFile(bucket, object, this->GetTransferManager(),
|
||||
this->GetS3Client()));
|
||||
result->reset(new S3WritableFile(
|
||||
bucket, object,
|
||||
this->GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD),
|
||||
this->GetS3Client()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -478,8 +610,10 @@ Status S3FileSystem::NewAppendableFile(const string& fname,
|
||||
|
||||
string bucket, object;
|
||||
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
|
||||
result->reset(new S3WritableFile(bucket, object, this->GetTransferManager(),
|
||||
this->GetS3Client()));
|
||||
result->reset(new S3WritableFile(
|
||||
bucket, object,
|
||||
this->GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD),
|
||||
this->GetS3Client()));
|
||||
|
||||
while (true) {
|
||||
status = reader->Read(offset, kS3ReadAppendableFileBufferSize, &read_chunk,
|
||||
@ -773,10 +907,13 @@ Status S3FileSystem::CopyFile(const Aws::String& source_bucket,
|
||||
TF_RETURN_IF_ERROR(
|
||||
this->GetFileSize(string(source_full_path.c_str()), &file_length));
|
||||
int num_parts;
|
||||
if (file_length <= multi_part_copy_part_size_) {
|
||||
if (file_length <=
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD]) {
|
||||
num_parts = 1;
|
||||
} else {
|
||||
num_parts = ceil((float)file_length / multi_part_copy_part_size_);
|
||||
num_parts =
|
||||
ceil((float)file_length /
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD]);
|
||||
}
|
||||
|
||||
if (num_parts == 1) {
|
||||
@ -786,7 +923,8 @@ Status S3FileSystem::CopyFile(const Aws::String& source_bucket,
|
||||
"MultiPartCopy with number of parts more than 10000 is not supported. "
|
||||
"Your object ",
|
||||
source, " required ", num_parts,
|
||||
" as multi_part_copy_part_size is set to ", multi_part_copy_part_size_,
|
||||
" as multi_part_copy_part_size is set to ",
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD],
|
||||
". You can control this part size using the environment variable ",
|
||||
"S3_MULTI_PART_COPY_PART_SIZE to increase it.");
|
||||
return tensorflow::errors::Unimplemented(message);
|
||||
@ -831,7 +969,9 @@ Status S3FileSystem::MultiPartCopy(const Aws::String& source,
|
||||
|
||||
Aws::String uploadID = multipartUploadOutcome.GetResult().GetUploadId();
|
||||
VLOG(1) << "Copying from " << source << " in " << num_parts
|
||||
<< " parts of size " << multi_part_copy_part_size_ << " each";
|
||||
<< " parts of size "
|
||||
<< multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD]
|
||||
<< " each";
|
||||
Aws::S3::Model::CompletedMultipartUpload completedMPURequest;
|
||||
|
||||
// passed to each callback keyed by partNumber
|
||||
@ -859,8 +999,12 @@ Status S3FileSystem::MultiPartCopy(const Aws::String& source,
|
||||
for (std::map<int, PartState>::iterator it = incompletePartStates.begin();
|
||||
it != incompletePartStates.end(); it++) {
|
||||
int partNumber = it->first;
|
||||
uint64 startPos = (partNumber - 1) * multi_part_copy_part_size_;
|
||||
uint64 endPos = startPos + kS3MultiPartCopyPartSize - 1;
|
||||
uint64 startPos =
|
||||
(partNumber - 1) *
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD];
|
||||
uint64 endPos =
|
||||
startPos +
|
||||
multi_part_chunk_size_[Aws::Transfer::TransferDirection::UPLOAD] - 1;
|
||||
if (endPos >= file_length) {
|
||||
endPos = file_length - 1;
|
||||
}
|
||||
|
@ -52,6 +52,10 @@ class S3FileSystem : public FileSystem {
|
||||
Status NewRandomAccessFile(
|
||||
const string& fname, std::unique_ptr<RandomAccessFile>* result) override;
|
||||
|
||||
Status NewRandomAccessFile(const string& fname,
|
||||
std::unique_ptr<RandomAccessFile>* result,
|
||||
bool use_multi_part_download);
|
||||
|
||||
Status NewWritableFile(const string& fname,
|
||||
std::unique_ptr<WritableFile>* result) override;
|
||||
|
||||
@ -101,8 +105,12 @@ class S3FileSystem : public FileSystem {
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client_;
|
||||
|
||||
// Returns the member transfer manager, initializing as-needed.
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> GetTransferManager();
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager_;
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> GetTransferManager(
|
||||
const Aws::Transfer::TransferDirection& direction);
|
||||
void InitializeTransferManagers();
|
||||
std::map<Aws::Transfer::TransferDirection,
|
||||
std::shared_ptr<Aws::Transfer::TransferManager> >
|
||||
transfer_managers_;
|
||||
|
||||
// Returns the member executor for transfer manager, initializing as-needed.
|
||||
std::shared_ptr<Aws::Utils::Threading::PooledThreadExecutor> GetExecutor();
|
||||
@ -132,8 +140,10 @@ class S3FileSystem : public FileSystem {
|
||||
// Lock held when checking for s3_client_ and transfer_manager_ initialization
|
||||
mutex initialization_lock_;
|
||||
|
||||
// size to split objects during multipart copy
|
||||
uint64 multi_part_copy_part_size_;
|
||||
// size to split objects during multipart upload/download/copy
|
||||
std::map<Aws::Transfer::TransferDirection, uint64> multi_part_chunk_size_;
|
||||
|
||||
bool use_multi_part_download_;
|
||||
};
|
||||
|
||||
/// S3 implementation of a file system with retry on failures.
|
||||
@ -147,6 +157,16 @@ class RetryingS3FileSystem : public RetryingFileSystem<S3FileSystem> {
|
||||
)) {}
|
||||
};
|
||||
|
||||
// AWS Streams destroy the buffer (buf) passed, so creating a new
|
||||
// IOStream that retains the buffer so the calling function
|
||||
// can control it's lifecycle
|
||||
class TFS3UnderlyingStream : public Aws::IOStream {
|
||||
public:
|
||||
using Base = Aws::IOStream;
|
||||
TFS3UnderlyingStream(std::streambuf* buf) : Base(buf) {}
|
||||
virtual ~TFS3UnderlyingStream() = default;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_S3_S3_FILE_SYSTEM_H_
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/platform/s3/s3_file_system.h"
|
||||
|
||||
#include <time.h>
|
||||
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/file_system.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
@ -62,6 +64,96 @@ class S3FileSystemTest : public ::testing::Test {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadAllInChunks(const string& fname, string* content,
|
||||
bool use_multi_part_download = true) {
|
||||
std::unique_ptr<RandomAccessFile> reader;
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
s3fs.NewRandomAccessFile(fname, &reader, use_multi_part_download));
|
||||
|
||||
uint64 file_size = 0;
|
||||
TF_RETURN_IF_ERROR(s3fs.GetFileSize(fname, &file_size));
|
||||
|
||||
content->resize(file_size);
|
||||
|
||||
uint64 buffer_size = 16 * 1024 * 1024;
|
||||
|
||||
std::size_t part_count = (std::max)(
|
||||
static_cast<size_t>((file_size + buffer_size - 1) / buffer_size),
|
||||
static_cast<std::size_t>(1));
|
||||
VLOG(1) << "buffersize:" << buffer_size << " file_size:" << file_size
|
||||
<< " part_count=" << part_count;
|
||||
std::unique_ptr<char[]> buffer{new char[buffer_size]};
|
||||
std::stringstream ss;
|
||||
|
||||
int offset = 0;
|
||||
int result_size = 0;
|
||||
|
||||
using namespace std::chrono;
|
||||
auto start = high_resolution_clock::now();
|
||||
|
||||
for (int i = 0; i < part_count; i++) {
|
||||
StringPiece result;
|
||||
offset = i * buffer_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->Read(offset, buffer_size, &result, buffer.get()));
|
||||
|
||||
if (result.size() != 0) {
|
||||
ss.write(result.data(), result.size());
|
||||
result_size += result.size();
|
||||
}
|
||||
if (result_size == file_size) {
|
||||
break;
|
||||
}
|
||||
if (result.size() != buffer_size) {
|
||||
VLOG(1) << "Result size and buffer size did not match";
|
||||
if (result.empty()) {
|
||||
return errors::OutOfRange("eof");
|
||||
} else {
|
||||
return errors::DataLoss("truncated record at ", offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (file_size != result_size) {
|
||||
return errors::DataLoss("expected ", file_size, " got ", result_size,
|
||||
" bytes");
|
||||
}
|
||||
|
||||
auto stop = high_resolution_clock::now();
|
||||
duration<double> time_taken = duration_cast<duration<double>>(stop - start);
|
||||
VLOG(1) << "Time Taken"
|
||||
<< " : " << time_taken.count() << "seconds";
|
||||
|
||||
memcpy((char*)(content->data()), ss.str().data(),
|
||||
static_cast<size_t>(file_size));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadLargeFile() {
|
||||
// const string fname = TmpDir("train-00001-of-01024");
|
||||
auto large_file_name = getenv("LARGE_DOWNLOAD_FILE_NAME");
|
||||
const string fname = TmpDir(large_file_name);
|
||||
string content_xfer;
|
||||
string content_s3client;
|
||||
|
||||
// Read using Chunked Transfer Manager
|
||||
VLOG(1) << "Using transfer manager";
|
||||
TF_RETURN_IF_ERROR(ReadAllInChunks(fname, &content_xfer));
|
||||
|
||||
VLOG(1) << "Without transfer manager";
|
||||
// Read using old S3 API and see if the contents match with TransferManager
|
||||
TF_RETURN_IF_ERROR(ReadAllInChunks(fname, &content_s3client, false));
|
||||
|
||||
if (content_xfer == content_s3client) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
VLOG(1) << "ReadLargeFile contents DO NOT match";
|
||||
return Status(error::OUT_OF_RANGE, "ReadLargeFile contents DO NOT match");
|
||||
}
|
||||
}
|
||||
|
||||
S3FileSystem s3fs;
|
||||
};
|
||||
|
||||
@ -236,5 +328,9 @@ TEST_F(S3FileSystemTest, HasAtomicMove) {
|
||||
EXPECT_EQ(has_atomic_move, false);
|
||||
}
|
||||
|
||||
TEST_F(S3FileSystemTest, NewRandomAccessBigFile) {
|
||||
TF_EXPECT_OK(ReadLargeFile());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
8
third_party/aws/workspace.bzl
vendored
8
third_party/aws/workspace.bzl
vendored
@ -9,11 +9,11 @@ def repo():
|
||||
third_party_http_archive(
|
||||
name = "aws",
|
||||
urls = [
|
||||
"https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.7.266.tar.gz",
|
||||
"https://github.com/aws/aws-sdk-cpp/archive/1.7.266.tar.gz",
|
||||
"https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.7.336.tar.gz",
|
||||
"https://github.com/aws/aws-sdk-cpp/archive/1.7.336.tar.gz",
|
||||
],
|
||||
sha256 = "39fd8a2999260d2b8fcbc8187f1ed5299972c2b8bd14adb7850fd674fea67fb7",
|
||||
strip_prefix = "aws-sdk-cpp-1.7.266",
|
||||
sha256 = "758174f9788fed6cc1e266bcecb20bf738bd5ef1c3d646131c9ed15c2d6c5720",
|
||||
strip_prefix = "aws-sdk-cpp-1.7.336",
|
||||
build_file = "//third_party/aws:BUILD.bazel",
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user