From ed65e69560a8e2d58f7571fc2dcac269b20c260d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 11 May 2016 21:03:48 -0800 Subject: [PATCH] File system implementation for Google Cloud Storage. This code implements a file system for file paths starting with gs:// using the HTTP API to Google Cloud Storage. No authentication is implemented yet, so only GCS objects with public access can be used. Change: 122126085 --- configure | 31 ++ jsoncpp.BUILD | 34 ++ tensorflow/BUILD | 1 + tensorflow/core/BUILD | 5 +- tensorflow/core/platform/cloud/BUILD | 79 +++ .../core/platform/cloud/gcs_file_system.cc | 516 ++++++++++++++++++ .../core/platform/cloud/gcs_file_system.h | 73 +++ .../platform/cloud/gcs_file_system_test.cc | 363 ++++++++++++ .../core/platform/cloud/http_request.cc | 410 ++++++++++++++ tensorflow/core/platform/cloud/http_request.h | 156 ++++++ .../core/platform/cloud/http_request_test.cc | 323 +++++++++++ .../core/platform/default/build_config.bzl | 7 + .../ci_build/install/install_deb_packages.sh | 1 + tensorflow/workspace.bzl | 12 + 14 files changed, 2009 insertions(+), 2 deletions(-) create mode 100644 jsoncpp.BUILD create mode 100644 tensorflow/core/platform/cloud/BUILD create mode 100644 tensorflow/core/platform/cloud/gcs_file_system.cc create mode 100644 tensorflow/core/platform/cloud/gcs_file_system.h create mode 100644 tensorflow/core/platform/cloud/gcs_file_system_test.cc create mode 100644 tensorflow/core/platform/cloud/http_request.cc create mode 100644 tensorflow/core/platform/cloud/http_request.h create mode 100644 tensorflow/core/platform/cloud/http_request_test.cc diff --git a/configure b/configure index 3f9a53b573b..5d6397da57d 100755 --- a/configure +++ b/configure @@ -35,6 +35,37 @@ while true; do # Retry done +while [ "$TF_NEED_GCP" == "" ]; do + read -p "Do you wish to build TensorFlow with "\ +"Google Cloud Platform support? [y/N] " INPUT + case $INPUT in + [Yy]* ) echo "Google Cloud Platform support will be enabled for "\ +"TensorFlow"; TF_NEED_GCP=1;; + [Nn]* ) echo "No Google Cloud Platform support will be enabled for "\ +"TensorFlow"; TF_NEED_GCP=0;; + "" ) echo "No Google Cloud Platform support will be enabled for "\ +"TensorFlow"; TF_NEED_GCP=0;; + * ) echo "Invalid selection: " $INPUT;; + esac +done + +if [ "$TF_NEED_GCP" == "1" ]; then + + ## Verify that libcurl header files are available. + # Only check Linux, since on MacOS the header files are installed with XCode. + if [[ $(uname -a) =~ Linux ]] && [[ ! -f "/usr/include/curl/curl.h" ]]; then + echo "ERROR: It appears that the development version of libcurl is not "\ +"available. Please install the libcurl3-dev package." + exit 1 + fi + + # Update Bazel build configuration. + perl -pi -e "s,WITH_GCP_SUPPORT = (False|True),WITH_GCP_SUPPORT = True,s" tensorflow/core/platform/default/build_config.bzl +else + # Update Bazel build configuration. + perl -pi -e "s,WITH_GCP_SUPPORT = (False|True),WITH_GCP_SUPPORT = False,s" tensorflow/core/platform/default/build_config.bzl +fi + ## Find swig path if [ -z "$SWIG_PATH" ]; then SWIG_PATH=`type -p swig 2> /dev/null` diff --git a/jsoncpp.BUILD b/jsoncpp.BUILD new file mode 100644 index 00000000000..2bb2e19a67f --- /dev/null +++ b/jsoncpp.BUILD @@ -0,0 +1,34 @@ +licenses(["notice"]) # MIT + +JSON_HEADERS = [ + "include/json/assertions.h", + "include/json/autolink.h", + "include/json/config.h", + "include/json/features.h", + "include/json/forwards.h", + "include/json/json.h", + "src/lib_json/json_batchallocator.h", + "include/json/reader.h", + "include/json/value.h", + "include/json/writer.h", +] + +JSON_SOURCES = [ + "src/lib_json/json_reader.cpp", + "src/lib_json/json_value.cpp", + "src/lib_json/json_writer.cpp", + "src/lib_json/json_tool.h", +] + +INLINE_SOURCES = [ + "src/lib_json/json_valueiterator.inl", +] + +cc_library( + name = "jsoncpp", + srcs = JSON_SOURCES, + hdrs = JSON_HEADERS, + includes = ["include"], + textual_hdrs = INLINE_SOURCES, + visibility = ["//visibility:public"], +) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 040be96ffa3..d50f9870094 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -91,6 +91,7 @@ filegroup( "//tensorflow/core/distributed_runtime/rpc:all_files", "//tensorflow/core/kernels:all_files", "//tensorflow/core/ops/compat:all_files", + "//tensorflow/core/platform/cloud:all_files", "//tensorflow/core/platform/default/build_config:all_files", "//tensorflow/core/util/ctc:all_files", "//tensorflow/examples/android:all_files", diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 73f4e91fa8b..044d732f30a 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -67,6 +67,7 @@ load( "tf_proto_library", "tf_proto_library_cc", "tf_additional_lib_srcs", + "tf_additional_lib_deps", "tf_additional_stream_executor_srcs", "tf_additional_test_deps", "tf_additional_test_srcs", @@ -995,9 +996,9 @@ tf_cuda_library( ":lib_internal", ":proto_text", ":protos_all_cc", - "//tensorflow/core/kernels:required", "//third_party/eigen3", - ], + "//tensorflow/core/kernels:required", + ] + tf_additional_lib_deps(), alwayslink = 1, ) diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD new file mode 100644 index 00000000000..95da1373869 --- /dev/null +++ b/tensorflow/core/platform/cloud/BUILD @@ -0,0 +1,79 @@ +# Description: +# Cloud file system implementation. + +package( + default_visibility = ["//visibility:private"], +) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +cc_library( + name = "gcs_file_system", + srcs = [ + "gcs_file_system.cc", + ], + hdrs = [ + "gcs_file_system.h", + ], + linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 + visibility = ["//visibility:public"], + deps = [ + "@jsoncpp_git//:jsoncpp", + ":http_request", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib_internal", + ], + alwayslink = 1, +) + +cc_library( + name = "http_request", + srcs = [ + "http_request.cc", + ], + hdrs = [ + "http_request.h", + ], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_cc_test( + name = "gcs_file_system_test", + size = "small", + deps = [ + ":gcs_file_system", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "http_request_test", + size = "small", + deps = [ + ":http_request", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc new file mode 100644 index 00000000000..ba9418f06f4 --- /dev/null +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -0,0 +1,516 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/cloud/gcs_file_system.h" +#include +#include +#include +#include +#include +#include +#include +#include "include/json/json.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/scanner.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" + +namespace tensorflow { + +namespace { + +constexpr char kGcsUriBase[] = "https://www.googleapis.com/storage/v1/"; +constexpr char kGcsUploadUriBase[] = + "https://www.googleapis.com/upload/storage/v1/"; +constexpr char kStorageHost[] = "storage.googleapis.com"; +constexpr size_t kBufferSize = 1024 * 1024; // In bytes. + +Status GetTmpFilename(string* filename) { + if (!filename) { + return errors::Internal("'filename' cannot be nullptr."); + } + char buffer[] = "/tmp/gcs_filesystem_XXXXXX"; + int fd = mkstemp(buffer); + if (fd < 0) { + return errors::Internal("Failed to create a temporary file."); + } + close(fd); + *filename = buffer; + return Status::OK(); +} + +/// No-op auth provider, which will only work for public objects. +class EmptyAuthProvider : public AuthProvider { + public: + Status GetToken(string* token) const override { + *token = ""; + return Status::OK(); + } +}; + +Status GetAuthToken(const AuthProvider* provider, string* token) { + if (!provider) { + return errors::Internal("Auth provider is required."); + } + return provider->GetToken(token); +} + +/// \brief Splits a GCS path to a bucket and an object. +/// +/// For example, "gs://bucket-name/path/to/file.txt" gets split into +/// "bucket-name" and "path/to/file.txt". +Status ParseGcsPath(const string& fname, string* bucket, string* object) { + if (!bucket || !object) { + return errors::Internal("bucket and object cannot be null."); + } + StringPiece matched_bucket, matched_object; + if (!strings::Scanner(fname) + .OneLiteral("gs://") + .RestartCapture() + .ScanEscapedUntil('/') + .OneLiteral("/") + .GetResult(&matched_object, &matched_bucket)) { + return errors::InvalidArgument("Couldn't parse GCS path: " + fname); + } + // 'matched_bucket' contains a trailing slash, exclude it. + *bucket = string(matched_bucket.data(), matched_bucket.size() - 1); + *object = string(matched_object.data(), matched_object.size()); + return Status::OK(); +} + +/// GCS-based implementation of a random access file. +class GcsRandomAccessFile : public RandomAccessFile { + public: + GcsRandomAccessFile(const string& bucket, const string& object, + AuthProvider* auth_provider, + HttpRequest::Factory* http_request_factory) + : bucket_(bucket), + object_(object), + auth_provider_(auth_provider), + http_request_factory_(std::move(http_request_factory)) {} + + Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const override { + string auth_token; + TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_, &auth_token)); + + std::unique_ptr request(http_request_factory_->Create()); + TF_RETURN_IF_ERROR(request->Init()); + TF_RETURN_IF_ERROR(request->SetUri( + strings::StrCat("https://", bucket_, ".", kStorageHost, "/", object_))); + TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); + TF_RETURN_IF_ERROR(request->SetRange(offset, offset + n - 1)); + TF_RETURN_IF_ERROR(request->SetResultBuffer(scratch, n, result)); + TF_RETURN_IF_ERROR(request->Send()); + + if (result->size() < n) { + // This is not an error per se. The RandomAccessFile interface expects + // that Read returns OutOfRange if fewer bytes were read than requested. + return errors::OutOfRange(strings::StrCat("EOF reached, ", result->size(), + " bytes were read out of ", n, + " bytes requested.")); + } + return Status::OK(); + } + + private: + string bucket_; + string object_; + AuthProvider* auth_provider_; + HttpRequest::Factory* http_request_factory_; +}; + +/// \brief GCS-based implementation of a writeable file. +/// +/// Since GCS objects are immutable, this implementation writes to a local +/// tmp file and copies it to GCS on flush/close. +class GcsWritableFile : public WritableFile { + public: + GcsWritableFile(const string& bucket, const string& object, + AuthProvider* auth_provider, + HttpRequest::Factory* http_request_factory) + : bucket_(bucket), + object_(object), + auth_provider_(auth_provider), + http_request_factory_(std::move(http_request_factory)) { + if (GetTmpFilename(&tmp_content_filename_).ok()) { + outfile_.open(tmp_content_filename_, + std::ofstream::binary | std::ofstream::app); + } + } + + /// \brief Constructs the writable file in append mode. + /// + /// tmp_content_filename should contain a path of an existing temporary file + /// with the content to be appended. The class takes onwnership of the + /// specified tmp file and deletes it on close. + GcsWritableFile(const string& bucket, const string& object, + AuthProvider* auth_provider, + const string& tmp_content_filename, + HttpRequest::Factory* http_request_factory) + : bucket_(bucket), + object_(object), + auth_provider_(auth_provider), + http_request_factory_(std::move(http_request_factory)) { + tmp_content_filename_ = tmp_content_filename; + outfile_.open(tmp_content_filename_, + std::ofstream::binary | std::ofstream::app); + } + + ~GcsWritableFile() { Close(); } + + Status Append(const StringPiece& data) override { + TF_RETURN_IF_ERROR(CheckWritable()); + outfile_ << data; + return Status::OK(); + } + + Status Close() override { + if (outfile_.is_open()) { + TF_RETURN_IF_ERROR(Sync()); + outfile_.close(); + std::remove(tmp_content_filename_.c_str()); + } + return Status::OK(); + } + + Status Flush() override { return Sync(); } + + /// Copies the current version of the file to GCS. + Status Sync() override { + TF_RETURN_IF_ERROR(CheckWritable()); + outfile_.flush(); + string auth_token; + TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_, &auth_token)); + + std::unique_ptr request(http_request_factory_->Create()); + TF_RETURN_IF_ERROR(request->Init()); + TF_RETURN_IF_ERROR( + request->SetUri(strings::StrCat(kGcsUploadUriBase, "b/", bucket_, + "/o?uploadType=media&name=", object_))); + TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); + TF_RETURN_IF_ERROR(request->SetPostRequest(tmp_content_filename_)); + TF_RETURN_IF_ERROR(request->Send()); + return Status::OK(); + } + + private: + Status CheckWritable() const { + if (!outfile_.is_open()) { + return errors::FailedPrecondition( + "The underlying tmp file is not writable."); + } + return Status::OK(); + } + + string bucket_; + string object_; + AuthProvider* auth_provider_; + string tmp_content_filename_; + std::ofstream outfile_; + HttpRequest::Factory* http_request_factory_; +}; + +class GcsReadOnlyMemoryRegion : public ReadOnlyMemoryRegion { + public: + GcsReadOnlyMemoryRegion(std::unique_ptr data, uint64 length) + : data_(std::move(data)), length_(length) {} + const void* data() override { return reinterpret_cast(data_.get()); } + uint64 length() override { return length_; } + + private: + std::unique_ptr data_; + uint64 length_; +}; +} // namespace + +GcsFileSystem::GcsFileSystem() + : auth_provider_(new EmptyAuthProvider()), + http_request_factory_(new HttpRequest::Factory()) {} + +GcsFileSystem::GcsFileSystem( + std::unique_ptr auth_provider, + std::unique_ptr http_request_factory) + : auth_provider_(std::move(auth_provider)), + http_request_factory_(std::move(http_request_factory)) {} + +Status GcsFileSystem::NewRandomAccessFile(const string& fname, + RandomAccessFile** result) { + string bucket, object; + TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object)); + *result = new GcsRandomAccessFile(bucket, object, auth_provider_.get(), + http_request_factory_.get()); + return Status::OK(); +} + +Status GcsFileSystem::NewWritableFile(const string& fname, + WritableFile** result) { + string bucket, object; + TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object)); + *result = new GcsWritableFile(bucket, object, auth_provider_.get(), + http_request_factory_.get()); + return Status::OK(); +} + +// Reads the file from GCS in chunks and stores it in a tmp file, +// which is then passed to GcsWritableFile. +Status GcsFileSystem::NewAppendableFile(const string& fname, + WritableFile** result) { + RandomAccessFile* reader_ptr; + TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &reader_ptr)); + std::unique_ptr reader(reader_ptr); + std::unique_ptr buffer(new char[kBufferSize]); + Status status; + uint64 offset = 0; + StringPiece read_chunk; + + // Read the file from GCS in chunks and save it to a tmp file. + string old_content_filename; + TF_RETURN_IF_ERROR(GetTmpFilename(&old_content_filename)); + std::ofstream old_content(old_content_filename, std::ofstream::binary); + while (true) { + status = reader->Read(offset, kBufferSize, &read_chunk, buffer.get()); + if (status.ok()) { + old_content << read_chunk; + offset += kBufferSize; + } else if (status.code() == error::OUT_OF_RANGE) { + // Expected, this means we reached EOF. + old_content << read_chunk; + break; + } else { + return status; + } + } + old_content.close(); + + // Create a writable file and pass the old content to it. + string bucket, object; + TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object)); + *result = + new GcsWritableFile(bucket, object, auth_provider_.get(), + old_content_filename, http_request_factory_.get()); + return Status::OK(); +} + +Status GcsFileSystem::NewReadOnlyMemoryRegionFromFile( + const string& fname, ReadOnlyMemoryRegion** result) { + uint64 size; + TF_RETURN_IF_ERROR(GetFileSize(fname, &size)); + std::unique_ptr data(new char[size]); + + RandomAccessFile* file; + TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &file)); + std::unique_ptr file_ptr(file); + + StringPiece piece; + TF_RETURN_IF_ERROR(file->Read(0, size, &piece, data.get())); + + *result = new GcsReadOnlyMemoryRegion(std::move(data), size); + return Status::OK(); +} + +bool GcsFileSystem::FileExists(const string& fname) { + string bucket, object_prefix; + if (!ParseGcsPath(fname, &bucket, &object_prefix).ok()) { + LOG(ERROR) << "Could not parse GCS file name " << fname; + return false; + } + + string auth_token; + if (!GetAuthToken(auth_provider_.get(), &auth_token).ok()) { + LOG(ERROR) << "Could not get an auth token."; + return false; + } + + std::unique_ptr request(http_request_factory_->Create()); + if (!request->Init().ok()) { + LOG(ERROR) << "Could not initialize the HTTP request."; + return false; + } + request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket, "/o/", + object_prefix, "?fields=size")); + request->AddAuthBearerHeader(auth_token); + return request->Send().ok(); +} + +Status GcsFileSystem::GetChildren(const string& dirname, + std::vector* result) { + if (!result) { + return errors::InvalidArgument("'result' cannot be null"); + } + string sanitized_dirname = dirname; + if (!dirname.empty() && dirname.back() != '/') { + sanitized_dirname += "/"; + } + string bucket, object_prefix; + TF_RETURN_IF_ERROR(ParseGcsPath(sanitized_dirname, &bucket, &object_prefix)); + + string auth_token; + TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_.get(), &auth_token)); + + std::unique_ptr scratch(new char[kBufferSize]); + StringPiece response_piece; + std::unique_ptr request(http_request_factory_->Create()); + TF_RETURN_IF_ERROR(request->Init()); + TF_RETURN_IF_ERROR( + request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket, "/o?prefix=", + object_prefix, "&fields=items"))); + TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); + // TODO(surkov): Implement pagination using maxResults and pageToken + // instead, so that all items can be read regardless of their count. + // Currently one item takes about 1KB in the response, so with a 1MB + // buffer size this will read fewer than 1000 objects. + TF_RETURN_IF_ERROR( + request->SetResultBuffer(scratch.get(), kBufferSize, &response_piece)); + TF_RETURN_IF_ERROR(request->Send()); + std::stringstream response_stream; + response_stream << response_piece; + Json::Value root; + Json::Reader reader; + if (!reader.parse(response_stream.str(), root)) { + return errors::Internal("Couldn't parse JSON response from GCS."); + } + const auto items = root.get("items", Json::Value::null); + if (items == Json::Value::null) { + // Empty results. + return Status::OK(); + } + if (!items.isArray()) { + return errors::Internal("Expected an array 'items' in the GCS response."); + } + for (size_t i = 0; i < items.size(); i++) { + const auto item = items.get(i, Json::Value::null); + if (!item.isObject()) { + return errors::Internal( + "Unexpected JSON format: 'items' should be a list of objects."); + } + const auto name = item.get("name", Json::Value::null); + if (name == Json::Value::null || !name.isString()) { + return errors::Internal( + "Unexpected JSON format: 'items.name' is missing or not a string."); + } + result->push_back( + strings::StrCat("gs://", bucket, "/", name.asString().c_str())); + } + return Status::OK(); +} + +Status GcsFileSystem::DeleteFile(const string& fname) { + string bucket, object; + TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object)); + + string auth_token; + TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_.get(), &auth_token)); + + std::unique_ptr request(http_request_factory_->Create()); + TF_RETURN_IF_ERROR(request->Init()); + TF_RETURN_IF_ERROR(request->SetUri( + strings::StrCat(kGcsUriBase, "b/", bucket, "/o/", object))); + TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); + TF_RETURN_IF_ERROR(request->SetDeleteRequest()); + TF_RETURN_IF_ERROR(request->Send()); + return Status::OK(); +} + +// Does nothing, because directories are not entities in GCS. +Status GcsFileSystem::CreateDir(const string& dirname) { return Status::OK(); } + +// Checks that the directory is empty (i.e no objects with this prefix exist). +// If it is, does nothing, because directories are not entities in GCS. +Status GcsFileSystem::DeleteDir(const string& dirname) { + string sanitized_dirname = dirname; + if (!dirname.empty() && dirname.back() != '/') { + sanitized_dirname += "/"; + } + std::vector children; + TF_RETURN_IF_ERROR(GetChildren(sanitized_dirname, &children)); + if (!children.empty()) { + return errors::InvalidArgument("Cannot delete a non-empty directory."); + } + return Status::OK(); +} + +Status GcsFileSystem::GetFileSize(const string& fname, uint64* file_size) { + string bucket, object_prefix; + TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object_prefix)); + + string auth_token; + TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_.get(), &auth_token)); + + std::unique_ptr scratch(new char[kBufferSize]); + StringPiece response_piece; + + std::unique_ptr request(http_request_factory_->Create()); + TF_RETURN_IF_ERROR(request->Init()); + TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( + kGcsUriBase, "b/", bucket, "/o/", object_prefix, "?fields=size"))); + TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); + TF_RETURN_IF_ERROR( + request->SetResultBuffer(scratch.get(), kBufferSize, &response_piece)); + TF_RETURN_IF_ERROR(request->Send()); + std::stringstream response_stream; + response_stream << response_piece; + + Json::Value root; + Json::Reader reader; + if (!reader.parse(response_stream.str(), root)) { + return errors::Internal("Couldn't parse JSON response from GCS."); + } + const auto size = root.get("size", Json::Value::null); + if (size == Json::Value::null) { + return errors::Internal("'size' was expected in the JSON response."); + } + if (size.isNumeric()) { + *file_size = size.asUInt64(); + } else if (size.isString()) { + if (!strings::safe_strtou64(size.asString().c_str(), file_size)) { + return errors::Internal("'size' couldn't be parsed as a nubmer."); + } + } else { + return errors::Internal("'size' is not a number in the JSON response."); + } + return Status::OK(); +} + +// Uses a GCS API command to copy the object and then deletes the old one. +Status GcsFileSystem::RenameFile(const string& src, const string& target) { + string src_bucket, src_object, target_bucket, target_object; + TF_RETURN_IF_ERROR(ParseGcsPath(src, &src_bucket, &src_object)); + TF_RETURN_IF_ERROR(ParseGcsPath(target, &target_bucket, &target_object)); + + string auth_token; + TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_.get(), &auth_token)); + + std::unique_ptr request(http_request_factory_->Create()); + TF_RETURN_IF_ERROR(request->Init()); + TF_RETURN_IF_ERROR(request->SetUri( + strings::StrCat(kGcsUriBase, "b/", src_bucket, "/o/", src_object, + "/rewriteTo/b/", target_bucket, "/o/", target_object))); + TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); + TF_RETURN_IF_ERROR(request->SetPostRequest()); + TF_RETURN_IF_ERROR(request->Send()); + + TF_RETURN_IF_ERROR(DeleteFile(src)); + return Status::OK(); +} + +REGISTER_FILE_SYSTEM("gs", GcsFileSystem); + +} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h new file mode 100644 index 00000000000..47c22173de5 --- /dev/null +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -0,0 +1,73 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_GCS_FILE_SYSTEM_H_ +#define TENSORFLOW_CORE_PLATFORM_GCS_FILE_SYSTEM_H_ + +#include +#include +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/cloud/http_request.h" +#include "tensorflow/core/platform/file_system.h" + +namespace tensorflow { + +/// Interface for a provider of HTTP auth bearer tokens. +class AuthProvider { + public: + virtual ~AuthProvider() {} + virtual Status GetToken(string* t) const = 0; +}; + +/// Google Cloud Storage implementation of a file system. +class GcsFileSystem : public FileSystem { + public: + GcsFileSystem(); + GcsFileSystem(std::unique_ptr auth_provider, + std::unique_ptr http_request_factory); + + Status NewRandomAccessFile(const string& fname, + RandomAccessFile** result) override; + + Status NewWritableFile(const string& fname, WritableFile** result) override; + + Status NewAppendableFile(const string& fname, WritableFile** result) override; + + Status NewReadOnlyMemoryRegionFromFile( + const string& fname, ReadOnlyMemoryRegion** result) override; + + bool FileExists(const string& fname) override; + + Status GetChildren(const string& dir, std::vector* result) override; + + Status DeleteFile(const string& fname) override; + + Status CreateDir(const string& dirname) override; + + Status DeleteDir(const string& dirname) override; + + Status GetFileSize(const string& fname, uint64* file_size) override; + + Status RenameFile(const string& src, const string& target) override; + + private: + std::unique_ptr auth_provider_; + std::unique_ptr http_request_factory_; + TF_DISALLOW_COPY_AND_ASSIGN(GcsFileSystem); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_GCS_FILE_SYSTEM_H_ diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc new file mode 100644 index 00000000000..151aebd87c5 --- /dev/null +++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc @@ -0,0 +1,363 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/cloud/gcs_file_system.h" +#include +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class FakeHttpRequest : public HttpRequest { + public: + FakeHttpRequest(const string& request, const string& response) + : FakeHttpRequest(request, response, Status::OK()) {} + + FakeHttpRequest(const string& request, const string& response, + Status response_status) + : expected_request_(request), + response_(response), + response_status_(response_status) {} + + Status Init() override { return Status::OK(); } + Status SetUri(const string& uri) override { + actual_request_ += "Uri: " + uri + "\n"; + return Status::OK(); + } + Status SetRange(uint64 start, uint64 end) override { + actual_request_ += strings::StrCat("Range: ", start, "-", end, "\n"); + return Status::OK(); + } + Status AddAuthBearerHeader(const string& auth_token) override { + actual_request_ += "Auth Token: " + auth_token + "\n"; + return Status::OK(); + } + Status SetDeleteRequest() override { + actual_request_ += "Delete: yes\n"; + return Status::OK(); + } + Status SetPostRequest(const string& body_filepath) override { + std::ifstream stream(body_filepath); + string content((std::istreambuf_iterator(stream)), + std::istreambuf_iterator()); + actual_request_ += "Post body: " + content + "\n"; + return Status::OK(); + } + Status SetPostRequest() override { + actual_request_ += "Post: yes\n"; + return Status::OK(); + } + Status SetResultBuffer(char* scratch, size_t size, + StringPiece* result) override { + scratch_ = scratch; + size_ = size; + result_ = result; + return Status::OK(); + } + Status Send() override { + EXPECT_EQ(expected_request_, actual_request_) << "Unexpected HTTP request."; + if (scratch_ && result_) { + auto actual_size = std::min(response_.size(), size_); + memcpy(scratch_, response_.c_str(), actual_size); + *result_ = StringPiece(scratch_, actual_size); + } + return response_status_; + } + + private: + char* scratch_ = nullptr; + size_t size_ = 0; + StringPiece* result_ = nullptr; + string expected_request_; + string actual_request_; + string response_; + Status response_status_; +}; + +class FakeHttpRequestFactory : public HttpRequest::Factory { + public: + FakeHttpRequestFactory(const std::vector* requests) + : requests_(requests) {} + + ~FakeHttpRequestFactory() { + EXPECT_EQ(current_index_, requests_->size()) + << "Not all expected requests were made."; + } + + HttpRequest* Create() override { + EXPECT_LT(current_index_, requests_->size()) + << "Too many calls of HttpRequest factory."; + return (*requests_)[current_index_++]; + } + + private: + const std::vector* requests_; + int current_index_ = 0; +}; + +class FakeAuthProvider : public AuthProvider { + public: + Status GetToken(string* token) const override { + *token = "fake_token"; + return Status::OK(); + } +}; + +TEST(GcsFileSystemTest, NewRandomAccessFile) { + std::vector requests( + {new FakeHttpRequest( + "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Auth Token: fake_token\n" + "Range: 0-5\n", + "012345"), + new FakeHttpRequest( + "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Auth Token: fake_token\n" + "Range: 6-11\n", + "6789")}); + GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests))); + + RandomAccessFile* file_ptr; + TF_EXPECT_OK( + fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file_ptr)); + std::unique_ptr file(file_ptr); + + char scratch[6]; + StringPiece result; + + // Read the first chunk. + TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch)); + EXPECT_EQ("012345", result); + + // Read the second chunk. + EXPECT_EQ( + errors::Code::OUT_OF_RANGE, + file->Read(sizeof(scratch), sizeof(scratch), &result, scratch).code()); + EXPECT_EQ("6789", result); +} + +TEST(GcsFileSystemTest, NewWritableFile) { + std::vector requests({new FakeHttpRequest( + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "uploadType=media&name=path/writeable.txt\n" + "Auth Token: fake_token\n" + "Post body: content1,content2\n", + "")}); + GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests))); + + WritableFile* file_ptr; + TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file_ptr)); + std::unique_ptr file(file_ptr); + + TF_EXPECT_OK(file->Append("content1,")); + TF_EXPECT_OK(file->Append("content2")); + TF_EXPECT_OK(file->Close()); +} + +TEST(GcsFileSystemTest, NewAppendableFile) { + std::vector requests( + {new FakeHttpRequest( + "Uri: https://bucket.storage.googleapis.com/path/appendable.txt\n" + "Auth Token: fake_token\n" + "Range: 0-1048575\n", + "content1,"), + new FakeHttpRequest( + "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" + "uploadType=media&name=path/appendable.txt\n" + "Auth Token: fake_token\n" + "Post body: content1,content2\n", + "")}); + GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests))); + + WritableFile* file_ptr; + TF_EXPECT_OK( + fs.NewAppendableFile("gs://bucket/path/appendable.txt", &file_ptr)); + std::unique_ptr file(file_ptr); + + TF_EXPECT_OK(file->Append("content2")); + TF_EXPECT_OK(file->Close()); +} + +TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) { + const string content = "file content"; + std::vector requests( + {new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "random_access.txt?fields=size\n" + "Auth Token: fake_token\n", + strings::StrCat("{\"size\": \"", content.size(), "\"}")), + new FakeHttpRequest( + strings::StrCat( + "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Auth Token: fake_token\n" + "Range: 0-", + content.size() - 1, "\n"), + content)}); + GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests))); + + ReadOnlyMemoryRegion* region_ptr; + TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile( + "gs://bucket/random_access.txt", ®ion_ptr)); + std::unique_ptr region(region_ptr); + + EXPECT_EQ(content, StringPiece(reinterpret_cast(region->data()), + region->length())); +} + +TEST(GcsFileSystemTest, FileExists) { + std::vector requests( + {new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "path/file1.txt?fields=size\n" + "Auth Token: fake_token\n", + "{\"size\": \"100\"}"), + new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "path/file2.txt?fields=size\n" + "Auth Token: fake_token\n", + "", errors::NotFound("404"))}); + GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests))); + + EXPECT_TRUE(fs.FileExists("gs://bucket/path/file1.txt")); + EXPECT_FALSE(fs.FileExists("gs://bucket/path/file2.txt")); +} + +TEST(GcsFileSystemTest, GetChildren_ThreeFiles) { + std::vector requests({new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "prefix=path/&fields=items\n" + "Auth Token: fake_token\n", + "{\"items\": [ " + " { \"name\": \"path/file1.txt\" }," + " { \"name\": \"path/subpath/file2.txt\" }," + " { \"name\": \"path/file3.txt\" }]}")}); + GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests))); + + std::vector children; + TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children)); + + EXPECT_EQ(3, children.size()); + EXPECT_EQ("gs://bucket/path/file1.txt", children[0]); + EXPECT_EQ("gs://bucket/path/subpath/file2.txt", children[1]); + EXPECT_EQ("gs://bucket/path/file3.txt", children[2]); +} + +TEST(GcsFileSystemTest, GetChildren_Empty) { + std::vector requests({new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "prefix=path/&fields=items\n" + "Auth Token: fake_token\n", + "{}")}); + GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests))); + + std::vector children; + TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children)); + + EXPECT_EQ(0, children.size()); +} + +TEST(GcsFileSystemTest, DeleteFile) { + std::vector requests( + {new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b" + "/bucket/o/path/file1.txt\n" + "Auth Token: fake_token\n" + "Delete: yes\n", + "")}); + GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests))); + + TF_EXPECT_OK(fs.DeleteFile("gs://bucket/path/file1.txt")); +} + +TEST(GcsFileSystemTest, DeleteDir_Empty) { + std::vector requests({new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "prefix=path/&fields=items\n" + "Auth Token: fake_token\n", + "{}")}); + GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests))); + + TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/")); +} + +TEST(GcsFileSystemTest, DeleteDir_NonEmpty) { + std::vector requests({new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "prefix=path/&fields=items\n" + "Auth Token: fake_token\n", + "{\"items\": [ " + " { \"name\": \"path/file1.txt\" }]}")}); + GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests))); + + EXPECT_FALSE(fs.DeleteDir("gs://bucket/path/").ok()); +} + +TEST(GcsFileSystemTest, GetFileSize) { + std::vector requests({new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "file.txt?fields=size\n" + "Auth Token: fake_token\n", + strings::StrCat("{\"size\": \"1010\"}"))}); + GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests))); + + uint64 size; + TF_EXPECT_OK(fs.GetFileSize("gs://bucket/file.txt", &size)); + EXPECT_EQ(1010, size); +} + +TEST(GcsFileSystemTest, RenameFile) { + std::vector requests( + {new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/src.txt" + "/rewriteTo/b/bucket/o/dst.txt\n" + "Auth Token: fake_token\n" + "Post: yes\n", + ""), + new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/src.txt\n" + "Auth Token: fake_token\n" + "Delete: yes\n", + "")}); + GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests))); + + TF_EXPECT_OK(fs.RenameFile("gs://bucket/src.txt", "gs://bucket/dst.txt")); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/http_request.cc b/tensorflow/core/platform/cloud/http_request.cc new file mode 100644 index 00000000000..38f132e723a --- /dev/null +++ b/tensorflow/core/platform/cloud/http_request.cc @@ -0,0 +1,410 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/cloud/http_request.h" +#include +#include +#include +#include +#include +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/scanner.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace { + +// Windows is not currently supported. +constexpr char kCurlLibLinux[] = "libcurl.so.3"; +constexpr char kCurlLibMac[] = "/usr/lib/libcurl.3.dylib"; + +constexpr char kCertsPath[] = "/etc/ssl/certs"; + +// Set to 1 to enable verbose debug output from curl. +constexpr uint64 kVerboseOutput = 0; + +/// An implementation that dynamically loads libcurl and forwards calls to it. +class LibCurlProxy : public LibCurl { + public: + ~LibCurlProxy() { + if (dll_handle_) { + dlclose(dll_handle_); + } + } + + Status MaybeLoadDll() override { + if (dll_handle_) { + return Status::OK(); + } + // This may have been linked statically; if curl_easy_init is in the + // current binary, no need to search for a dynamic version. + dll_handle_ = load_dll(nullptr); + if (!dll_handle_) { + dll_handle_ = load_dll(kCurlLibLinux); + } + if (!dll_handle_) { + dll_handle_ = load_dll(kCurlLibMac); + } + if (!dll_handle_) { + return errors::FailedPrecondition(strings::StrCat( + "Could not initialize the libcurl library. Please make sure that " + "libcurl is installed in the OS or statically linked to the " + "TensorFlow binary.")); + } + curl_global_init_(CURL_GLOBAL_ALL); + return Status::OK(); + } + + CURL* curl_easy_init() override { + CHECK(dll_handle_); + return curl_easy_init_(); + } + + CURLcode curl_easy_setopt(CURL* curl, CURLoption option, + uint64 param) override { + CHECK(dll_handle_); + return curl_easy_setopt_(curl, option, param); + } + + CURLcode curl_easy_setopt(CURL* curl, CURLoption option, + const char* param) override { + CHECK(dll_handle_); + return curl_easy_setopt_(curl, option, param); + } + CURLcode curl_easy_setopt(CURL* curl, CURLoption option, + void* param) override { + CHECK(dll_handle_); + return curl_easy_setopt_(curl, option, param); + } + CURLcode curl_easy_setopt(CURL* curl, CURLoption option, + size_t (*param)(void*, size_t, size_t, + FILE*)) override { + CHECK(dll_handle_); + return curl_easy_setopt_(curl, option, param); + } + CURLcode curl_easy_setopt(CURL* curl, CURLoption option, + size_t (*param)(const void*, size_t, size_t, + void*)) override { + CHECK(dll_handle_); + return curl_easy_setopt_(curl, option, param); + } + + CURLcode curl_easy_perform(CURL* curl) override { + CHECK(dll_handle_); + return curl_easy_perform_(curl); + } + + CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info, + uint64* value) override { + CHECK(dll_handle_); + return curl_easy_getinfo_(curl, info, value); + } + CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info, + double* value) override { + CHECK(dll_handle_); + return curl_easy_getinfo_(curl, info, value); + } + void curl_easy_cleanup(CURL* curl) override { + CHECK(dll_handle_); + return curl_easy_cleanup_(curl); + } + + curl_slist* curl_slist_append(curl_slist* list, const char* str) override { + CHECK(dll_handle_); + return curl_slist_append_(list, str); + } + + void curl_slist_free_all(curl_slist* list) override { + CHECK(dll_handle_); + return curl_slist_free_all_(list); + } + + private: + // Loads the dynamic library and binds the required methods. + // Returns the library handle in case of success or nullptr otherwise. + // 'name' can be nullptr. + void* load_dll(const char* name) { + void* handle = nullptr; + handle = dlopen(name, RTLD_NOW | RTLD_LOCAL | RTLD_NODELETE); + if (!handle) { + return nullptr; + } + +#define BIND_CURL_FUNC(function) \ + *reinterpret_cast(&(function##_)) = dlsym(handle, #function) + + BIND_CURL_FUNC(curl_global_init); + BIND_CURL_FUNC(curl_easy_init); + BIND_CURL_FUNC(curl_easy_setopt); + BIND_CURL_FUNC(curl_easy_perform); + BIND_CURL_FUNC(curl_easy_getinfo); + BIND_CURL_FUNC(curl_slist_append); + BIND_CURL_FUNC(curl_slist_free_all); + BIND_CURL_FUNC(curl_easy_cleanup); + +#undef BIND_CURL_FUNC + + if (curl_global_init_ == nullptr) { + dlerror(); // Clear dlerror before attempting to open libraries. + dlclose(handle); + return nullptr; + } + return handle; + } + + void* dll_handle_ = nullptr; + CURLcode (*curl_global_init_)(int64) = nullptr; + CURL* (*curl_easy_init_)(void) = nullptr; + CURLcode (*curl_easy_setopt_)(CURL*, CURLoption, ...) = nullptr; + CURLcode (*curl_easy_perform_)(CURL* curl) = nullptr; + CURLcode (*curl_easy_getinfo_)(CURL* curl, CURLINFO info, ...) = nullptr; + void (*curl_easy_cleanup_)(CURL* curl) = nullptr; + curl_slist* (*curl_slist_append_)(curl_slist* list, + const char* str) = nullptr; + void (*curl_slist_free_all_)(curl_slist* list) = nullptr; +}; +} // namespace + +HttpRequest::HttpRequest() + : HttpRequest(std::unique_ptr(new LibCurlProxy)) {} + +HttpRequest::HttpRequest(std::unique_ptr libcurl) + : libcurl_(std::move(libcurl)), + default_response_buffer_(new char[CURL_MAX_WRITE_SIZE]) {} + +HttpRequest::~HttpRequest() { + if (curl_headers_) { + libcurl_->curl_slist_free_all(curl_headers_); + } + if (post_body_) { + fclose(post_body_); + } + if (curl_) { + libcurl_->curl_easy_cleanup(curl_); + } +} + +Status HttpRequest::Init() { + if (!libcurl_) { + return errors::Internal("libcurl proxy cannot be nullptr."); + } + TF_RETURN_IF_ERROR(libcurl_->MaybeLoadDll()); + curl_ = libcurl_->curl_easy_init(); + if (!curl_) { + return errors::Internal("Couldn't initialize a curl session."); + } + + libcurl_->curl_easy_setopt(curl_, CURLOPT_VERBOSE, kVerboseOutput); + libcurl_->curl_easy_setopt(curl_, CURLOPT_CAPATH, kCertsPath); + + // If response buffer is not set, libcurl will print results to stdout, + // so we always set it. + is_initialized_ = true; + auto s = SetResultBuffer(default_response_buffer_.get(), CURL_MAX_WRITE_SIZE, + &default_response_string_piece_); + if (!s.ok()) { + is_initialized_ = false; + return s; + } + return Status::OK(); +} + +Status HttpRequest::SetUri(const string& uri) { + TF_RETURN_IF_ERROR(CheckInitialized()); + TF_RETURN_IF_ERROR(CheckNotSent()); + is_uri_set_ = true; + libcurl_->curl_easy_setopt(curl_, CURLOPT_URL, uri.c_str()); + return Status::OK(); +} + +Status HttpRequest::SetRange(uint64 start, uint64 end) { + TF_RETURN_IF_ERROR(CheckInitialized()); + TF_RETURN_IF_ERROR(CheckNotSent()); + libcurl_->curl_easy_setopt(curl_, CURLOPT_RANGE, + strings::StrCat(start, "-", end).c_str()); + return Status::OK(); +} + +Status HttpRequest::AddAuthBearerHeader(const string& auth_token) { + TF_RETURN_IF_ERROR(CheckInitialized()); + TF_RETURN_IF_ERROR(CheckNotSent()); + if (!auth_token.empty()) { + curl_headers_ = libcurl_->curl_slist_append( + curl_headers_, + strings::StrCat("Authorization: Bearer ", auth_token).c_str()); + } + return Status::OK(); +} + +Status HttpRequest::SetDeleteRequest() { + TF_RETURN_IF_ERROR(CheckInitialized()); + TF_RETURN_IF_ERROR(CheckNotSent()); + TF_RETURN_IF_ERROR(CheckMethodNotSet()); + is_method_set_ = true; + libcurl_->curl_easy_setopt(curl_, CURLOPT_CUSTOMREQUEST, "DELETE"); + return Status::OK(); +} + +Status HttpRequest::SetPostRequest(const string& body_filepath) { + TF_RETURN_IF_ERROR(CheckInitialized()); + TF_RETURN_IF_ERROR(CheckNotSent()); + TF_RETURN_IF_ERROR(CheckMethodNotSet()); + is_method_set_ = true; + if (post_body_) { + fclose(post_body_); + } + post_body_ = fopen(body_filepath.c_str(), "r"); + if (!post_body_) { + return errors::InvalidArgument("Couldnt' open the specified file: " + + body_filepath); + } + fseek(post_body_, 0, SEEK_END); + const auto size = ftell(post_body_); + fseek(post_body_, 0, SEEK_SET); + + curl_headers_ = libcurl_->curl_slist_append( + curl_headers_, strings::StrCat("Content-Length: ", size).c_str()); + libcurl_->curl_easy_setopt(curl_, CURLOPT_POST, 1); + libcurl_->curl_easy_setopt(curl_, CURLOPT_READDATA, + reinterpret_cast(post_body_)); + return Status::OK(); +} + +Status HttpRequest::SetPostRequest() { + TF_RETURN_IF_ERROR(CheckInitialized()); + TF_RETURN_IF_ERROR(CheckNotSent()); + TF_RETURN_IF_ERROR(CheckMethodNotSet()); + is_method_set_ = true; + libcurl_->curl_easy_setopt(curl_, CURLOPT_POST, 1); + curl_headers_ = + libcurl_->curl_slist_append(curl_headers_, "Content-Length: 0"); + return Status::OK(); +} + +Status HttpRequest::SetResultBuffer(char* scratch, size_t size, + StringPiece* result) { + TF_RETURN_IF_ERROR(CheckInitialized()); + TF_RETURN_IF_ERROR(CheckNotSent()); + if (!scratch) { + return errors::InvalidArgument("scratch cannot be null"); + } + if (!result) { + return errors::InvalidArgument("result cannot be null"); + } + if (size <= 0) { + return errors::InvalidArgument("buffer size should be positive"); + } + + response_buffer_ = scratch; + response_buffer_size_ = size; + response_string_piece_ = result; + response_buffer_written_ = 0; + + libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEDATA, + reinterpret_cast(this)); + libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEFUNCTION, + &HttpRequest::WriteCallback); + return Status::OK(); +} + +size_t HttpRequest::WriteCallback(const void* ptr, size_t size, size_t nmemb, + void* this_object) { + CHECK(ptr); + auto that = reinterpret_cast(this_object); + CHECK(that->response_buffer_); + CHECK(that->response_buffer_size_ >= that->response_buffer_written_); + const size_t bytes_to_copy = + std::min(size * nmemb, + that->response_buffer_size_ - that->response_buffer_written_); + memcpy(that->response_buffer_ + that->response_buffer_written_, ptr, + bytes_to_copy); + that->response_buffer_written_ += bytes_to_copy; + return bytes_to_copy; +} + +Status HttpRequest::Send() { + TF_RETURN_IF_ERROR(CheckInitialized()); + TF_RETURN_IF_ERROR(CheckNotSent()); + is_sent_ = true; + if (!is_uri_set_) { + return errors::FailedPrecondition("URI has not been set."); + } + if (curl_headers_) { + libcurl_->curl_easy_setopt(curl_, CURLOPT_HTTPHEADER, curl_headers_); + } + + char error_buffer[CURL_ERROR_SIZE]; + libcurl_->curl_easy_setopt(curl_, CURLOPT_ERRORBUFFER, error_buffer); + + const auto curl_result = libcurl_->curl_easy_perform(curl_); + + double written_size = 0; + libcurl_->curl_easy_getinfo(curl_, CURLINFO_SIZE_DOWNLOAD, &written_size); + + uint64 response_code; + libcurl_->curl_easy_getinfo(curl_, CURLINFO_RESPONSE_CODE, &response_code); + + if (curl_result != CURLE_OK) { + return errors::Internal(string("curl error: ") + error_buffer); + } + switch (response_code) { + case 200: // OK + case 204: // No Content + case 206: // Partial Content + if (response_buffer_ && response_string_piece_) { + *response_string_piece_ = StringPiece(response_buffer_, written_size); + } + return Status::OK(); + case 401: + return errors::PermissionDenied( + "Not authorized to access the given HTTP resource."); + case 404: + return errors::NotFound("The requested URL was not found."); + case 416: // Requested Range Not Satisfiable + if (response_string_piece_) { + *response_string_piece_ = StringPiece(); + } + return Status::OK(); + default: + return errors::Internal( + strings::StrCat("Unexpected HTTP response code ", response_code)); + } +} + +Status HttpRequest::CheckInitialized() const { + if (!is_initialized_) { + return errors::FailedPrecondition("The object has not been initialized."); + } + return Status::OK(); +} + +Status HttpRequest::CheckMethodNotSet() const { + if (is_method_set_) { + return errors::FailedPrecondition("HTTP method has been already set."); + } + return Status::OK(); +} + +Status HttpRequest::CheckNotSent() const { + if (is_sent_) { + return errors::FailedPrecondition("The request has already been sent."); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/http_request.h b/tensorflow/core/platform/cloud/http_request.h new file mode 100644 index 00000000000..19aed67e6a4 --- /dev/null +++ b/tensorflow/core/platform/cloud/http_request.h @@ -0,0 +1,156 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_ +#define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_ + +#include +#include +#include +#include +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class LibCurl; // libcurl interface as a class, for dependency injection. + +/// \brief A basic HTTP client based on the libcurl library. +/// +/// The usage pattern for the class reflects the one of the libcurl library: +/// create a request object, set request parameters and call Send(). +/// +/// For example: +/// HttpRequest request; +/// request.SetUri("http://www.google.com"); +/// request.SetResultsBuffer(scratch, 1000, &result); +/// request.Send(); +class HttpRequest { + public: + class Factory { + public: + virtual ~Factory() {} + virtual HttpRequest* Create() { return new HttpRequest(); } + }; + + HttpRequest(); + explicit HttpRequest(std::unique_ptr libcurl); + virtual ~HttpRequest(); + + virtual Status Init(); + + /// Sets the request URI. + virtual Status SetUri(const string& uri); + + /// \brief Sets the Range header. + /// + /// Used for random seeks, for example "0-999" returns the first 1000 bytes + /// (note that the right border is included). + virtual Status SetRange(uint64 start, uint64 end); + + /// Sets the 'Authorization' header to the value of 'Bearer ' + auth_token. + virtual Status AddAuthBearerHeader(const string& auth_token); + + /// Makes the request a DELETE request. + virtual Status SetDeleteRequest(); + + /// \brief Makes the request a POST request. + /// + /// The request body will be taken from the specified file. + virtual Status SetPostRequest(const string& body_filepath); + + /// Makes the request a POST request. + virtual Status SetPostRequest(); + + /// \brief Specifies the buffer for receiving the response body. + /// + /// The interface is made similar to RandomAccessFile::Read. + virtual Status SetResultBuffer(char* scratch, size_t size, + StringPiece* result); + + /// \brief Sends the formed request. + /// + /// If the result buffer was defined, the response will be written there. + /// The object is not designed to be re-used after Send() is executed. + virtual Status Send(); + + private: + /// A callback in the form which can be accepted by libcurl. + static size_t WriteCallback(const void* ptr, size_t size, size_t nmemb, + void* userdata); + Status CheckInitialized() const; + Status CheckMethodNotSet() const; + Status CheckNotSent() const; + + std::unique_ptr libcurl_; + FILE* post_body_ = nullptr; + char* response_buffer_ = nullptr; + size_t response_buffer_size_ = 0; + size_t response_buffer_written_ = 0; + StringPiece* response_string_piece_ = nullptr; + CURL* curl_ = nullptr; + curl_slist* curl_headers_ = nullptr; + + std::unique_ptr default_response_buffer_; + StringPiece default_response_string_piece_; + + // Members to enforce the usage flow. + bool is_initialized_ = false; + bool is_uri_set_ = false; + bool is_method_set_ = false; + bool is_sent_ = false; + + TF_DISALLOW_COPY_AND_ASSIGN(HttpRequest); +}; + +/// \brief A proxy to the libcurl C interface as a dependency injection measure. +/// +/// This class is meant as a very thin wrapper for the libcurl C library. +class LibCurl { + public: + virtual ~LibCurl() {} + /// Lazy initialization of the dynamic libcurl library. + virtual Status MaybeLoadDll() = 0; + + virtual CURL* curl_easy_init() = 0; + virtual CURLcode curl_easy_setopt(CURL* curl, CURLoption option, + uint64 param) = 0; + virtual CURLcode curl_easy_setopt(CURL* curl, CURLoption option, + const char* param) = 0; + virtual CURLcode curl_easy_setopt(CURL* curl, CURLoption option, + void* param) = 0; + virtual CURLcode curl_easy_setopt(CURL* curl, CURLoption option, + size_t (*param)(void*, size_t, size_t, + FILE*)) = 0; + virtual CURLcode curl_easy_setopt(CURL* curl, CURLoption option, + size_t (*param)(const void*, size_t, size_t, + void*)) = 0; + virtual CURLcode curl_easy_perform(CURL* curl) = 0; + virtual CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info, + uint64* value) = 0; + virtual CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info, + double* value) = 0; + virtual void curl_easy_cleanup(CURL* curl) = 0; + virtual curl_slist* curl_slist_append(curl_slist* list, const char* str) = 0; + virtual void curl_slist_free_all(curl_slist* list) = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_ diff --git a/tensorflow/core/platform/cloud/http_request_test.cc b/tensorflow/core/platform/cloud/http_request_test.cc new file mode 100644 index 00000000000..247514c9da4 --- /dev/null +++ b/tensorflow/core/platform/cloud/http_request_test.cc @@ -0,0 +1,323 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/cloud/http_request.h" +#include +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +// A fake proxy that pretends to be libcurl. +class FakeLibCurl : public LibCurl { + public: + FakeLibCurl(const string& response_content, uint64 response_code) + : response_content(response_content), response_code(response_code) {} + Status MaybeLoadDll() override { return Status::OK(); } + CURL* curl_easy_init() override { + is_initialized = true; + // The reuslt just needs to be non-null. + return reinterpret_cast(this); + } + CURLcode curl_easy_setopt(CURL* curl, CURLoption option, + uint64 param) override { + switch (option) { + case CURLOPT_POST: + is_post = param; + break; + default: + break; + } + return CURLE_OK; + } + CURLcode curl_easy_setopt(CURL* curl, CURLoption option, + const char* param) override { + return curl_easy_setopt(curl, option, + reinterpret_cast(const_cast(param))); + } + CURLcode curl_easy_setopt(CURL* curl, CURLoption option, + void* param) override { + switch (option) { + case CURLOPT_URL: + url = reinterpret_cast(param); + break; + case CURLOPT_RANGE: + range = reinterpret_cast(param); + break; + case CURLOPT_CUSTOMREQUEST: + custom_request = reinterpret_cast(param); + break; + case CURLOPT_HTTPHEADER: + headers = reinterpret_cast*>(param); + break; + case CURLOPT_ERRORBUFFER: + error_buffer = reinterpret_cast(param); + break; + case CURLOPT_WRITEDATA: + write_data = reinterpret_cast(param); + break; + case CURLOPT_READDATA: + read_data = reinterpret_cast(param); + break; + default: + break; + } + return CURLE_OK; + } + CURLcode curl_easy_setopt(CURL* curl, CURLoption option, + size_t (*param)(void*, size_t, size_t, + FILE*)) override { + EXPECT_EQ(param, &fread) << "Expected the standard fread() function."; + return CURLE_OK; + } + CURLcode curl_easy_setopt(CURL* curl, CURLoption option, + size_t (*param)(const void*, size_t, size_t, + void*)) override { + switch (option) { + case CURLOPT_WRITEFUNCTION: + write_callback = param; + break; + default: + break; + } + return CURLE_OK; + } + CURLcode curl_easy_perform(CURL* curl) override { + if (read_data) { + char buffer[100]; + int bytes_read; + posted_content = ""; + do { + bytes_read = fread(buffer, 1, 100, read_data); + posted_content = + strings::StrCat(posted_content, StringPiece(buffer, bytes_read)); + } while (bytes_read > 0); + } + if (write_data) { + write_callback(response_content.c_str(), 1, response_content.size(), + write_data); + } + return CURLE_OK; + } + CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info, + uint64* value) override { + switch (info) { + case CURLINFO_RESPONSE_CODE: + *value = response_code; + break; + default: + break; + } + return CURLE_OK; + } + CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info, + double* value) override { + switch (info) { + case CURLINFO_SIZE_DOWNLOAD: + *value = response_content.size(); + break; + default: + break; + } + return CURLE_OK; + } + void curl_easy_cleanup(CURL* curl) override { is_cleaned_up = true; } + curl_slist* curl_slist_append(curl_slist* list, const char* str) override { + std::vector* v = list ? reinterpret_cast*>(list) + : new std::vector(); + v->push_back(str); + return reinterpret_cast(v); + } + void curl_slist_free_all(curl_slist* list) override { + delete reinterpret_cast*>(list); + } + + // Variables defining the behavior of this fake. + string response_content; + uint64 response_code; + + // Internal variables to store the libcurl state. + string url; + string range; + string custom_request; + char* error_buffer = nullptr; + bool is_initialized = false; + bool is_cleaned_up = false; + std::vector* headers = nullptr; + FILE* read_data = nullptr; + bool is_post = false; + void* write_data = nullptr; + size_t (*write_callback)(const void* ptr, size_t size, size_t nmemb, + void* userdata) = nullptr; + // Outcome of performing the request. + string posted_content; +}; + +TEST(HttpRequestTest, GetRequest) { + FakeLibCurl* libcurl = new FakeLibCurl("get response", 200); + HttpRequest http_request((std::unique_ptr(libcurl))); + TF_EXPECT_OK(http_request.Init()); + + char scratch[100] = "random original scratch content"; + StringPiece result = "random original string piece"; + + TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com")); + TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer")); + TF_EXPECT_OK(http_request.SetRange(100, 199)); + TF_EXPECT_OK(http_request.SetResultBuffer(scratch, 100, &result)); + TF_EXPECT_OK(http_request.Send()); + + EXPECT_EQ("get response", result); + + // Check interactions with libcurl. + EXPECT_TRUE(libcurl->is_initialized); + EXPECT_EQ("http://www.testuri.com", libcurl->url); + EXPECT_EQ("100-199", libcurl->range); + EXPECT_EQ("", libcurl->custom_request); + EXPECT_EQ(1, libcurl->headers->size()); + EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl->headers)[0]); + EXPECT_FALSE(libcurl->is_post); +} + +TEST(HttpRequestTest, PostRequest_WithBody) { + FakeLibCurl* libcurl = new FakeLibCurl("", 200); + HttpRequest http_request((std::unique_ptr(libcurl))); + TF_EXPECT_OK(http_request.Init()); + + auto content_filename = io::JoinPath(testing::TmpDir(), "content"); + std::ofstream content(content_filename, std::ofstream::binary); + content << "post body content"; + content.close(); + + TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com")); + TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer")); + TF_EXPECT_OK(http_request.SetPostRequest(content_filename)); + TF_EXPECT_OK(http_request.Send()); + + // Check interactions with libcurl. + EXPECT_TRUE(libcurl->is_initialized); + EXPECT_EQ("http://www.testuri.com", libcurl->url); + EXPECT_EQ("", libcurl->custom_request); + EXPECT_EQ(2, libcurl->headers->size()); + EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl->headers)[0]); + EXPECT_EQ("Content-Length: 17", (*libcurl->headers)[1]); + EXPECT_TRUE(libcurl->is_post); + EXPECT_EQ("post body content", libcurl->posted_content); + + std::remove(content_filename.c_str()); +} + +TEST(HttpRequestTest, PostRequest_WithoutBody) { + FakeLibCurl* libcurl = new FakeLibCurl("", 200); + HttpRequest http_request((std::unique_ptr(libcurl))); + TF_EXPECT_OK(http_request.Init()); + + TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com")); + TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer")); + TF_EXPECT_OK(http_request.SetPostRequest()); + TF_EXPECT_OK(http_request.Send()); + + // Check interactions with libcurl. + EXPECT_TRUE(libcurl->is_initialized); + EXPECT_EQ("http://www.testuri.com", libcurl->url); + EXPECT_EQ("", libcurl->custom_request); + EXPECT_EQ(2, libcurl->headers->size()); + EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl->headers)[0]); + EXPECT_EQ("Content-Length: 0", (*libcurl->headers)[1]); + EXPECT_TRUE(libcurl->is_post); + EXPECT_EQ("", libcurl->posted_content); +} + +TEST(HttpRequestTest, DeleteRequest) { + FakeLibCurl* libcurl = new FakeLibCurl("", 200); + HttpRequest http_request((std::unique_ptr(libcurl))); + TF_EXPECT_OK(http_request.Init()); + + TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com")); + TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer")); + TF_EXPECT_OK(http_request.SetDeleteRequest()); + TF_EXPECT_OK(http_request.Send()); + + // Check interactions with libcurl. + EXPECT_TRUE(libcurl->is_initialized); + EXPECT_EQ("http://www.testuri.com", libcurl->url); + EXPECT_EQ("DELETE", libcurl->custom_request); + EXPECT_EQ(1, libcurl->headers->size()); + EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl->headers)[0]); + EXPECT_FALSE(libcurl->is_post); +} + +TEST(HttpRequestTest, WrongSequenceOfCalls_NoUri) { + FakeLibCurl* libcurl = new FakeLibCurl("", 200); + HttpRequest http_request((std::unique_ptr(libcurl))); + TF_EXPECT_OK(http_request.Init()); + + auto s = http_request.Send(); + ASSERT_TRUE(errors::IsFailedPrecondition(s)); + EXPECT_TRUE(StringPiece(s.error_message()).contains("URI has not been set")); +} + +TEST(HttpRequestTest, WrongSequenceOfCalls_TwoSends) { + FakeLibCurl* libcurl = new FakeLibCurl("", 200); + HttpRequest http_request((std::unique_ptr(libcurl))); + TF_EXPECT_OK(http_request.Init()); + + http_request.SetUri("http://www.google.com"); + http_request.Send(); + auto s = http_request.Send(); + ASSERT_TRUE(errors::IsFailedPrecondition(s)); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("The request has already been sent")); +} + +TEST(HttpRequestTest, WrongSequenceOfCalls_ReusingAfterSend) { + FakeLibCurl* libcurl = new FakeLibCurl("", 200); + HttpRequest http_request((std::unique_ptr(libcurl))); + TF_EXPECT_OK(http_request.Init()); + + http_request.SetUri("http://www.google.com"); + http_request.Send(); + auto s = http_request.SetUri("http://mail.google.com"); + ASSERT_TRUE(errors::IsFailedPrecondition(s)); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("The request has already been sent")); +} + +TEST(HttpRequestTest, WrongSequenceOfCalls_SettingMethodTwice) { + FakeLibCurl* libcurl = new FakeLibCurl("", 200); + HttpRequest http_request((std::unique_ptr(libcurl))); + TF_EXPECT_OK(http_request.Init()); + + http_request.SetDeleteRequest(); + auto s = http_request.SetPostRequest(); + ASSERT_TRUE(errors::IsFailedPrecondition(s)); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("HTTP method has been already set")); +} + +TEST(HttpRequestTest, WrongSequenceOfCalls_NotInitialized) { + FakeLibCurl* libcurl = new FakeLibCurl("", 200); + HttpRequest http_request((std::unique_ptr(libcurl))); + + auto s = http_request.SetPostRequest(); + ASSERT_TRUE(errors::IsFailedPrecondition(s)); + EXPECT_TRUE(StringPiece(s.error_message()) + .contains("The object has not been initialized")); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 6b3d85ded44..c20db730795 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -3,6 +3,9 @@ load("//google/protobuf:protobuf.bzl", "cc_proto_library") load("//google/protobuf:protobuf.bzl", "py_proto_library") +# configure may change the following line to True +WITH_GCP_SUPPORT = False + # Appends a suffix to a list of deps. def tf_deps(deps, suffix): tf_deps = [] @@ -91,3 +94,7 @@ def tf_additional_test_srcs(): def tf_kernel_tests_linkstatic(): return 0 + +def tf_additional_lib_deps(): + return (["//tensorflow/core/platform/cloud:gcs_file_system"] + if WITH_GCP_SUPPORT else []) diff --git a/tensorflow/tools/ci_build/install/install_deb_packages.sh b/tensorflow/tools/ci_build/install/install_deb_packages.sh index 596f5b86e39..e3a841468b9 100755 --- a/tensorflow/tools/ci_build/install/install_deb_packages.sh +++ b/tensorflow/tools/ci_build/install/install_deb_packages.sh @@ -32,6 +32,7 @@ apt-get install -y \ gfortran \ libatlas-base-dev \ libblas-dev \ + libcurl4-openssl-dev \ liblapack-dev \ libtool \ openjdk-8-jdk \ diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 81672eca6ee..8e89b217f76 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -96,3 +96,15 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): name = "grpc_lib", actual = "@grpc//:grpc++_unsecure", ) + + native.new_git_repository( + name = "jsoncpp_git", + remote = "https://github.com/open-source-parsers/jsoncpp.git", + commit = "11086dd6a7eba04289944367ca82cea71299ed70", + build_file = path_prefix + "jsoncpp.BUILD", + ) + + native.bind( + name = "jsoncpp", + actual = "@jsoncpp_git//:jsoncpp", + )