Merge pull request #41379 from vnvo2409:s3-crypto
PiperOrigin-RevId: 321221674 Change-Id: I07bd31133050767601cfad19e6246a2c863fb038
This commit is contained in:
commit
0fd1fcdf0b
@ -25,8 +25,21 @@ cc_library(
|
||||
"//tensorflow:windows": get_win_copts(),
|
||||
}),
|
||||
deps = [
|
||||
":aws_crypto",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"@aws",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "aws_crypto",
|
||||
srcs = ["aws_crypto.cc"],
|
||||
hdrs = ["aws_crypto.h"],
|
||||
deps = [
|
||||
"@aws",
|
||||
"@boringssl//:crypto",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
133
tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.cc
Normal file
133
tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.cc
Normal file
@ -0,0 +1,133 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h"
|
||||
|
||||
#include <aws/core/utils/crypto/HashResult.h>
|
||||
#include <aws/s3/S3Client.h>
|
||||
#include <openssl/hmac.h>
|
||||
#include <openssl/rand.h>
|
||||
#include <openssl/sha.h>
|
||||
|
||||
namespace tf_s3_filesystem {
|
||||
|
||||
class AWSSha256HMACOpenSSLImpl : public Aws::Utils::Crypto::HMAC {
|
||||
public:
|
||||
AWSSha256HMACOpenSSLImpl() {}
|
||||
|
||||
virtual ~AWSSha256HMACOpenSSLImpl() = default;
|
||||
|
||||
Aws::Utils::Crypto::HashResult Calculate(
|
||||
const Aws::Utils::ByteBuffer& toSign,
|
||||
const Aws::Utils::ByteBuffer& secret) override {
|
||||
unsigned int length = SHA256_DIGEST_LENGTH;
|
||||
Aws::Utils::ByteBuffer digest(length);
|
||||
memset(digest.GetUnderlyingData(), 0, length);
|
||||
|
||||
HMAC_CTX ctx;
|
||||
HMAC_CTX_init(&ctx);
|
||||
|
||||
HMAC_Init_ex(&ctx, secret.GetUnderlyingData(),
|
||||
static_cast<int>(secret.GetLength()), EVP_sha256(), NULL);
|
||||
HMAC_Update(&ctx, toSign.GetUnderlyingData(), toSign.GetLength());
|
||||
HMAC_Final(&ctx, digest.GetUnderlyingData(), &length);
|
||||
HMAC_CTX_cleanup(&ctx);
|
||||
|
||||
return Aws::Utils::Crypto::HashResult(std::move(digest));
|
||||
}
|
||||
};
|
||||
|
||||
class AWSSha256OpenSSLImpl : public Aws::Utils::Crypto::Hash {
|
||||
public:
|
||||
AWSSha256OpenSSLImpl() {}
|
||||
|
||||
virtual ~AWSSha256OpenSSLImpl() = default;
|
||||
|
||||
Aws::Utils::Crypto::HashResult Calculate(const Aws::String& str) override {
|
||||
SHA256_CTX sha256;
|
||||
SHA256_Init(&sha256);
|
||||
SHA256_Update(&sha256, str.data(), str.size());
|
||||
|
||||
Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH);
|
||||
SHA256_Final(hash.GetUnderlyingData(), &sha256);
|
||||
|
||||
return Aws::Utils::Crypto::HashResult(std::move(hash));
|
||||
}
|
||||
|
||||
Aws::Utils::Crypto::HashResult Calculate(Aws::IStream& stream) override {
|
||||
SHA256_CTX sha256;
|
||||
SHA256_Init(&sha256);
|
||||
|
||||
auto currentPos = stream.tellg();
|
||||
if (currentPos == std::streampos(std::streamoff(-1))) {
|
||||
currentPos = 0;
|
||||
stream.clear();
|
||||
}
|
||||
|
||||
stream.seekg(0, stream.beg);
|
||||
|
||||
char streamBuffer
|
||||
[Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE];
|
||||
while (stream.good()) {
|
||||
stream.read(streamBuffer,
|
||||
Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE);
|
||||
auto bytesRead = stream.gcount();
|
||||
|
||||
if (bytesRead > 0) {
|
||||
SHA256_Update(&sha256, streamBuffer, static_cast<size_t>(bytesRead));
|
||||
}
|
||||
}
|
||||
|
||||
stream.clear();
|
||||
stream.seekg(currentPos, stream.beg);
|
||||
|
||||
Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH);
|
||||
SHA256_Final(hash.GetUnderlyingData(), &sha256);
|
||||
|
||||
return Aws::Utils::Crypto::HashResult(std::move(hash));
|
||||
}
|
||||
};
|
||||
|
||||
class AWSSecureRandomBytesImpl : public Aws::Utils::Crypto::SecureRandomBytes {
|
||||
public:
|
||||
AWSSecureRandomBytesImpl() {}
|
||||
virtual ~AWSSecureRandomBytesImpl() = default;
|
||||
void GetBytes(unsigned char* buffer, size_t bufferSize) override {
|
||||
assert(buffer);
|
||||
int success = RAND_bytes(buffer, static_cast<int>(bufferSize));
|
||||
if (success != 1) {
|
||||
m_failure = true;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool m_failure;
|
||||
};
|
||||
|
||||
std::shared_ptr<Aws::Utils::Crypto::Hash>
|
||||
AWSSHA256Factory::CreateImplementation() const {
|
||||
return Aws::MakeShared<AWSSha256OpenSSLImpl>(AWSCryptoAllocationTag);
|
||||
}
|
||||
|
||||
std::shared_ptr<Aws::Utils::Crypto::HMAC>
|
||||
AWSSHA256HmacFactory::CreateImplementation() const {
|
||||
return Aws::MakeShared<AWSSha256HMACOpenSSLImpl>(AWSCryptoAllocationTag);
|
||||
}
|
||||
|
||||
std::shared_ptr<Aws::Utils::Crypto::SecureRandomBytes>
|
||||
AWSSecureRandomFactory::CreateImplementation() const {
|
||||
return Aws::MakeShared<AWSSecureRandomBytesImpl>(AWSCryptoAllocationTag);
|
||||
}
|
||||
|
||||
} // namespace tf_s3_filesystem
|
47
tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h
Normal file
47
tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h
Normal file
@ -0,0 +1,47 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_
|
||||
|
||||
#include <aws/core/Aws.h>
|
||||
#include <aws/core/utils/crypto/Factories.h>
|
||||
#include <aws/core/utils/crypto/HMAC.h>
|
||||
#include <aws/core/utils/crypto/Hash.h>
|
||||
#include <aws/core/utils/crypto/SecureRandom.h>
|
||||
|
||||
namespace tf_s3_filesystem {
|
||||
constexpr char AWSCryptoAllocationTag[] = "AWSCryptoAllocation";
|
||||
|
||||
class AWSSHA256Factory : public Aws::Utils::Crypto::HashFactory {
|
||||
public:
|
||||
std::shared_ptr<Aws::Utils::Crypto::Hash> CreateImplementation()
|
||||
const override;
|
||||
};
|
||||
|
||||
class AWSSHA256HmacFactory : public Aws::Utils::Crypto::HMACFactory {
|
||||
public:
|
||||
std::shared_ptr<Aws::Utils::Crypto::HMAC> CreateImplementation()
|
||||
const override;
|
||||
};
|
||||
|
||||
class AWSSecureRandomFactory : public Aws::Utils::Crypto::SecureRandomFactory {
|
||||
public:
|
||||
std::shared_ptr<Aws::Utils::Crypto::SecureRandomBytes> CreateImplementation()
|
||||
const override;
|
||||
};
|
||||
|
||||
} // namespace tf_s3_filesystem
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_
|
@ -52,6 +52,14 @@ static void ParseS3Path(const Aws::String& fname, bool object_empty_ok,
|
||||
}
|
||||
}
|
||||
|
||||
static void ShutdownClient(Aws::S3::S3Client* s3_client) {
|
||||
if (s3_client != nullptr) {
|
||||
delete s3_client;
|
||||
Aws::SDKOptions options;
|
||||
Aws::ShutdownAPI(options);
|
||||
}
|
||||
}
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
@ -79,6 +87,19 @@ namespace tf_read_only_memory_region {
|
||||
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_s3_filesystem {
|
||||
S3File::S3File()
|
||||
: s3_client(nullptr, ShutdownClient),
|
||||
executor(nullptr),
|
||||
initialization_lock() {}
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
filesystem->plugin_filesystem = new S3File();
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void Cleanup(TF_Filesystem* filesystem) {
|
||||
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
|
||||
delete s3_file;
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
|
@ -17,8 +17,22 @@ limitations under the License.
|
||||
|
||||
#include <aws/core/Aws.h>
|
||||
#include <aws/core/utils/StringUtils.h>
|
||||
#include <aws/core/utils/threading/Executor.h>
|
||||
#include <aws/s3/S3Client.h>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
namespace tf_s3_filesystem {
|
||||
typedef struct S3File {
|
||||
std::shared_ptr<Aws::S3::S3Client> s3_client;
|
||||
std::shared_ptr<Aws::Utils::Threading::PooledThreadExecutor> executor;
|
||||
absl::Mutex initialization_lock;
|
||||
S3File();
|
||||
} S3File;
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
||||
void Cleanup(TF_Filesystem* filesystem);
|
||||
} // namespace tf_s3_filesystem
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_S3_FILESYSTEM_H_
|
||||
|
Loading…
x
Reference in New Issue
Block a user