diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/BUILD b/tensorflow/c/experimental/filesystem/plugins/s3/BUILD index 0f32456b5c8..a2108d06cbb 100644 --- a/tensorflow/c/experimental/filesystem/plugins/s3/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/s3/BUILD @@ -26,6 +26,7 @@ cc_library( }), deps = [ ":aws_crypto", + ":aws_logging", "//tensorflow/c:logging", "//tensorflow/c:tf_status", "//tensorflow/c/experimental/filesystem:filesystem_interface", @@ -46,6 +47,18 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "aws_logging", + srcs = ["aws_logging.cc"], + hdrs = ["aws_logging.h"], + deps = [ + "//tensorflow/c:logging", + "@aws", + "@com_google_absl//absl/synchronization", + ], + alwayslink = 1, +) + tf_cc_test( name = "s3_filesystem_test", srcs = [ diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.cc b/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.cc new file mode 100644 index 00000000000..353b733fd25 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.cc @@ -0,0 +1,159 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h" + +#include +#include +#include + +#include +#include +#include + +#include "absl/synchronization/mutex.h" +#include "tensorflow/c/logging.h" + +static constexpr char kAWSLoggingTag[] = "AWSLogging"; + +static const std::map + log_levels_string_to_aws = { + {"off", Aws::Utils::Logging::LogLevel::Off}, + {"fatal", Aws::Utils::Logging::LogLevel::Fatal}, + {"error", Aws::Utils::Logging::LogLevel::Error}, + {"warn", Aws::Utils::Logging::LogLevel::Warn}, + {"info", Aws::Utils::Logging::LogLevel::Info}, + {"debug", Aws::Utils::Logging::LogLevel::Debug}, + {"trace", Aws::Utils::Logging::LogLevel::Trace}}; + +static const std::map + log_levels_tf_to_aws = {{0, Aws::Utils::Logging::LogLevel::Info}, + {1, Aws::Utils::Logging::LogLevel::Warn}, + {2, Aws::Utils::Logging::LogLevel::Error}, + {3, Aws::Utils::Logging::LogLevel::Fatal}}; + +namespace tf_s3_filesystem { + +AWSLogSystem::AWSLogSystem(Aws::Utils::Logging::LogLevel log_level) + : log_level_(log_level) {} + +void AWSLogSystem::LogMessage(Aws::Utils::Logging::LogLevel log_level, + const std::string& message) { + if (message == "Initializing Curl library") return; + switch (log_level) { + case Aws::Utils::Logging::LogLevel::Info: + TF_Log(TF_INFO, message.c_str()); + break; + case Aws::Utils::Logging::LogLevel::Warn: + TF_Log(TF_WARNING, message.c_str()); + break; + case Aws::Utils::Logging::LogLevel::Error: + TF_Log(TF_ERROR, message.c_str()); + break; + case Aws::Utils::Logging::LogLevel::Fatal: + TF_Log(TF_FATAL, message.c_str()); + break; + default: + // this will match for DEBUG, TRACE + TF_Log(TF_INFO, message.c_str()); + break; + } +} + +void AWSLogSystem::Log(Aws::Utils::Logging::LogLevel log_level, const char* tag, + const char* format, ...) { + char buffer[256]; + va_list args; + va_start(args, format); + vsnprintf(buffer, 256, format, args); + va_end(args); + LogMessage(log_level, buffer); +} + +void AWSLogSystem::LogStream(Aws::Utils::Logging::LogLevel log_level, + const char* tag, + const Aws::OStringStream& message_stream) { + LogMessage(log_level, message_stream.rdbuf()->str().c_str()); +} + +void AWSLogSystem::Flush() { return; } + +static Aws::Utils::Logging::LogLevel TfLogLevelToAwsLogLevel(int level) { + // Converts TF Log Levels INFO, WARNING, ERROR and FATAL to the AWS enum + // values for the levels + if (log_levels_tf_to_aws.find(level) != log_levels_tf_to_aws.end()) { + return log_levels_tf_to_aws.at(level); + } else { + // default to fatal + return Aws::Utils::Logging::LogLevel::Fatal; + } +} + +static Aws::Utils::Logging::LogLevel ParseAwsLogLevelFromEnv() { + // defaults to FATAL log level for the AWS SDK + // this is because many normal tensorflow operations are logged as errors in + // the AWS SDK such as checking if a file exists can log an error in AWS SDK + // if the file does not actually exist. Another such case is when reading a + // file till the end, TensorFlow expects to see an InvalidRange exception at + // the end, but this would be an error in the AWS SDK. This confuses users, + // hence the default setting. + Aws::Utils::Logging::LogLevel log_level = + Aws::Utils::Logging::LogLevel::Fatal; + + const char* aws_env_var_val = getenv("AWS_LOG_LEVEL"); + if (aws_env_var_val != nullptr) { + std::string maybe_integer_str(aws_env_var_val, strlen(aws_env_var_val)); + std::istringstream ss(maybe_integer_str); + int level; + ss >> level; + if (ss.fail()) { + // wasn't a number + // expecting a string + std::string level_str = maybe_integer_str; + if (log_levels_string_to_aws.find(level_str) != + log_levels_string_to_aws.end()) { + log_level = log_levels_string_to_aws.at(level_str); + } + } else { + // backwards compatibility + // valid number, but this number follows the standard TensorFlow log + // levels need to convert this to AWS SDK logging level number + log_level = TfLogLevelToAwsLogLevel(level); + } + } + return log_level; +} + +static bool initialized = false; +ABSL_CONST_INIT static absl::Mutex s3_logging_mutex(absl::kConstInit); +void AWSLogSystem::InitializeAWSLogging() { + absl::MutexLock l(&s3_logging_mutex); + if (!initialized) { + Aws::Utils::Logging::InitializeAWSLogging(Aws::MakeShared( + kAWSLoggingTag, ParseAwsLogLevelFromEnv())); + initialized = true; + return; + } +} + +void AWSLogSystem::ShutdownAWSLogging() { + absl::MutexLock l(&s3_logging_mutex); + if (initialized) { + Aws::Utils::Logging::ShutdownAWSLogging(); + initialized = false; + return; + } +} + +} // namespace tf_s3_filesystem diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h b/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h new file mode 100644 index 00000000000..afecd7e5e62 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_ + +#include +#include + +#include +#include + +namespace tf_s3_filesystem { + +class AWSLogSystem : public Aws::Utils::Logging::LogSystemInterface { + public: + static void InitializeAWSLogging(); + static void ShutdownAWSLogging(); + + explicit AWSLogSystem(Aws::Utils::Logging::LogLevel log_level); + virtual ~AWSLogSystem() = default; + + // Gets the currently configured log level. + Aws::Utils::Logging::LogLevel GetLogLevel(void) const override { + return log_level_; + } + + // Set a new log level. This has the immediate effect of changing the log. + void SetLogLevel(Aws::Utils::Logging::LogLevel log_level) { + log_level_.store(log_level); + } + + // Does a printf style output to ProcessFormattedStatement. Don't use this, + // it's unsafe. See LogStream. + void Log(Aws::Utils::Logging::LogLevel log_level, const char* tag, + const char* format, ...) override; + + // Writes the stream to ProcessFormattedStatement. + void LogStream(Aws::Utils::Logging::LogLevel log_level, const char* tag, + const Aws::OStringStream& messageStream) override; + + // Flushes the buffered messages if the logger supports buffering + void Flush() override; + + private: + void LogMessage(Aws::Utils::Logging::LogLevel log_level, + const std::string& message); + std::atomic log_level_; +}; + +} // namespace tf_s3_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc index 1a61ab30a7c..9ff07633f2a 100644 --- a/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc @@ -38,6 +38,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h" +#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h" #include "tensorflow/c/logging.h" #include "tensorflow/c/tf_status.h" @@ -187,6 +188,8 @@ static void GetS3Client(tf_s3_filesystem::S3File* s3_file) { absl::MutexLock l(&s3_file->initialization_lock); if (s3_file->s3_client.get() == nullptr) { + tf_s3_filesystem::AWSLogSystem::InitializeAWSLogging(); + Aws::SDKOptions options; options.cryptoOptions.sha256Factory_create_fn = []() { return Aws::MakeShared( @@ -251,6 +254,7 @@ static void ShutdownClient(Aws::S3::S3Client* s3_client) { delete s3_client; Aws::SDKOptions options; Aws::ShutdownAPI(options); + tf_s3_filesystem::AWSLogSystem::ShutdownAWSLogging(); } }