Merge pull request #41379 from vnvo2409:s3-crypto

PiperOrigin-RevId: 321221674
Change-Id: I07bd31133050767601cfad19e6246a2c863fb038
This commit is contained in:
TensorFlower Gardener 2020-07-14 13:24:34 -07:00
commit 0fd1fcdf0b
5 changed files with 228 additions and 0 deletions

View File

@ -25,8 +25,21 @@ cc_library(
"//tensorflow:windows": get_win_copts(),
}),
deps = [
":aws_crypto",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@aws",
"@com_google_absl//absl/synchronization",
],
)
cc_library(
name = "aws_crypto",
srcs = ["aws_crypto.cc"],
hdrs = ["aws_crypto.h"],
deps = [
"@aws",
"@boringssl//:crypto",
],
alwayslink = 1,
)

View File

@ -0,0 +1,133 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h"
#include <aws/core/utils/crypto/HashResult.h>
#include <aws/s3/S3Client.h>
#include <openssl/hmac.h>
#include <openssl/rand.h>
#include <openssl/sha.h>
namespace tf_s3_filesystem {
class AWSSha256HMACOpenSSLImpl : public Aws::Utils::Crypto::HMAC {
public:
AWSSha256HMACOpenSSLImpl() {}
virtual ~AWSSha256HMACOpenSSLImpl() = default;
Aws::Utils::Crypto::HashResult Calculate(
const Aws::Utils::ByteBuffer& toSign,
const Aws::Utils::ByteBuffer& secret) override {
unsigned int length = SHA256_DIGEST_LENGTH;
Aws::Utils::ByteBuffer digest(length);
memset(digest.GetUnderlyingData(), 0, length);
HMAC_CTX ctx;
HMAC_CTX_init(&ctx);
HMAC_Init_ex(&ctx, secret.GetUnderlyingData(),
static_cast<int>(secret.GetLength()), EVP_sha256(), NULL);
HMAC_Update(&ctx, toSign.GetUnderlyingData(), toSign.GetLength());
HMAC_Final(&ctx, digest.GetUnderlyingData(), &length);
HMAC_CTX_cleanup(&ctx);
return Aws::Utils::Crypto::HashResult(std::move(digest));
}
};
class AWSSha256OpenSSLImpl : public Aws::Utils::Crypto::Hash {
public:
AWSSha256OpenSSLImpl() {}
virtual ~AWSSha256OpenSSLImpl() = default;
Aws::Utils::Crypto::HashResult Calculate(const Aws::String& str) override {
SHA256_CTX sha256;
SHA256_Init(&sha256);
SHA256_Update(&sha256, str.data(), str.size());
Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH);
SHA256_Final(hash.GetUnderlyingData(), &sha256);
return Aws::Utils::Crypto::HashResult(std::move(hash));
}
Aws::Utils::Crypto::HashResult Calculate(Aws::IStream& stream) override {
SHA256_CTX sha256;
SHA256_Init(&sha256);
auto currentPos = stream.tellg();
if (currentPos == std::streampos(std::streamoff(-1))) {
currentPos = 0;
stream.clear();
}
stream.seekg(0, stream.beg);
char streamBuffer
[Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE];
while (stream.good()) {
stream.read(streamBuffer,
Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE);
auto bytesRead = stream.gcount();
if (bytesRead > 0) {
SHA256_Update(&sha256, streamBuffer, static_cast<size_t>(bytesRead));
}
}
stream.clear();
stream.seekg(currentPos, stream.beg);
Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH);
SHA256_Final(hash.GetUnderlyingData(), &sha256);
return Aws::Utils::Crypto::HashResult(std::move(hash));
}
};
class AWSSecureRandomBytesImpl : public Aws::Utils::Crypto::SecureRandomBytes {
public:
AWSSecureRandomBytesImpl() {}
virtual ~AWSSecureRandomBytesImpl() = default;
void GetBytes(unsigned char* buffer, size_t bufferSize) override {
assert(buffer);
int success = RAND_bytes(buffer, static_cast<int>(bufferSize));
if (success != 1) {
m_failure = true;
}
}
private:
bool m_failure;
};
std::shared_ptr<Aws::Utils::Crypto::Hash>
AWSSHA256Factory::CreateImplementation() const {
return Aws::MakeShared<AWSSha256OpenSSLImpl>(AWSCryptoAllocationTag);
}
std::shared_ptr<Aws::Utils::Crypto::HMAC>
AWSSHA256HmacFactory::CreateImplementation() const {
return Aws::MakeShared<AWSSha256HMACOpenSSLImpl>(AWSCryptoAllocationTag);
}
std::shared_ptr<Aws::Utils::Crypto::SecureRandomBytes>
AWSSecureRandomFactory::CreateImplementation() const {
return Aws::MakeShared<AWSSecureRandomBytesImpl>(AWSCryptoAllocationTag);
}
} // namespace tf_s3_filesystem

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_
#include <aws/core/Aws.h>
#include <aws/core/utils/crypto/Factories.h>
#include <aws/core/utils/crypto/HMAC.h>
#include <aws/core/utils/crypto/Hash.h>
#include <aws/core/utils/crypto/SecureRandom.h>
namespace tf_s3_filesystem {
constexpr char AWSCryptoAllocationTag[] = "AWSCryptoAllocation";
class AWSSHA256Factory : public Aws::Utils::Crypto::HashFactory {
public:
std::shared_ptr<Aws::Utils::Crypto::Hash> CreateImplementation()
const override;
};
class AWSSHA256HmacFactory : public Aws::Utils::Crypto::HMACFactory {
public:
std::shared_ptr<Aws::Utils::Crypto::HMAC> CreateImplementation()
const override;
};
class AWSSecureRandomFactory : public Aws::Utils::Crypto::SecureRandomFactory {
public:
std::shared_ptr<Aws::Utils::Crypto::SecureRandomBytes> CreateImplementation()
const override;
};
} // namespace tf_s3_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_

View File

@ -52,6 +52,14 @@ static void ParseS3Path(const Aws::String& fname, bool object_empty_ok,
}
}
static void ShutdownClient(Aws::S3::S3Client* s3_client) {
if (s3_client != nullptr) {
delete s3_client;
Aws::SDKOptions options;
Aws::ShutdownAPI(options);
}
}
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
@ -79,6 +87,19 @@ namespace tf_read_only_memory_region {
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_s3_filesystem {
S3File::S3File()
: s3_client(nullptr, ShutdownClient),
executor(nullptr),
initialization_lock() {}
void Init(TF_Filesystem* filesystem, TF_Status* status) {
filesystem->plugin_filesystem = new S3File();
TF_SetStatus(status, TF_OK, "");
}
void Cleanup(TF_Filesystem* filesystem) {
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
delete s3_file;
}
// TODO(vnvo2409): Implement later

View File

@ -17,8 +17,22 @@ limitations under the License.
#include <aws/core/Aws.h>
#include <aws/core/utils/StringUtils.h>
#include <aws/core/utils/threading/Executor.h>
#include <aws/s3/S3Client.h>
#include "absl/synchronization/mutex.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
namespace tf_s3_filesystem {
typedef struct S3File {
std::shared_ptr<Aws::S3::S3Client> s3_client;
std::shared_ptr<Aws::Utils::Threading::PooledThreadExecutor> executor;
absl::Mutex initialization_lock;
S3File();
} S3File;
void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem);
} // namespace tf_s3_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_S3_FILESYSTEM_H_