[xprof:oss] Added a gRPC utility for retrieving default credentials.

* Created a directory of Bazel build macros used by the profiler. [go/xprof-oss-self-contained, go/tfsl]
* profiler_client.cc
  * Removed redundant prefix "dns:///" which is already gRPC's default when not specified. Previously, it would prepend and result in a bad service address if URI had already been provided.
* Added log points for address binding and channel failures.

PiperOrigin-RevId: 327442840
Change-Id: Ia9e41bcbaff8d28267e0ca9ad8429151e51be3c4
This commit is contained in:
Yi Situ 2020-08-19 09:01:13 -07:00 committed by TensorFlower Gardener
parent 654b45cd56
commit ad66131d3a
11 changed files with 177 additions and 6 deletions

View File

@ -0,0 +1,10 @@
package(
default_visibility = ["//tensorflow/core/profiler:internal"],
licenses = ["notice"], # Apache 2.0
)
# ONLY FOR DEV TESTING. DO NOT USE IF YOU DO NOT KNOW ABOUT IT ALREADY.
config_setting(
name = "profiler_build_oss",
values = {"define": "profiler_build=oss"},
)

View File

@ -0,0 +1,14 @@
"""Provides a redirection point for platform specific implementations of Starlark utilities."""
load(
"//tensorflow/core/profiler/builds/oss:build_config.bzl",
_tf_profiler_alias = "tf_profiler_alias",
)
tf_profiler_alias = _tf_profiler_alias
def if_profiler_oss(if_true, if_false = []):
return select({
"//tensorflow/core/profiler/builds:profiler_build_oss": if_true,
"//conditions:default": if_false,
})

View File

@ -0,0 +1,8 @@
# Tensorflow default + linux implementations of tensorflow/core/profiler libraries.
package(
default_visibility = [
"//tensorflow/core/profiler:internal",
],
licenses = ["notice"], # Apache 2.0
)

View File

@ -0,0 +1,7 @@
# Platform-specific build configurations.
"""
TF profiler build macros for use in OSS.
"""
def tf_profiler_alias(target_dir, name):
return target_dir + "oss:" + name

View File

@ -1,11 +1,31 @@
load("//tensorflow:tensorflow.bzl", "tf_external_workspace_visible") # buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") # buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper") # buildifier: disable=same-origin-load
load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_alias")
package(
default_visibility = [
"//tensorflow/core/profiler:internal",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "grpc",
hdrs = ["grpc.h"],
deps = [
tf_profiler_alias("//tensorflow/core/profiler/rpc/", "grpc"),
tf_grpc_cc_dependency(),
],
)
exports_files(
[
"grpc.h",
],
visibility = ["//tensorflow/core/profiler/rpc:__subpackages__"],
)
cc_library(
name = "profiler_service_impl",
srcs = ["profiler_service_impl.cc"],
@ -38,6 +58,7 @@ cc_library(
"//tensorflow/python/profiler/internal:__pkg__",
],
deps = [
":grpc",
":profiler_service_impl",
"//tensorflow/core:lib",
"//tensorflow/core/profiler:profiler_service_proto_cc",

View File

@ -56,6 +56,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler:profiler_analysis_proto_cc",
"//tensorflow/core/profiler:profiler_service_proto_cc",
"//tensorflow/core/profiler/rpc:grpc",
tf_grpc_cc_dependency(),
],
alwayslink = True,

View File

@ -18,8 +18,10 @@ limitations under the License.
#include "grpcpp/grpcpp.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/rpc/grpc.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace tensorflow {
@ -36,9 +38,13 @@ template <typename T>
std::unique_ptr<typename T::Stub> CreateStub(const std::string& service_addr) {
::grpc::ChannelArguments channel_args;
channel_args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
return T::NewStub(::grpc::CreateCustomChannel(
"dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
channel_args));
// Default URI prefix is "dns:///" if not provided.
auto channel = ::grpc::CreateCustomChannel(
service_addr, ::grpc::InsecureChannelCredentials(), channel_args);
if (!channel) {
LOG(ERROR) << "Unable to create channel" << service_addr;
}
return T::NewStub(channel);
}
} // namespace

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// GRPC utilities
#ifndef TENSORFLOW_CORE_PROFILER_COMMON_GRPC_GRPC_H_
#define TENSORFLOW_CORE_PROFILER_COMMON_GRPC_GRPC_H_
#include <memory>
#include "grpcpp/security/credentials.h"
#include "grpcpp/security/server_credentials.h"
namespace tensorflow {
namespace profiler {
// Returns default credentials for use when creating a gRPC server.
std::shared_ptr<::grpc::ServerCredentials> GetDefaultServerCredentials();
// Returns default credentials for use when creating a gRPC channel.
std::shared_ptr<::grpc::ChannelCredentials> GetDefaultChannelCredentials();
} // namespace profiler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_PROFILER_COMMON_GRPC_GRPC_H_

View File

@ -0,0 +1,27 @@
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
package(
default_visibility = [
"//tensorflow/core/profiler:internal",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "grpc",
srcs = [
"grpc.cc",
"//tensorflow/core/profiler/rpc:grpc.h",
],
deps = [
tf_grpc_cc_dependency(),
],
alwayslink = True,
)
exports_files(
[
"grpc.cc",
],
visibility = ["//tensorflow/core/profiler/rpc:__subpackages__"],
)

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/profiler/rpc/grpc.h"
namespace tensorflow {
namespace profiler {
std::shared_ptr<::grpc::ServerCredentials> GetDefaultServerCredentials() {
return ::grpc::InsecureServerCredentials();
}
std::shared_ptr<::grpc::ChannelCredentials> GetDefaultChannelCredentials() {
return ::grpc::InsecureChannelCredentials();
}
} // namespace profiler
} // namespace tensorflow

View File

@ -23,18 +23,28 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
#include "tensorflow/core/profiler/rpc/grpc.h"
#include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
namespace tensorflow {
void ProfilerServer::StartProfilerServer(int32 port) {
std::string server_address = absl::StrCat("0.0.0.0:", port);
std::string server_address = absl::StrCat("[::]:", port);
service_ = CreateProfilerService();
::grpc::ServerBuilder builder;
builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials());
int selected_port = 0;
builder.AddListeningPort(
server_address, profiler::GetDefaultServerCredentials(), &selected_port);
builder.RegisterService(service_.get());
server_ = builder.BuildAndStart();
LOG(INFO) << "Profiling Server listening on " << server_address;
if (!selected_port) {
LOG(ERROR) << "Unable to bind to " << server_address << ":"
<< selected_port;
} else {
LOG(INFO) << "Profiling Server listening on " << server_address << ":"
<< selected_port;
}
}
ProfilerServer::~ProfilerServer() {