diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/BUILD b/tensorflow/c/experimental/filesystem/plugins/s3/BUILD index 4d4b7a1e6b0..358d5202614 100644 --- a/tensorflow/c/experimental/filesystem/plugins/s3/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/s3/BUILD @@ -25,8 +25,21 @@ cc_library( "//tensorflow:windows": get_win_copts(), }), deps = [ + ":aws_crypto", "//tensorflow/c:tf_status", "//tensorflow/c/experimental/filesystem:filesystem_interface", "@aws", + "@com_google_absl//absl/synchronization", ], ) + +cc_library( + name = "aws_crypto", + srcs = ["aws_crypto.cc"], + hdrs = ["aws_crypto.h"], + deps = [ + "@aws", + "@boringssl//:crypto", + ], + alwayslink = 1, +) diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.cc b/tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.cc new file mode 100644 index 00000000000..2e15ac176e3 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.cc @@ -0,0 +1,133 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h" + +#include +#include +#include +#include +#include + +namespace tf_s3_filesystem { + +class AWSSha256HMACOpenSSLImpl : public Aws::Utils::Crypto::HMAC { + public: + AWSSha256HMACOpenSSLImpl() {} + + virtual ~AWSSha256HMACOpenSSLImpl() = default; + + Aws::Utils::Crypto::HashResult Calculate( + const Aws::Utils::ByteBuffer& toSign, + const Aws::Utils::ByteBuffer& secret) override { + unsigned int length = SHA256_DIGEST_LENGTH; + Aws::Utils::ByteBuffer digest(length); + memset(digest.GetUnderlyingData(), 0, length); + + HMAC_CTX ctx; + HMAC_CTX_init(&ctx); + + HMAC_Init_ex(&ctx, secret.GetUnderlyingData(), + static_cast(secret.GetLength()), EVP_sha256(), NULL); + HMAC_Update(&ctx, toSign.GetUnderlyingData(), toSign.GetLength()); + HMAC_Final(&ctx, digest.GetUnderlyingData(), &length); + HMAC_CTX_cleanup(&ctx); + + return Aws::Utils::Crypto::HashResult(std::move(digest)); + } +}; + +class AWSSha256OpenSSLImpl : public Aws::Utils::Crypto::Hash { + public: + AWSSha256OpenSSLImpl() {} + + virtual ~AWSSha256OpenSSLImpl() = default; + + Aws::Utils::Crypto::HashResult Calculate(const Aws::String& str) override { + SHA256_CTX sha256; + SHA256_Init(&sha256); + SHA256_Update(&sha256, str.data(), str.size()); + + Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH); + SHA256_Final(hash.GetUnderlyingData(), &sha256); + + return Aws::Utils::Crypto::HashResult(std::move(hash)); + } + + Aws::Utils::Crypto::HashResult Calculate(Aws::IStream& stream) override { + SHA256_CTX sha256; + SHA256_Init(&sha256); + + auto currentPos = stream.tellg(); + if (currentPos == std::streampos(std::streamoff(-1))) { + currentPos = 0; + stream.clear(); + } + + stream.seekg(0, stream.beg); + + char streamBuffer + [Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE]; + while (stream.good()) { + stream.read(streamBuffer, + Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE); + auto bytesRead = stream.gcount(); + + if (bytesRead > 0) { + SHA256_Update(&sha256, streamBuffer, static_cast(bytesRead)); + } + } + + stream.clear(); + stream.seekg(currentPos, stream.beg); + + Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH); + SHA256_Final(hash.GetUnderlyingData(), &sha256); + + return Aws::Utils::Crypto::HashResult(std::move(hash)); + } +}; + +class AWSSecureRandomBytesImpl : public Aws::Utils::Crypto::SecureRandomBytes { + public: + AWSSecureRandomBytesImpl() {} + virtual ~AWSSecureRandomBytesImpl() = default; + void GetBytes(unsigned char* buffer, size_t bufferSize) override { + assert(buffer); + int success = RAND_bytes(buffer, static_cast(bufferSize)); + if (success != 1) { + m_failure = true; + } + } + + private: + bool m_failure; +}; + +std::shared_ptr +AWSSHA256Factory::CreateImplementation() const { + return Aws::MakeShared(AWSCryptoAllocationTag); +} + +std::shared_ptr +AWSSHA256HmacFactory::CreateImplementation() const { + return Aws::MakeShared(AWSCryptoAllocationTag); +} + +std::shared_ptr +AWSSecureRandomFactory::CreateImplementation() const { + return Aws::MakeShared(AWSCryptoAllocationTag); +} + +} // namespace tf_s3_filesystem diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h b/tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h new file mode 100644 index 00000000000..a70bf060fc7 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_ + +#include +#include +#include +#include +#include + +namespace tf_s3_filesystem { +constexpr char AWSCryptoAllocationTag[] = "AWSCryptoAllocation"; + +class AWSSHA256Factory : public Aws::Utils::Crypto::HashFactory { + public: + std::shared_ptr CreateImplementation() + const override; +}; + +class AWSSHA256HmacFactory : public Aws::Utils::Crypto::HMACFactory { + public: + std::shared_ptr CreateImplementation() + const override; +}; + +class AWSSecureRandomFactory : public Aws::Utils::Crypto::SecureRandomFactory { + public: + std::shared_ptr CreateImplementation() + const override; +}; + +} // namespace tf_s3_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc index 45350565500..f6ec1361335 100644 --- a/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc @@ -52,6 +52,14 @@ static void ParseS3Path(const Aws::String& fname, bool object_empty_ok, } } +static void ShutdownClient(Aws::S3::S3Client* s3_client) { + if (s3_client != nullptr) { + delete s3_client; + Aws::SDKOptions options; + Aws::ShutdownAPI(options); + } +} + // SECTION 1. Implementation for `TF_RandomAccessFile` // ---------------------------------------------------------------------------- namespace tf_random_access_file { @@ -79,6 +87,19 @@ namespace tf_read_only_memory_region { // SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem // ---------------------------------------------------------------------------- namespace tf_s3_filesystem { +S3File::S3File() + : s3_client(nullptr, ShutdownClient), + executor(nullptr), + initialization_lock() {} +void Init(TF_Filesystem* filesystem, TF_Status* status) { + filesystem->plugin_filesystem = new S3File(); + TF_SetStatus(status, TF_OK, ""); +} + +void Cleanup(TF_Filesystem* filesystem) { + auto s3_file = static_cast(filesystem->plugin_filesystem); + delete s3_file; +} // TODO(vnvo2409): Implement later diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.h b/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.h index 62e2a7e0c06..9086b5d00f4 100644 --- a/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.h +++ b/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.h @@ -17,8 +17,22 @@ limitations under the License. #include #include +#include +#include +#include "absl/synchronization/mutex.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/tf_status.h" +namespace tf_s3_filesystem { +typedef struct S3File { + std::shared_ptr s3_client; + std::shared_ptr executor; + absl::Mutex initialization_lock; + S3File(); +} S3File; +void Init(TF_Filesystem* filesystem, TF_Status* status); +void Cleanup(TF_Filesystem* filesystem); +} // namespace tf_s3_filesystem + #endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_S3_FILESYSTEM_H_