File system implementation for Google Cloud Storage.

This code implements a file system for file paths starting with gs:// using the HTTP API to Google Cloud Storage. No authentication is implemented yet, so only GCS objects with public access can be used.
Change: 122126085
This commit is contained in:
A. Unique TensorFlower 2016-05-11 21:03:48 -08:00 committed by TensorFlower Gardener
parent d03631a27a
commit ed65e69560
14 changed files with 2009 additions and 2 deletions

31
configure vendored
View File

@ -35,6 +35,37 @@ while true; do
# Retry
done
while [ "$TF_NEED_GCP" == "" ]; do
read -p "Do you wish to build TensorFlow with "\
"Google Cloud Platform support? [y/N] " INPUT
case $INPUT in
[Yy]* ) echo "Google Cloud Platform support will be enabled for "\
"TensorFlow"; TF_NEED_GCP=1;;
[Nn]* ) echo "No Google Cloud Platform support will be enabled for "\
"TensorFlow"; TF_NEED_GCP=0;;
"" ) echo "No Google Cloud Platform support will be enabled for "\
"TensorFlow"; TF_NEED_GCP=0;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
if [ "$TF_NEED_GCP" == "1" ]; then
## Verify that libcurl header files are available.
# Only check Linux, since on MacOS the header files are installed with XCode.
if [[ $(uname -a) =~ Linux ]] && [[ ! -f "/usr/include/curl/curl.h" ]]; then
echo "ERROR: It appears that the development version of libcurl is not "\
"available. Please install the libcurl3-dev package."
exit 1
fi
# Update Bazel build configuration.
perl -pi -e "s,WITH_GCP_SUPPORT = (False|True),WITH_GCP_SUPPORT = True,s" tensorflow/core/platform/default/build_config.bzl
else
# Update Bazel build configuration.
perl -pi -e "s,WITH_GCP_SUPPORT = (False|True),WITH_GCP_SUPPORT = False,s" tensorflow/core/platform/default/build_config.bzl
fi
## Find swig path
if [ -z "$SWIG_PATH" ]; then
SWIG_PATH=`type -p swig 2> /dev/null`

34
jsoncpp.BUILD Normal file
View File

@ -0,0 +1,34 @@
licenses(["notice"]) # MIT
JSON_HEADERS = [
"include/json/assertions.h",
"include/json/autolink.h",
"include/json/config.h",
"include/json/features.h",
"include/json/forwards.h",
"include/json/json.h",
"src/lib_json/json_batchallocator.h",
"include/json/reader.h",
"include/json/value.h",
"include/json/writer.h",
]
JSON_SOURCES = [
"src/lib_json/json_reader.cpp",
"src/lib_json/json_value.cpp",
"src/lib_json/json_writer.cpp",
"src/lib_json/json_tool.h",
]
INLINE_SOURCES = [
"src/lib_json/json_valueiterator.inl",
]
cc_library(
name = "jsoncpp",
srcs = JSON_SOURCES,
hdrs = JSON_HEADERS,
includes = ["include"],
textual_hdrs = INLINE_SOURCES,
visibility = ["//visibility:public"],
)

View File

@ -91,6 +91,7 @@ filegroup(
"//tensorflow/core/distributed_runtime/rpc:all_files",
"//tensorflow/core/kernels:all_files",
"//tensorflow/core/ops/compat:all_files",
"//tensorflow/core/platform/cloud:all_files",
"//tensorflow/core/platform/default/build_config:all_files",
"//tensorflow/core/util/ctc:all_files",
"//tensorflow/examples/android:all_files",

View File

@ -67,6 +67,7 @@ load(
"tf_proto_library",
"tf_proto_library_cc",
"tf_additional_lib_srcs",
"tf_additional_lib_deps",
"tf_additional_stream_executor_srcs",
"tf_additional_test_deps",
"tf_additional_test_srcs",
@ -995,9 +996,9 @@ tf_cuda_library(
":lib_internal",
":proto_text",
":protos_all_cc",
"//tensorflow/core/kernels:required",
"//third_party/eigen3",
],
"//tensorflow/core/kernels:required",
] + tf_additional_lib_deps(),
alwayslink = 1,
)

View File

@ -0,0 +1,79 @@
# Description:
# Cloud file system implementation.
package(
default_visibility = ["//visibility:private"],
)
licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)
cc_library(
name = "gcs_file_system",
srcs = [
"gcs_file_system.cc",
],
hdrs = [
"gcs_file_system.h",
],
linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669
visibility = ["//visibility:public"],
deps = [
"@jsoncpp_git//:jsoncpp",
":http_request",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:lib_internal",
],
alwayslink = 1,
)
cc_library(
name = "http_request",
srcs = [
"http_request.cc",
],
hdrs = [
"http_request.h",
],
deps = [
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:lib_internal",
],
)
tf_cc_test(
name = "gcs_file_system_test",
size = "small",
deps = [
":gcs_file_system",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "http_request_test",
size = "small",
deps = [
":http_request",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -0,0 +1,516 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/cloud/gcs_file_system.h"
#include <stdio.h>
#include <unistd.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <vector>
#include "include/json/json.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
namespace tensorflow {
namespace {
constexpr char kGcsUriBase[] = "https://www.googleapis.com/storage/v1/";
constexpr char kGcsUploadUriBase[] =
"https://www.googleapis.com/upload/storage/v1/";
constexpr char kStorageHost[] = "storage.googleapis.com";
constexpr size_t kBufferSize = 1024 * 1024; // In bytes.
Status GetTmpFilename(string* filename) {
if (!filename) {
return errors::Internal("'filename' cannot be nullptr.");
}
char buffer[] = "/tmp/gcs_filesystem_XXXXXX";
int fd = mkstemp(buffer);
if (fd < 0) {
return errors::Internal("Failed to create a temporary file.");
}
close(fd);
*filename = buffer;
return Status::OK();
}
/// No-op auth provider, which will only work for public objects.
class EmptyAuthProvider : public AuthProvider {
public:
Status GetToken(string* token) const override {
*token = "";
return Status::OK();
}
};
Status GetAuthToken(const AuthProvider* provider, string* token) {
if (!provider) {
return errors::Internal("Auth provider is required.");
}
return provider->GetToken(token);
}
/// \brief Splits a GCS path to a bucket and an object.
///
/// For example, "gs://bucket-name/path/to/file.txt" gets split into
/// "bucket-name" and "path/to/file.txt".
Status ParseGcsPath(const string& fname, string* bucket, string* object) {
if (!bucket || !object) {
return errors::Internal("bucket and object cannot be null.");
}
StringPiece matched_bucket, matched_object;
if (!strings::Scanner(fname)
.OneLiteral("gs://")
.RestartCapture()
.ScanEscapedUntil('/')
.OneLiteral("/")
.GetResult(&matched_object, &matched_bucket)) {
return errors::InvalidArgument("Couldn't parse GCS path: " + fname);
}
// 'matched_bucket' contains a trailing slash, exclude it.
*bucket = string(matched_bucket.data(), matched_bucket.size() - 1);
*object = string(matched_object.data(), matched_object.size());
return Status::OK();
}
/// GCS-based implementation of a random access file.
class GcsRandomAccessFile : public RandomAccessFile {
public:
GcsRandomAccessFile(const string& bucket, const string& object,
AuthProvider* auth_provider,
HttpRequest::Factory* http_request_factory)
: bucket_(bucket),
object_(object),
auth_provider_(auth_provider),
http_request_factory_(std::move(http_request_factory)) {}
Status Read(uint64 offset, size_t n, StringPiece* result,
char* scratch) const override {
string auth_token;
TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_, &auth_token));
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
TF_RETURN_IF_ERROR(request->Init());
TF_RETURN_IF_ERROR(request->SetUri(
strings::StrCat("https://", bucket_, ".", kStorageHost, "/", object_)));
TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
TF_RETURN_IF_ERROR(request->SetRange(offset, offset + n - 1));
TF_RETURN_IF_ERROR(request->SetResultBuffer(scratch, n, result));
TF_RETURN_IF_ERROR(request->Send());
if (result->size() < n) {
// This is not an error per se. The RandomAccessFile interface expects
// that Read returns OutOfRange if fewer bytes were read than requested.
return errors::OutOfRange(strings::StrCat("EOF reached, ", result->size(),
" bytes were read out of ", n,
" bytes requested."));
}
return Status::OK();
}
private:
string bucket_;
string object_;
AuthProvider* auth_provider_;
HttpRequest::Factory* http_request_factory_;
};
/// \brief GCS-based implementation of a writeable file.
///
/// Since GCS objects are immutable, this implementation writes to a local
/// tmp file and copies it to GCS on flush/close.
class GcsWritableFile : public WritableFile {
public:
GcsWritableFile(const string& bucket, const string& object,
AuthProvider* auth_provider,
HttpRequest::Factory* http_request_factory)
: bucket_(bucket),
object_(object),
auth_provider_(auth_provider),
http_request_factory_(std::move(http_request_factory)) {
if (GetTmpFilename(&tmp_content_filename_).ok()) {
outfile_.open(tmp_content_filename_,
std::ofstream::binary | std::ofstream::app);
}
}
/// \brief Constructs the writable file in append mode.
///
/// tmp_content_filename should contain a path of an existing temporary file
/// with the content to be appended. The class takes onwnership of the
/// specified tmp file and deletes it on close.
GcsWritableFile(const string& bucket, const string& object,
AuthProvider* auth_provider,
const string& tmp_content_filename,
HttpRequest::Factory* http_request_factory)
: bucket_(bucket),
object_(object),
auth_provider_(auth_provider),
http_request_factory_(std::move(http_request_factory)) {
tmp_content_filename_ = tmp_content_filename;
outfile_.open(tmp_content_filename_,
std::ofstream::binary | std::ofstream::app);
}
~GcsWritableFile() { Close(); }
Status Append(const StringPiece& data) override {
TF_RETURN_IF_ERROR(CheckWritable());
outfile_ << data;
return Status::OK();
}
Status Close() override {
if (outfile_.is_open()) {
TF_RETURN_IF_ERROR(Sync());
outfile_.close();
std::remove(tmp_content_filename_.c_str());
}
return Status::OK();
}
Status Flush() override { return Sync(); }
/// Copies the current version of the file to GCS.
Status Sync() override {
TF_RETURN_IF_ERROR(CheckWritable());
outfile_.flush();
string auth_token;
TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_, &auth_token));
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
TF_RETURN_IF_ERROR(request->Init());
TF_RETURN_IF_ERROR(
request->SetUri(strings::StrCat(kGcsUploadUriBase, "b/", bucket_,
"/o?uploadType=media&name=", object_)));
TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
TF_RETURN_IF_ERROR(request->SetPostRequest(tmp_content_filename_));
TF_RETURN_IF_ERROR(request->Send());
return Status::OK();
}
private:
Status CheckWritable() const {
if (!outfile_.is_open()) {
return errors::FailedPrecondition(
"The underlying tmp file is not writable.");
}
return Status::OK();
}
string bucket_;
string object_;
AuthProvider* auth_provider_;
string tmp_content_filename_;
std::ofstream outfile_;
HttpRequest::Factory* http_request_factory_;
};
class GcsReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
public:
GcsReadOnlyMemoryRegion(std::unique_ptr<char[]> data, uint64 length)
: data_(std::move(data)), length_(length) {}
const void* data() override { return reinterpret_cast<void*>(data_.get()); }
uint64 length() override { return length_; }
private:
std::unique_ptr<char[]> data_;
uint64 length_;
};
} // namespace
GcsFileSystem::GcsFileSystem()
: auth_provider_(new EmptyAuthProvider()),
http_request_factory_(new HttpRequest::Factory()) {}
GcsFileSystem::GcsFileSystem(
std::unique_ptr<AuthProvider> auth_provider,
std::unique_ptr<HttpRequest::Factory> http_request_factory)
: auth_provider_(std::move(auth_provider)),
http_request_factory_(std::move(http_request_factory)) {}
Status GcsFileSystem::NewRandomAccessFile(const string& fname,
RandomAccessFile** result) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object));
*result = new GcsRandomAccessFile(bucket, object, auth_provider_.get(),
http_request_factory_.get());
return Status::OK();
}
Status GcsFileSystem::NewWritableFile(const string& fname,
WritableFile** result) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object));
*result = new GcsWritableFile(bucket, object, auth_provider_.get(),
http_request_factory_.get());
return Status::OK();
}
// Reads the file from GCS in chunks and stores it in a tmp file,
// which is then passed to GcsWritableFile.
Status GcsFileSystem::NewAppendableFile(const string& fname,
WritableFile** result) {
RandomAccessFile* reader_ptr;
TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &reader_ptr));
std::unique_ptr<RandomAccessFile> reader(reader_ptr);
std::unique_ptr<char[]> buffer(new char[kBufferSize]);
Status status;
uint64 offset = 0;
StringPiece read_chunk;
// Read the file from GCS in chunks and save it to a tmp file.
string old_content_filename;
TF_RETURN_IF_ERROR(GetTmpFilename(&old_content_filename));
std::ofstream old_content(old_content_filename, std::ofstream::binary);
while (true) {
status = reader->Read(offset, kBufferSize, &read_chunk, buffer.get());
if (status.ok()) {
old_content << read_chunk;
offset += kBufferSize;
} else if (status.code() == error::OUT_OF_RANGE) {
// Expected, this means we reached EOF.
old_content << read_chunk;
break;
} else {
return status;
}
}
old_content.close();
// Create a writable file and pass the old content to it.
string bucket, object;
TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object));
*result =
new GcsWritableFile(bucket, object, auth_provider_.get(),
old_content_filename, http_request_factory_.get());
return Status::OK();
}
Status GcsFileSystem::NewReadOnlyMemoryRegionFromFile(
const string& fname, ReadOnlyMemoryRegion** result) {
uint64 size;
TF_RETURN_IF_ERROR(GetFileSize(fname, &size));
std::unique_ptr<char[]> data(new char[size]);
RandomAccessFile* file;
TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &file));
std::unique_ptr<RandomAccessFile> file_ptr(file);
StringPiece piece;
TF_RETURN_IF_ERROR(file->Read(0, size, &piece, data.get()));
*result = new GcsReadOnlyMemoryRegion(std::move(data), size);
return Status::OK();
}
bool GcsFileSystem::FileExists(const string& fname) {
string bucket, object_prefix;
if (!ParseGcsPath(fname, &bucket, &object_prefix).ok()) {
LOG(ERROR) << "Could not parse GCS file name " << fname;
return false;
}
string auth_token;
if (!GetAuthToken(auth_provider_.get(), &auth_token).ok()) {
LOG(ERROR) << "Could not get an auth token.";
return false;
}
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
if (!request->Init().ok()) {
LOG(ERROR) << "Could not initialize the HTTP request.";
return false;
}
request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket, "/o/",
object_prefix, "?fields=size"));
request->AddAuthBearerHeader(auth_token);
return request->Send().ok();
}
Status GcsFileSystem::GetChildren(const string& dirname,
std::vector<string>* result) {
if (!result) {
return errors::InvalidArgument("'result' cannot be null");
}
string sanitized_dirname = dirname;
if (!dirname.empty() && dirname.back() != '/') {
sanitized_dirname += "/";
}
string bucket, object_prefix;
TF_RETURN_IF_ERROR(ParseGcsPath(sanitized_dirname, &bucket, &object_prefix));
string auth_token;
TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_.get(), &auth_token));
std::unique_ptr<char[]> scratch(new char[kBufferSize]);
StringPiece response_piece;
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
TF_RETURN_IF_ERROR(request->Init());
TF_RETURN_IF_ERROR(
request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket, "/o?prefix=",
object_prefix, "&fields=items")));
TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
// TODO(surkov): Implement pagination using maxResults and pageToken
// instead, so that all items can be read regardless of their count.
// Currently one item takes about 1KB in the response, so with a 1MB
// buffer size this will read fewer than 1000 objects.
TF_RETURN_IF_ERROR(
request->SetResultBuffer(scratch.get(), kBufferSize, &response_piece));
TF_RETURN_IF_ERROR(request->Send());
std::stringstream response_stream;
response_stream << response_piece;
Json::Value root;
Json::Reader reader;
if (!reader.parse(response_stream.str(), root)) {
return errors::Internal("Couldn't parse JSON response from GCS.");
}
const auto items = root.get("items", Json::Value::null);
if (items == Json::Value::null) {
// Empty results.
return Status::OK();
}
if (!items.isArray()) {
return errors::Internal("Expected an array 'items' in the GCS response.");
}
for (size_t i = 0; i < items.size(); i++) {
const auto item = items.get(i, Json::Value::null);
if (!item.isObject()) {
return errors::Internal(
"Unexpected JSON format: 'items' should be a list of objects.");
}
const auto name = item.get("name", Json::Value::null);
if (name == Json::Value::null || !name.isString()) {
return errors::Internal(
"Unexpected JSON format: 'items.name' is missing or not a string.");
}
result->push_back(
strings::StrCat("gs://", bucket, "/", name.asString().c_str()));
}
return Status::OK();
}
Status GcsFileSystem::DeleteFile(const string& fname) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object));
string auth_token;
TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_.get(), &auth_token));
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
TF_RETURN_IF_ERROR(request->Init());
TF_RETURN_IF_ERROR(request->SetUri(
strings::StrCat(kGcsUriBase, "b/", bucket, "/o/", object)));
TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
TF_RETURN_IF_ERROR(request->SetDeleteRequest());
TF_RETURN_IF_ERROR(request->Send());
return Status::OK();
}
// Does nothing, because directories are not entities in GCS.
Status GcsFileSystem::CreateDir(const string& dirname) { return Status::OK(); }
// Checks that the directory is empty (i.e no objects with this prefix exist).
// If it is, does nothing, because directories are not entities in GCS.
Status GcsFileSystem::DeleteDir(const string& dirname) {
string sanitized_dirname = dirname;
if (!dirname.empty() && dirname.back() != '/') {
sanitized_dirname += "/";
}
std::vector<string> children;
TF_RETURN_IF_ERROR(GetChildren(sanitized_dirname, &children));
if (!children.empty()) {
return errors::InvalidArgument("Cannot delete a non-empty directory.");
}
return Status::OK();
}
Status GcsFileSystem::GetFileSize(const string& fname, uint64* file_size) {
string bucket, object_prefix;
TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object_prefix));
string auth_token;
TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_.get(), &auth_token));
std::unique_ptr<char[]> scratch(new char[kBufferSize]);
StringPiece response_piece;
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
TF_RETURN_IF_ERROR(request->Init());
TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat(
kGcsUriBase, "b/", bucket, "/o/", object_prefix, "?fields=size")));
TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
TF_RETURN_IF_ERROR(
request->SetResultBuffer(scratch.get(), kBufferSize, &response_piece));
TF_RETURN_IF_ERROR(request->Send());
std::stringstream response_stream;
response_stream << response_piece;
Json::Value root;
Json::Reader reader;
if (!reader.parse(response_stream.str(), root)) {
return errors::Internal("Couldn't parse JSON response from GCS.");
}
const auto size = root.get("size", Json::Value::null);
if (size == Json::Value::null) {
return errors::Internal("'size' was expected in the JSON response.");
}
if (size.isNumeric()) {
*file_size = size.asUInt64();
} else if (size.isString()) {
if (!strings::safe_strtou64(size.asString().c_str(), file_size)) {
return errors::Internal("'size' couldn't be parsed as a nubmer.");
}
} else {
return errors::Internal("'size' is not a number in the JSON response.");
}
return Status::OK();
}
// Uses a GCS API command to copy the object and then deletes the old one.
Status GcsFileSystem::RenameFile(const string& src, const string& target) {
string src_bucket, src_object, target_bucket, target_object;
TF_RETURN_IF_ERROR(ParseGcsPath(src, &src_bucket, &src_object));
TF_RETURN_IF_ERROR(ParseGcsPath(target, &target_bucket, &target_object));
string auth_token;
TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_.get(), &auth_token));
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
TF_RETURN_IF_ERROR(request->Init());
TF_RETURN_IF_ERROR(request->SetUri(
strings::StrCat(kGcsUriBase, "b/", src_bucket, "/o/", src_object,
"/rewriteTo/b/", target_bucket, "/o/", target_object)));
TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
TF_RETURN_IF_ERROR(request->SetPostRequest());
TF_RETURN_IF_ERROR(request->Send());
TF_RETURN_IF_ERROR(DeleteFile(src));
return Status::OK();
}
REGISTER_FILE_SYSTEM("gs", GcsFileSystem);
} // namespace tensorflow

View File

@ -0,0 +1,73 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_PLATFORM_GCS_FILE_SYSTEM_H_
#define TENSORFLOW_CORE_PLATFORM_GCS_FILE_SYSTEM_H_
#include <string>
#include <vector>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/cloud/http_request.h"
#include "tensorflow/core/platform/file_system.h"
namespace tensorflow {
/// Interface for a provider of HTTP auth bearer tokens.
class AuthProvider {
public:
virtual ~AuthProvider() {}
virtual Status GetToken(string* t) const = 0;
};
/// Google Cloud Storage implementation of a file system.
class GcsFileSystem : public FileSystem {
public:
GcsFileSystem();
GcsFileSystem(std::unique_ptr<AuthProvider> auth_provider,
std::unique_ptr<HttpRequest::Factory> http_request_factory);
Status NewRandomAccessFile(const string& fname,
RandomAccessFile** result) override;
Status NewWritableFile(const string& fname, WritableFile** result) override;
Status NewAppendableFile(const string& fname, WritableFile** result) override;
Status NewReadOnlyMemoryRegionFromFile(
const string& fname, ReadOnlyMemoryRegion** result) override;
bool FileExists(const string& fname) override;
Status GetChildren(const string& dir, std::vector<string>* result) override;
Status DeleteFile(const string& fname) override;
Status CreateDir(const string& dirname) override;
Status DeleteDir(const string& dirname) override;
Status GetFileSize(const string& fname, uint64* file_size) override;
Status RenameFile(const string& src, const string& target) override;
private:
std::unique_ptr<AuthProvider> auth_provider_;
std::unique_ptr<HttpRequest::Factory> http_request_factory_;
TF_DISALLOW_COPY_AND_ASSIGN(GcsFileSystem);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_PLATFORM_GCS_FILE_SYSTEM_H_

View File

@ -0,0 +1,363 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/cloud/gcs_file_system.h"
#include <fstream>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class FakeHttpRequest : public HttpRequest {
public:
FakeHttpRequest(const string& request, const string& response)
: FakeHttpRequest(request, response, Status::OK()) {}
FakeHttpRequest(const string& request, const string& response,
Status response_status)
: expected_request_(request),
response_(response),
response_status_(response_status) {}
Status Init() override { return Status::OK(); }
Status SetUri(const string& uri) override {
actual_request_ += "Uri: " + uri + "\n";
return Status::OK();
}
Status SetRange(uint64 start, uint64 end) override {
actual_request_ += strings::StrCat("Range: ", start, "-", end, "\n");
return Status::OK();
}
Status AddAuthBearerHeader(const string& auth_token) override {
actual_request_ += "Auth Token: " + auth_token + "\n";
return Status::OK();
}
Status SetDeleteRequest() override {
actual_request_ += "Delete: yes\n";
return Status::OK();
}
Status SetPostRequest(const string& body_filepath) override {
std::ifstream stream(body_filepath);
string content((std::istreambuf_iterator<char>(stream)),
std::istreambuf_iterator<char>());
actual_request_ += "Post body: " + content + "\n";
return Status::OK();
}
Status SetPostRequest() override {
actual_request_ += "Post: yes\n";
return Status::OK();
}
Status SetResultBuffer(char* scratch, size_t size,
StringPiece* result) override {
scratch_ = scratch;
size_ = size;
result_ = result;
return Status::OK();
}
Status Send() override {
EXPECT_EQ(expected_request_, actual_request_) << "Unexpected HTTP request.";
if (scratch_ && result_) {
auto actual_size = std::min(response_.size(), size_);
memcpy(scratch_, response_.c_str(), actual_size);
*result_ = StringPiece(scratch_, actual_size);
}
return response_status_;
}
private:
char* scratch_ = nullptr;
size_t size_ = 0;
StringPiece* result_ = nullptr;
string expected_request_;
string actual_request_;
string response_;
Status response_status_;
};
class FakeHttpRequestFactory : public HttpRequest::Factory {
public:
FakeHttpRequestFactory(const std::vector<HttpRequest*>* requests)
: requests_(requests) {}
~FakeHttpRequestFactory() {
EXPECT_EQ(current_index_, requests_->size())
<< "Not all expected requests were made.";
}
HttpRequest* Create() override {
EXPECT_LT(current_index_, requests_->size())
<< "Too many calls of HttpRequest factory.";
return (*requests_)[current_index_++];
}
private:
const std::vector<HttpRequest*>* requests_;
int current_index_ = 0;
};
class FakeAuthProvider : public AuthProvider {
public:
Status GetToken(string* token) const override {
*token = "fake_token";
return Status::OK();
}
};
TEST(GcsFileSystemTest, NewRandomAccessFile) {
std::vector<HttpRequest*> requests(
{new FakeHttpRequest(
"Uri: https://bucket.storage.googleapis.com/random_access.txt\n"
"Auth Token: fake_token\n"
"Range: 0-5\n",
"012345"),
new FakeHttpRequest(
"Uri: https://bucket.storage.googleapis.com/random_access.txt\n"
"Auth Token: fake_token\n"
"Range: 6-11\n",
"6789")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)));
RandomAccessFile* file_ptr;
TF_EXPECT_OK(
fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file_ptr));
std::unique_ptr<RandomAccessFile> file(file_ptr);
char scratch[6];
StringPiece result;
// Read the first chunk.
TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch));
EXPECT_EQ("012345", result);
// Read the second chunk.
EXPECT_EQ(
errors::Code::OUT_OF_RANGE,
file->Read(sizeof(scratch), sizeof(scratch), &result, scratch).code());
EXPECT_EQ("6789", result);
}
TEST(GcsFileSystemTest, NewWritableFile) {
std::vector<HttpRequest*> requests({new FakeHttpRequest(
"Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?"
"uploadType=media&name=path/writeable.txt\n"
"Auth Token: fake_token\n"
"Post body: content1,content2\n",
"")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)));
WritableFile* file_ptr;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file_ptr));
std::unique_ptr<WritableFile> file(file_ptr);
TF_EXPECT_OK(file->Append("content1,"));
TF_EXPECT_OK(file->Append("content2"));
TF_EXPECT_OK(file->Close());
}
TEST(GcsFileSystemTest, NewAppendableFile) {
std::vector<HttpRequest*> requests(
{new FakeHttpRequest(
"Uri: https://bucket.storage.googleapis.com/path/appendable.txt\n"
"Auth Token: fake_token\n"
"Range: 0-1048575\n",
"content1,"),
new FakeHttpRequest(
"Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?"
"uploadType=media&name=path/appendable.txt\n"
"Auth Token: fake_token\n"
"Post body: content1,content2\n",
"")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)));
WritableFile* file_ptr;
TF_EXPECT_OK(
fs.NewAppendableFile("gs://bucket/path/appendable.txt", &file_ptr));
std::unique_ptr<WritableFile> file(file_ptr);
TF_EXPECT_OK(file->Append("content2"));
TF_EXPECT_OK(file->Close());
}
TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) {
const string content = "file content";
std::vector<HttpRequest*> requests(
{new FakeHttpRequest(
"Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
"random_access.txt?fields=size\n"
"Auth Token: fake_token\n",
strings::StrCat("{\"size\": \"", content.size(), "\"}")),
new FakeHttpRequest(
strings::StrCat(
"Uri: https://bucket.storage.googleapis.com/random_access.txt\n"
"Auth Token: fake_token\n"
"Range: 0-",
content.size() - 1, "\n"),
content)});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)));
ReadOnlyMemoryRegion* region_ptr;
TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile(
"gs://bucket/random_access.txt", &region_ptr));
std::unique_ptr<ReadOnlyMemoryRegion> region(region_ptr);
EXPECT_EQ(content, StringPiece(reinterpret_cast<const char*>(region->data()),
region->length()));
}
TEST(GcsFileSystemTest, FileExists) {
std::vector<HttpRequest*> requests(
{new FakeHttpRequest(
"Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
"path/file1.txt?fields=size\n"
"Auth Token: fake_token\n",
"{\"size\": \"100\"}"),
new FakeHttpRequest(
"Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
"path/file2.txt?fields=size\n"
"Auth Token: fake_token\n",
"", errors::NotFound("404"))});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)));
EXPECT_TRUE(fs.FileExists("gs://bucket/path/file1.txt"));
EXPECT_FALSE(fs.FileExists("gs://bucket/path/file2.txt"));
}
TEST(GcsFileSystemTest, GetChildren_ThreeFiles) {
std::vector<HttpRequest*> requests({new FakeHttpRequest(
"Uri: https://www.googleapis.com/storage/v1/b/bucket/o?"
"prefix=path/&fields=items\n"
"Auth Token: fake_token\n",
"{\"items\": [ "
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)));
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
EXPECT_EQ(3, children.size());
EXPECT_EQ("gs://bucket/path/file1.txt", children[0]);
EXPECT_EQ("gs://bucket/path/subpath/file2.txt", children[1]);
EXPECT_EQ("gs://bucket/path/file3.txt", children[2]);
}
TEST(GcsFileSystemTest, GetChildren_Empty) {
std::vector<HttpRequest*> requests({new FakeHttpRequest(
"Uri: https://www.googleapis.com/storage/v1/b/bucket/o?"
"prefix=path/&fields=items\n"
"Auth Token: fake_token\n",
"{}")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)));
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
EXPECT_EQ(0, children.size());
}
TEST(GcsFileSystemTest, DeleteFile) {
std::vector<HttpRequest*> requests(
{new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b"
"/bucket/o/path/file1.txt\n"
"Auth Token: fake_token\n"
"Delete: yes\n",
"")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)));
TF_EXPECT_OK(fs.DeleteFile("gs://bucket/path/file1.txt"));
}
TEST(GcsFileSystemTest, DeleteDir_Empty) {
std::vector<HttpRequest*> requests({new FakeHttpRequest(
"Uri: https://www.googleapis.com/storage/v1/b/bucket/o?"
"prefix=path/&fields=items\n"
"Auth Token: fake_token\n",
"{}")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)));
TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/"));
}
TEST(GcsFileSystemTest, DeleteDir_NonEmpty) {
std::vector<HttpRequest*> requests({new FakeHttpRequest(
"Uri: https://www.googleapis.com/storage/v1/b/bucket/o?"
"prefix=path/&fields=items\n"
"Auth Token: fake_token\n",
"{\"items\": [ "
" { \"name\": \"path/file1.txt\" }]}")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)));
EXPECT_FALSE(fs.DeleteDir("gs://bucket/path/").ok());
}
TEST(GcsFileSystemTest, GetFileSize) {
std::vector<HttpRequest*> requests({new FakeHttpRequest(
"Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
"file.txt?fields=size\n"
"Auth Token: fake_token\n",
strings::StrCat("{\"size\": \"1010\"}"))});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)));
uint64 size;
TF_EXPECT_OK(fs.GetFileSize("gs://bucket/file.txt", &size));
EXPECT_EQ(1010, size);
}
TEST(GcsFileSystemTest, RenameFile) {
std::vector<HttpRequest*> requests(
{new FakeHttpRequest(
"Uri: https://www.googleapis.com/storage/v1/b/bucket/o/src.txt"
"/rewriteTo/b/bucket/o/dst.txt\n"
"Auth Token: fake_token\n"
"Post: yes\n",
""),
new FakeHttpRequest(
"Uri: https://www.googleapis.com/storage/v1/b/bucket/o/src.txt\n"
"Auth Token: fake_token\n"
"Delete: yes\n",
"")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)));
TF_EXPECT_OK(fs.RenameFile("gs://bucket/src.txt", "gs://bucket/dst.txt"));
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,410 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/cloud/http_request.h"
#include <dlfcn.h>
#include <stdio.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
// Windows is not currently supported.
constexpr char kCurlLibLinux[] = "libcurl.so.3";
constexpr char kCurlLibMac[] = "/usr/lib/libcurl.3.dylib";
constexpr char kCertsPath[] = "/etc/ssl/certs";
// Set to 1 to enable verbose debug output from curl.
constexpr uint64 kVerboseOutput = 0;
/// An implementation that dynamically loads libcurl and forwards calls to it.
class LibCurlProxy : public LibCurl {
public:
~LibCurlProxy() {
if (dll_handle_) {
dlclose(dll_handle_);
}
}
Status MaybeLoadDll() override {
if (dll_handle_) {
return Status::OK();
}
// This may have been linked statically; if curl_easy_init is in the
// current binary, no need to search for a dynamic version.
dll_handle_ = load_dll(nullptr);
if (!dll_handle_) {
dll_handle_ = load_dll(kCurlLibLinux);
}
if (!dll_handle_) {
dll_handle_ = load_dll(kCurlLibMac);
}
if (!dll_handle_) {
return errors::FailedPrecondition(strings::StrCat(
"Could not initialize the libcurl library. Please make sure that "
"libcurl is installed in the OS or statically linked to the "
"TensorFlow binary."));
}
curl_global_init_(CURL_GLOBAL_ALL);
return Status::OK();
}
CURL* curl_easy_init() override {
CHECK(dll_handle_);
return curl_easy_init_();
}
CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
uint64 param) override {
CHECK(dll_handle_);
return curl_easy_setopt_(curl, option, param);
}
CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
const char* param) override {
CHECK(dll_handle_);
return curl_easy_setopt_(curl, option, param);
}
CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
void* param) override {
CHECK(dll_handle_);
return curl_easy_setopt_(curl, option, param);
}
CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
size_t (*param)(void*, size_t, size_t,
FILE*)) override {
CHECK(dll_handle_);
return curl_easy_setopt_(curl, option, param);
}
CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
size_t (*param)(const void*, size_t, size_t,
void*)) override {
CHECK(dll_handle_);
return curl_easy_setopt_(curl, option, param);
}
CURLcode curl_easy_perform(CURL* curl) override {
CHECK(dll_handle_);
return curl_easy_perform_(curl);
}
CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info,
uint64* value) override {
CHECK(dll_handle_);
return curl_easy_getinfo_(curl, info, value);
}
CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info,
double* value) override {
CHECK(dll_handle_);
return curl_easy_getinfo_(curl, info, value);
}
void curl_easy_cleanup(CURL* curl) override {
CHECK(dll_handle_);
return curl_easy_cleanup_(curl);
}
curl_slist* curl_slist_append(curl_slist* list, const char* str) override {
CHECK(dll_handle_);
return curl_slist_append_(list, str);
}
void curl_slist_free_all(curl_slist* list) override {
CHECK(dll_handle_);
return curl_slist_free_all_(list);
}
private:
// Loads the dynamic library and binds the required methods.
// Returns the library handle in case of success or nullptr otherwise.
// 'name' can be nullptr.
void* load_dll(const char* name) {
void* handle = nullptr;
handle = dlopen(name, RTLD_NOW | RTLD_LOCAL | RTLD_NODELETE);
if (!handle) {
return nullptr;
}
#define BIND_CURL_FUNC(function) \
*reinterpret_cast<void**>(&(function##_)) = dlsym(handle, #function)
BIND_CURL_FUNC(curl_global_init);
BIND_CURL_FUNC(curl_easy_init);
BIND_CURL_FUNC(curl_easy_setopt);
BIND_CURL_FUNC(curl_easy_perform);
BIND_CURL_FUNC(curl_easy_getinfo);
BIND_CURL_FUNC(curl_slist_append);
BIND_CURL_FUNC(curl_slist_free_all);
BIND_CURL_FUNC(curl_easy_cleanup);
#undef BIND_CURL_FUNC
if (curl_global_init_ == nullptr) {
dlerror(); // Clear dlerror before attempting to open libraries.
dlclose(handle);
return nullptr;
}
return handle;
}
void* dll_handle_ = nullptr;
CURLcode (*curl_global_init_)(int64) = nullptr;
CURL* (*curl_easy_init_)(void) = nullptr;
CURLcode (*curl_easy_setopt_)(CURL*, CURLoption, ...) = nullptr;
CURLcode (*curl_easy_perform_)(CURL* curl) = nullptr;
CURLcode (*curl_easy_getinfo_)(CURL* curl, CURLINFO info, ...) = nullptr;
void (*curl_easy_cleanup_)(CURL* curl) = nullptr;
curl_slist* (*curl_slist_append_)(curl_slist* list,
const char* str) = nullptr;
void (*curl_slist_free_all_)(curl_slist* list) = nullptr;
};
} // namespace
HttpRequest::HttpRequest()
: HttpRequest(std::unique_ptr<LibCurl>(new LibCurlProxy)) {}
HttpRequest::HttpRequest(std::unique_ptr<LibCurl> libcurl)
: libcurl_(std::move(libcurl)),
default_response_buffer_(new char[CURL_MAX_WRITE_SIZE]) {}
HttpRequest::~HttpRequest() {
if (curl_headers_) {
libcurl_->curl_slist_free_all(curl_headers_);
}
if (post_body_) {
fclose(post_body_);
}
if (curl_) {
libcurl_->curl_easy_cleanup(curl_);
}
}
Status HttpRequest::Init() {
if (!libcurl_) {
return errors::Internal("libcurl proxy cannot be nullptr.");
}
TF_RETURN_IF_ERROR(libcurl_->MaybeLoadDll());
curl_ = libcurl_->curl_easy_init();
if (!curl_) {
return errors::Internal("Couldn't initialize a curl session.");
}
libcurl_->curl_easy_setopt(curl_, CURLOPT_VERBOSE, kVerboseOutput);
libcurl_->curl_easy_setopt(curl_, CURLOPT_CAPATH, kCertsPath);
// If response buffer is not set, libcurl will print results to stdout,
// so we always set it.
is_initialized_ = true;
auto s = SetResultBuffer(default_response_buffer_.get(), CURL_MAX_WRITE_SIZE,
&default_response_string_piece_);
if (!s.ok()) {
is_initialized_ = false;
return s;
}
return Status::OK();
}
Status HttpRequest::SetUri(const string& uri) {
TF_RETURN_IF_ERROR(CheckInitialized());
TF_RETURN_IF_ERROR(CheckNotSent());
is_uri_set_ = true;
libcurl_->curl_easy_setopt(curl_, CURLOPT_URL, uri.c_str());
return Status::OK();
}
Status HttpRequest::SetRange(uint64 start, uint64 end) {
TF_RETURN_IF_ERROR(CheckInitialized());
TF_RETURN_IF_ERROR(CheckNotSent());
libcurl_->curl_easy_setopt(curl_, CURLOPT_RANGE,
strings::StrCat(start, "-", end).c_str());
return Status::OK();
}
Status HttpRequest::AddAuthBearerHeader(const string& auth_token) {
TF_RETURN_IF_ERROR(CheckInitialized());
TF_RETURN_IF_ERROR(CheckNotSent());
if (!auth_token.empty()) {
curl_headers_ = libcurl_->curl_slist_append(
curl_headers_,
strings::StrCat("Authorization: Bearer ", auth_token).c_str());
}
return Status::OK();
}
Status HttpRequest::SetDeleteRequest() {
TF_RETURN_IF_ERROR(CheckInitialized());
TF_RETURN_IF_ERROR(CheckNotSent());
TF_RETURN_IF_ERROR(CheckMethodNotSet());
is_method_set_ = true;
libcurl_->curl_easy_setopt(curl_, CURLOPT_CUSTOMREQUEST, "DELETE");
return Status::OK();
}
Status HttpRequest::SetPostRequest(const string& body_filepath) {
TF_RETURN_IF_ERROR(CheckInitialized());
TF_RETURN_IF_ERROR(CheckNotSent());
TF_RETURN_IF_ERROR(CheckMethodNotSet());
is_method_set_ = true;
if (post_body_) {
fclose(post_body_);
}
post_body_ = fopen(body_filepath.c_str(), "r");
if (!post_body_) {
return errors::InvalidArgument("Couldnt' open the specified file: " +
body_filepath);
}
fseek(post_body_, 0, SEEK_END);
const auto size = ftell(post_body_);
fseek(post_body_, 0, SEEK_SET);
curl_headers_ = libcurl_->curl_slist_append(
curl_headers_, strings::StrCat("Content-Length: ", size).c_str());
libcurl_->curl_easy_setopt(curl_, CURLOPT_POST, 1);
libcurl_->curl_easy_setopt(curl_, CURLOPT_READDATA,
reinterpret_cast<void*>(post_body_));
return Status::OK();
}
Status HttpRequest::SetPostRequest() {
TF_RETURN_IF_ERROR(CheckInitialized());
TF_RETURN_IF_ERROR(CheckNotSent());
TF_RETURN_IF_ERROR(CheckMethodNotSet());
is_method_set_ = true;
libcurl_->curl_easy_setopt(curl_, CURLOPT_POST, 1);
curl_headers_ =
libcurl_->curl_slist_append(curl_headers_, "Content-Length: 0");
return Status::OK();
}
Status HttpRequest::SetResultBuffer(char* scratch, size_t size,
StringPiece* result) {
TF_RETURN_IF_ERROR(CheckInitialized());
TF_RETURN_IF_ERROR(CheckNotSent());
if (!scratch) {
return errors::InvalidArgument("scratch cannot be null");
}
if (!result) {
return errors::InvalidArgument("result cannot be null");
}
if (size <= 0) {
return errors::InvalidArgument("buffer size should be positive");
}
response_buffer_ = scratch;
response_buffer_size_ = size;
response_string_piece_ = result;
response_buffer_written_ = 0;
libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEDATA,
reinterpret_cast<void*>(this));
libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEFUNCTION,
&HttpRequest::WriteCallback);
return Status::OK();
}
size_t HttpRequest::WriteCallback(const void* ptr, size_t size, size_t nmemb,
void* this_object) {
CHECK(ptr);
auto that = reinterpret_cast<HttpRequest*>(this_object);
CHECK(that->response_buffer_);
CHECK(that->response_buffer_size_ >= that->response_buffer_written_);
const size_t bytes_to_copy =
std::min(size * nmemb,
that->response_buffer_size_ - that->response_buffer_written_);
memcpy(that->response_buffer_ + that->response_buffer_written_, ptr,
bytes_to_copy);
that->response_buffer_written_ += bytes_to_copy;
return bytes_to_copy;
}
Status HttpRequest::Send() {
TF_RETURN_IF_ERROR(CheckInitialized());
TF_RETURN_IF_ERROR(CheckNotSent());
is_sent_ = true;
if (!is_uri_set_) {
return errors::FailedPrecondition("URI has not been set.");
}
if (curl_headers_) {
libcurl_->curl_easy_setopt(curl_, CURLOPT_HTTPHEADER, curl_headers_);
}
char error_buffer[CURL_ERROR_SIZE];
libcurl_->curl_easy_setopt(curl_, CURLOPT_ERRORBUFFER, error_buffer);
const auto curl_result = libcurl_->curl_easy_perform(curl_);
double written_size = 0;
libcurl_->curl_easy_getinfo(curl_, CURLINFO_SIZE_DOWNLOAD, &written_size);
uint64 response_code;
libcurl_->curl_easy_getinfo(curl_, CURLINFO_RESPONSE_CODE, &response_code);
if (curl_result != CURLE_OK) {
return errors::Internal(string("curl error: ") + error_buffer);
}
switch (response_code) {
case 200: // OK
case 204: // No Content
case 206: // Partial Content
if (response_buffer_ && response_string_piece_) {
*response_string_piece_ = StringPiece(response_buffer_, written_size);
}
return Status::OK();
case 401:
return errors::PermissionDenied(
"Not authorized to access the given HTTP resource.");
case 404:
return errors::NotFound("The requested URL was not found.");
case 416: // Requested Range Not Satisfiable
if (response_string_piece_) {
*response_string_piece_ = StringPiece();
}
return Status::OK();
default:
return errors::Internal(
strings::StrCat("Unexpected HTTP response code ", response_code));
}
}
Status HttpRequest::CheckInitialized() const {
if (!is_initialized_) {
return errors::FailedPrecondition("The object has not been initialized.");
}
return Status::OK();
}
Status HttpRequest::CheckMethodNotSet() const {
if (is_method_set_) {
return errors::FailedPrecondition("HTTP method has been already set.");
}
return Status::OK();
}
Status HttpRequest::CheckNotSent() const {
if (is_sent_) {
return errors::FailedPrecondition("The request has already been sent.");
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,156 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_
#define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_
#include <functional>
#include <string>
#include <vector>
#include <curl/curl.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class LibCurl; // libcurl interface as a class, for dependency injection.
/// \brief A basic HTTP client based on the libcurl library.
///
/// The usage pattern for the class reflects the one of the libcurl library:
/// create a request object, set request parameters and call Send().
///
/// For example:
/// HttpRequest request;
/// request.SetUri("http://www.google.com");
/// request.SetResultsBuffer(scratch, 1000, &result);
/// request.Send();
class HttpRequest {
public:
class Factory {
public:
virtual ~Factory() {}
virtual HttpRequest* Create() { return new HttpRequest(); }
};
HttpRequest();
explicit HttpRequest(std::unique_ptr<LibCurl> libcurl);
virtual ~HttpRequest();
virtual Status Init();
/// Sets the request URI.
virtual Status SetUri(const string& uri);
/// \brief Sets the Range header.
///
/// Used for random seeks, for example "0-999" returns the first 1000 bytes
/// (note that the right border is included).
virtual Status SetRange(uint64 start, uint64 end);
/// Sets the 'Authorization' header to the value of 'Bearer ' + auth_token.
virtual Status AddAuthBearerHeader(const string& auth_token);
/// Makes the request a DELETE request.
virtual Status SetDeleteRequest();
/// \brief Makes the request a POST request.
///
/// The request body will be taken from the specified file.
virtual Status SetPostRequest(const string& body_filepath);
/// Makes the request a POST request.
virtual Status SetPostRequest();
/// \brief Specifies the buffer for receiving the response body.
///
/// The interface is made similar to RandomAccessFile::Read.
virtual Status SetResultBuffer(char* scratch, size_t size,
StringPiece* result);
/// \brief Sends the formed request.
///
/// If the result buffer was defined, the response will be written there.
/// The object is not designed to be re-used after Send() is executed.
virtual Status Send();
private:
/// A callback in the form which can be accepted by libcurl.
static size_t WriteCallback(const void* ptr, size_t size, size_t nmemb,
void* userdata);
Status CheckInitialized() const;
Status CheckMethodNotSet() const;
Status CheckNotSent() const;
std::unique_ptr<LibCurl> libcurl_;
FILE* post_body_ = nullptr;
char* response_buffer_ = nullptr;
size_t response_buffer_size_ = 0;
size_t response_buffer_written_ = 0;
StringPiece* response_string_piece_ = nullptr;
CURL* curl_ = nullptr;
curl_slist* curl_headers_ = nullptr;
std::unique_ptr<char[]> default_response_buffer_;
StringPiece default_response_string_piece_;
// Members to enforce the usage flow.
bool is_initialized_ = false;
bool is_uri_set_ = false;
bool is_method_set_ = false;
bool is_sent_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(HttpRequest);
};
/// \brief A proxy to the libcurl C interface as a dependency injection measure.
///
/// This class is meant as a very thin wrapper for the libcurl C library.
class LibCurl {
public:
virtual ~LibCurl() {}
/// Lazy initialization of the dynamic libcurl library.
virtual Status MaybeLoadDll() = 0;
virtual CURL* curl_easy_init() = 0;
virtual CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
uint64 param) = 0;
virtual CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
const char* param) = 0;
virtual CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
void* param) = 0;
virtual CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
size_t (*param)(void*, size_t, size_t,
FILE*)) = 0;
virtual CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
size_t (*param)(const void*, size_t, size_t,
void*)) = 0;
virtual CURLcode curl_easy_perform(CURL* curl) = 0;
virtual CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info,
uint64* value) = 0;
virtual CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info,
double* value) = 0;
virtual void curl_easy_cleanup(CURL* curl) = 0;
virtual curl_slist* curl_slist_append(curl_slist* list, const char* str) = 0;
virtual void curl_slist_free_all(curl_slist* list) = 0;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_

View File

@ -0,0 +1,323 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/cloud/http_request.h"
#include <fstream>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
// A fake proxy that pretends to be libcurl.
class FakeLibCurl : public LibCurl {
public:
FakeLibCurl(const string& response_content, uint64 response_code)
: response_content(response_content), response_code(response_code) {}
Status MaybeLoadDll() override { return Status::OK(); }
CURL* curl_easy_init() override {
is_initialized = true;
// The reuslt just needs to be non-null.
return reinterpret_cast<CURL*>(this);
}
CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
uint64 param) override {
switch (option) {
case CURLOPT_POST:
is_post = param;
break;
default:
break;
}
return CURLE_OK;
}
CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
const char* param) override {
return curl_easy_setopt(curl, option,
reinterpret_cast<void*>(const_cast<char*>(param)));
}
CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
void* param) override {
switch (option) {
case CURLOPT_URL:
url = reinterpret_cast<char*>(param);
break;
case CURLOPT_RANGE:
range = reinterpret_cast<char*>(param);
break;
case CURLOPT_CUSTOMREQUEST:
custom_request = reinterpret_cast<char*>(param);
break;
case CURLOPT_HTTPHEADER:
headers = reinterpret_cast<std::vector<string>*>(param);
break;
case CURLOPT_ERRORBUFFER:
error_buffer = reinterpret_cast<char*>(param);
break;
case CURLOPT_WRITEDATA:
write_data = reinterpret_cast<FILE*>(param);
break;
case CURLOPT_READDATA:
read_data = reinterpret_cast<FILE*>(param);
break;
default:
break;
}
return CURLE_OK;
}
CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
size_t (*param)(void*, size_t, size_t,
FILE*)) override {
EXPECT_EQ(param, &fread) << "Expected the standard fread() function.";
return CURLE_OK;
}
CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
size_t (*param)(const void*, size_t, size_t,
void*)) override {
switch (option) {
case CURLOPT_WRITEFUNCTION:
write_callback = param;
break;
default:
break;
}
return CURLE_OK;
}
CURLcode curl_easy_perform(CURL* curl) override {
if (read_data) {
char buffer[100];
int bytes_read;
posted_content = "";
do {
bytes_read = fread(buffer, 1, 100, read_data);
posted_content =
strings::StrCat(posted_content, StringPiece(buffer, bytes_read));
} while (bytes_read > 0);
}
if (write_data) {
write_callback(response_content.c_str(), 1, response_content.size(),
write_data);
}
return CURLE_OK;
}
CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info,
uint64* value) override {
switch (info) {
case CURLINFO_RESPONSE_CODE:
*value = response_code;
break;
default:
break;
}
return CURLE_OK;
}
CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info,
double* value) override {
switch (info) {
case CURLINFO_SIZE_DOWNLOAD:
*value = response_content.size();
break;
default:
break;
}
return CURLE_OK;
}
void curl_easy_cleanup(CURL* curl) override { is_cleaned_up = true; }
curl_slist* curl_slist_append(curl_slist* list, const char* str) override {
std::vector<string>* v = list ? reinterpret_cast<std::vector<string>*>(list)
: new std::vector<string>();
v->push_back(str);
return reinterpret_cast<curl_slist*>(v);
}
void curl_slist_free_all(curl_slist* list) override {
delete reinterpret_cast<std::vector<string>*>(list);
}
// Variables defining the behavior of this fake.
string response_content;
uint64 response_code;
// Internal variables to store the libcurl state.
string url;
string range;
string custom_request;
char* error_buffer = nullptr;
bool is_initialized = false;
bool is_cleaned_up = false;
std::vector<string>* headers = nullptr;
FILE* read_data = nullptr;
bool is_post = false;
void* write_data = nullptr;
size_t (*write_callback)(const void* ptr, size_t size, size_t nmemb,
void* userdata) = nullptr;
// Outcome of performing the request.
string posted_content;
};
TEST(HttpRequestTest, GetRequest) {
FakeLibCurl* libcurl = new FakeLibCurl("get response", 200);
HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
TF_EXPECT_OK(http_request.Init());
char scratch[100] = "random original scratch content";
StringPiece result = "random original string piece";
TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer"));
TF_EXPECT_OK(http_request.SetRange(100, 199));
TF_EXPECT_OK(http_request.SetResultBuffer(scratch, 100, &result));
TF_EXPECT_OK(http_request.Send());
EXPECT_EQ("get response", result);
// Check interactions with libcurl.
EXPECT_TRUE(libcurl->is_initialized);
EXPECT_EQ("http://www.testuri.com", libcurl->url);
EXPECT_EQ("100-199", libcurl->range);
EXPECT_EQ("", libcurl->custom_request);
EXPECT_EQ(1, libcurl->headers->size());
EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl->headers)[0]);
EXPECT_FALSE(libcurl->is_post);
}
TEST(HttpRequestTest, PostRequest_WithBody) {
FakeLibCurl* libcurl = new FakeLibCurl("", 200);
HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
TF_EXPECT_OK(http_request.Init());
auto content_filename = io::JoinPath(testing::TmpDir(), "content");
std::ofstream content(content_filename, std::ofstream::binary);
content << "post body content";
content.close();
TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer"));
TF_EXPECT_OK(http_request.SetPostRequest(content_filename));
TF_EXPECT_OK(http_request.Send());
// Check interactions with libcurl.
EXPECT_TRUE(libcurl->is_initialized);
EXPECT_EQ("http://www.testuri.com", libcurl->url);
EXPECT_EQ("", libcurl->custom_request);
EXPECT_EQ(2, libcurl->headers->size());
EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl->headers)[0]);
EXPECT_EQ("Content-Length: 17", (*libcurl->headers)[1]);
EXPECT_TRUE(libcurl->is_post);
EXPECT_EQ("post body content", libcurl->posted_content);
std::remove(content_filename.c_str());
}
TEST(HttpRequestTest, PostRequest_WithoutBody) {
FakeLibCurl* libcurl = new FakeLibCurl("", 200);
HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
TF_EXPECT_OK(http_request.Init());
TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer"));
TF_EXPECT_OK(http_request.SetPostRequest());
TF_EXPECT_OK(http_request.Send());
// Check interactions with libcurl.
EXPECT_TRUE(libcurl->is_initialized);
EXPECT_EQ("http://www.testuri.com", libcurl->url);
EXPECT_EQ("", libcurl->custom_request);
EXPECT_EQ(2, libcurl->headers->size());
EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl->headers)[0]);
EXPECT_EQ("Content-Length: 0", (*libcurl->headers)[1]);
EXPECT_TRUE(libcurl->is_post);
EXPECT_EQ("", libcurl->posted_content);
}
TEST(HttpRequestTest, DeleteRequest) {
FakeLibCurl* libcurl = new FakeLibCurl("", 200);
HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
TF_EXPECT_OK(http_request.Init());
TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer"));
TF_EXPECT_OK(http_request.SetDeleteRequest());
TF_EXPECT_OK(http_request.Send());
// Check interactions with libcurl.
EXPECT_TRUE(libcurl->is_initialized);
EXPECT_EQ("http://www.testuri.com", libcurl->url);
EXPECT_EQ("DELETE", libcurl->custom_request);
EXPECT_EQ(1, libcurl->headers->size());
EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl->headers)[0]);
EXPECT_FALSE(libcurl->is_post);
}
TEST(HttpRequestTest, WrongSequenceOfCalls_NoUri) {
FakeLibCurl* libcurl = new FakeLibCurl("", 200);
HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
TF_EXPECT_OK(http_request.Init());
auto s = http_request.Send();
ASSERT_TRUE(errors::IsFailedPrecondition(s));
EXPECT_TRUE(StringPiece(s.error_message()).contains("URI has not been set"));
}
TEST(HttpRequestTest, WrongSequenceOfCalls_TwoSends) {
FakeLibCurl* libcurl = new FakeLibCurl("", 200);
HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
TF_EXPECT_OK(http_request.Init());
http_request.SetUri("http://www.google.com");
http_request.Send();
auto s = http_request.Send();
ASSERT_TRUE(errors::IsFailedPrecondition(s));
EXPECT_TRUE(StringPiece(s.error_message())
.contains("The request has already been sent"));
}
TEST(HttpRequestTest, WrongSequenceOfCalls_ReusingAfterSend) {
FakeLibCurl* libcurl = new FakeLibCurl("", 200);
HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
TF_EXPECT_OK(http_request.Init());
http_request.SetUri("http://www.google.com");
http_request.Send();
auto s = http_request.SetUri("http://mail.google.com");
ASSERT_TRUE(errors::IsFailedPrecondition(s));
EXPECT_TRUE(StringPiece(s.error_message())
.contains("The request has already been sent"));
}
TEST(HttpRequestTest, WrongSequenceOfCalls_SettingMethodTwice) {
FakeLibCurl* libcurl = new FakeLibCurl("", 200);
HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
TF_EXPECT_OK(http_request.Init());
http_request.SetDeleteRequest();
auto s = http_request.SetPostRequest();
ASSERT_TRUE(errors::IsFailedPrecondition(s));
EXPECT_TRUE(StringPiece(s.error_message())
.contains("HTTP method has been already set"));
}
TEST(HttpRequestTest, WrongSequenceOfCalls_NotInitialized) {
FakeLibCurl* libcurl = new FakeLibCurl("", 200);
HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
auto s = http_request.SetPostRequest();
ASSERT_TRUE(errors::IsFailedPrecondition(s));
EXPECT_TRUE(StringPiece(s.error_message())
.contains("The object has not been initialized"));
}
} // namespace
} // namespace tensorflow

View File

@ -3,6 +3,9 @@
load("//google/protobuf:protobuf.bzl", "cc_proto_library")
load("//google/protobuf:protobuf.bzl", "py_proto_library")
# configure may change the following line to True
WITH_GCP_SUPPORT = False
# Appends a suffix to a list of deps.
def tf_deps(deps, suffix):
tf_deps = []
@ -91,3 +94,7 @@ def tf_additional_test_srcs():
def tf_kernel_tests_linkstatic():
return 0
def tf_additional_lib_deps():
return (["//tensorflow/core/platform/cloud:gcs_file_system"]
if WITH_GCP_SUPPORT else [])

View File

@ -32,6 +32,7 @@ apt-get install -y \
gfortran \
libatlas-base-dev \
libblas-dev \
libcurl4-openssl-dev \
liblapack-dev \
libtool \
openjdk-8-jdk \

View File

@ -96,3 +96,15 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
name = "grpc_lib",
actual = "@grpc//:grpc++_unsecure",
)
native.new_git_repository(
name = "jsoncpp_git",
remote = "https://github.com/open-source-parsers/jsoncpp.git",
commit = "11086dd6a7eba04289944367ca82cea71299ed70",
build_file = path_prefix + "jsoncpp.BUILD",
)
native.bind(
name = "jsoncpp",
actual = "@jsoncpp_git//:jsoncpp",
)