[tf.data service] Add framework for pluggable credentials.
PiperOrigin-RevId: 302069447 Change-Id: Ifd01c60826967b03693395668672bf836dd40f61
This commit is contained in:
parent
70d1d30b5a
commit
f8da7c2b15
@ -78,6 +78,29 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "credentials_factory",
|
||||
srcs = ["credentials_factory.cc"],
|
||||
hdrs = ["credentials_factory.h"],
|
||||
deps = [
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "credentials_factory_test",
|
||||
srcs = ["credentials_factory_test.cc"],
|
||||
deps = [
|
||||
":credentials_factory",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_grpc_library(
|
||||
name = "master_cc_grpc_proto",
|
||||
srcs = [":master_proto"],
|
||||
|
111
tensorflow/core/data/service/credentials_factory.cc
Normal file
111
tensorflow/core/data/service/credentials_factory.cc
Normal file
@ -0,0 +1,111 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
namespace {
|
||||
mutex* get_lock() {
|
||||
static mutex lock(LINKER_INITIALIZED);
|
||||
return &lock;
|
||||
}
|
||||
|
||||
using CredentialsFactories =
|
||||
std::unordered_map<std::string, CredentialsFactory*>;
|
||||
CredentialsFactories& credentials_factories() {
|
||||
static auto& factories = *new CredentialsFactories();
|
||||
return factories;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void CredentialsFactory::Register(CredentialsFactory* factory) {
|
||||
mutex_lock l(*get_lock());
|
||||
if (!credentials_factories().insert({factory->Protocol(), factory}).second) {
|
||||
LOG(ERROR)
|
||||
<< "Two credentials factories are being registered with protocol "
|
||||
<< factory->Protocol() << ". Which one gets used is undefined.";
|
||||
}
|
||||
}
|
||||
|
||||
Status CredentialsFactory::Get(absl::string_view protocol,
|
||||
CredentialsFactory** out) {
|
||||
mutex_lock l(*get_lock());
|
||||
auto it = credentials_factories().find(std::string(protocol));
|
||||
if (it != credentials_factories().end()) {
|
||||
*out = it->second;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<string> available_types;
|
||||
for (const auto& factory : credentials_factories()) {
|
||||
available_types.push_back(factory.first);
|
||||
}
|
||||
|
||||
return errors::NotFound("No credentials factory has been registered for ",
|
||||
"protocol ", protocol,
|
||||
". The available types are: [ ",
|
||||
absl::StrJoin(available_types, ", "), " ]");
|
||||
}
|
||||
|
||||
Status CredentialsFactory::CreateServerCredentials(
|
||||
absl::string_view protocol, std::shared_ptr<grpc::ServerCredentials>* out) {
|
||||
CredentialsFactory* factory;
|
||||
TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, &factory));
|
||||
TF_RETURN_IF_ERROR(factory->CreateServerCredentials(out));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CredentialsFactory::CreateClientCredentials(
|
||||
absl::string_view protocol,
|
||||
std::shared_ptr<grpc::ChannelCredentials>* out) {
|
||||
CredentialsFactory* factory;
|
||||
TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, &factory));
|
||||
TF_RETURN_IF_ERROR(factory->CreateClientCredentials(out));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class InsecureCredentialsFactory : public CredentialsFactory {
|
||||
public:
|
||||
std::string Protocol() override { return "grpc"; }
|
||||
|
||||
Status CreateServerCredentials(
|
||||
std::shared_ptr<grpc::ServerCredentials>* out) override {
|
||||
*out = grpc::InsecureServerCredentials();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CreateClientCredentials(
|
||||
std::shared_ptr<grpc::ChannelCredentials>* out) override {
|
||||
*out = grpc::InsecureChannelCredentials();
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
class InsecureCredentialsRegistrar {
|
||||
public:
|
||||
InsecureCredentialsRegistrar() {
|
||||
auto factory = new InsecureCredentialsFactory();
|
||||
CredentialsFactory::Register(factory);
|
||||
}
|
||||
};
|
||||
static InsecureCredentialsRegistrar registrar;
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
69
tensorflow/core/data/service/credentials_factory.h
Normal file
69
tensorflow/core/data/service/credentials_factory.h
Normal file
@ -0,0 +1,69 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_CREDENTIALS_FACTORY_H_
|
||||
#define TENSORFLOW_CORE_DATA_SERVICE_CREDENTIALS_FACTORY_H_
|
||||
|
||||
#include "grpcpp/grpcpp.h"
|
||||
#include "grpcpp/security/credentials.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// Credential factory implementations should be threadsafe since all callers
|
||||
// to `GetCredentials` will get the same instance of `CredentialsFactory`.
|
||||
class CredentialsFactory {
|
||||
public:
|
||||
virtual ~CredentialsFactory() = default;
|
||||
|
||||
// Returns a protocol name for the credentials factory. This is the string to
|
||||
// look up with `GetCredentials` to find the registered credentials factory.
|
||||
virtual std::string Protocol() = 0;
|
||||
|
||||
// Stores server credentials to `*out`.
|
||||
virtual Status CreateServerCredentials(
|
||||
std::shared_ptr<grpc::ServerCredentials>* out) = 0;
|
||||
|
||||
// Stores client credentials to `*out`.
|
||||
virtual Status CreateClientCredentials(
|
||||
std::shared_ptr<grpc::ChannelCredentials>* out) = 0;
|
||||
|
||||
// Registers a credentials factory.
|
||||
static void Register(CredentialsFactory* factory);
|
||||
|
||||
// Creates server credentials using the credentials factory registered as
|
||||
// `protocol`, and stores them to `*out`.
|
||||
static Status CreateServerCredentials(
|
||||
absl::string_view protocol,
|
||||
std::shared_ptr<grpc::ServerCredentials>* out);
|
||||
|
||||
// Creates client credentials using the credentials factory registered as
|
||||
// `protocol`, and stores them to `*out`.
|
||||
static Status CreateClientCredentials(
|
||||
absl::string_view protocol,
|
||||
std::shared_ptr<grpc::ChannelCredentials>* out);
|
||||
|
||||
private:
|
||||
// Gets the credentials factory registered via `Register` for the specified
|
||||
// protocol, and stores it to `*out`.
|
||||
static Status Get(const absl::string_view protocol, CredentialsFactory** out);
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DATA_SERVICE_CREDENTIALS_FACTORY_H_
|
91
tensorflow/core/data/service/credentials_factory_test.cc
Normal file
91
tensorflow/core/data/service/credentials_factory_test.cc
Normal file
@ -0,0 +1,91 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
namespace {
|
||||
constexpr char kFailedToCreateServerCredentials[] =
|
||||
"Failed to create server credentials.";
|
||||
constexpr char kFailedToCreateClientCredentials[] =
|
||||
"Failed to create client credentials.";
|
||||
|
||||
class TestCredentialsFactory : public CredentialsFactory {
|
||||
public:
|
||||
std::string Protocol() override { return "test"; }
|
||||
|
||||
Status CreateServerCredentials(
|
||||
std::shared_ptr<grpc::ServerCredentials>* out) override {
|
||||
return errors::Internal(kFailedToCreateServerCredentials);
|
||||
}
|
||||
|
||||
Status CreateClientCredentials(
|
||||
std::shared_ptr<grpc::ChannelCredentials>* out) override {
|
||||
return errors::Internal(kFailedToCreateClientCredentials);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST(CredentialsFactory, Register) {
|
||||
TestCredentialsFactory test_factory;
|
||||
CredentialsFactory::Register(&test_factory);
|
||||
std::shared_ptr<grpc::ServerCredentials> server_credentials;
|
||||
ASSERT_EQ(errors::Internal(kFailedToCreateServerCredentials),
|
||||
CredentialsFactory::CreateServerCredentials(test_factory.Protocol(),
|
||||
&server_credentials));
|
||||
std::shared_ptr<grpc::ChannelCredentials> client_credentials;
|
||||
ASSERT_EQ(errors::Internal(kFailedToCreateClientCredentials),
|
||||
CredentialsFactory::CreateClientCredentials(test_factory.Protocol(),
|
||||
&client_credentials));
|
||||
}
|
||||
|
||||
TEST(CredentialsFactory, DefaultGrpcProtocol) {
|
||||
std::shared_ptr<grpc::ServerCredentials> server_credentials;
|
||||
TF_ASSERT_OK(
|
||||
CredentialsFactory::CreateServerCredentials("grpc", &server_credentials));
|
||||
std::shared_ptr<grpc::ChannelCredentials> client_credentials;
|
||||
TF_ASSERT_OK(
|
||||
CredentialsFactory::CreateClientCredentials("grpc", &client_credentials));
|
||||
}
|
||||
|
||||
TEST(CredentialsFactory, MissingServerProtocol) {
|
||||
std::shared_ptr<grpc::ServerCredentials> server_credentials;
|
||||
Status s = CredentialsFactory::CreateServerCredentials("unknown_protocol",
|
||||
&server_credentials);
|
||||
ASSERT_EQ(error::Code::NOT_FOUND, s.code());
|
||||
ASSERT_TRUE(
|
||||
absl::StrContains(s.ToString(),
|
||||
"No credentials factory has been registered for "
|
||||
"protocol unknown_protocol"));
|
||||
}
|
||||
|
||||
TEST(CredentialsFactory, MissingClientProtocol) {
|
||||
std::shared_ptr<grpc::ChannelCredentials> client_credentials;
|
||||
Status s = CredentialsFactory::CreateClientCredentials("unknown_protocol",
|
||||
&client_credentials);
|
||||
ASSERT_EQ(error::Code::NOT_FOUND, s.code());
|
||||
ASSERT_TRUE(
|
||||
absl::StrContains(s.ToString(),
|
||||
"No credentials factory has been registered for "
|
||||
"protocol unknown_protocol"));
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
Loading…
x
Reference in New Issue
Block a user