[xprof:oss] Added a gRPC utility for retrieving default credentials.
* Created a directory of Bazel build macros used by the profiler. [go/xprof-oss-self-contained, go/tfsl] * profiler_client.cc * Removed redundant prefix "dns:///" which is already gRPC's default when not specified. Previously, it would prepend and result in a bad service address if URI had already been provided. * Added log points for address binding and channel failures. PiperOrigin-RevId: 327442840 Change-Id: Ia9e41bcbaff8d28267e0ca9ad8429151e51be3c4
This commit is contained in:
parent
654b45cd56
commit
ad66131d3a
10
tensorflow/core/profiler/builds/BUILD
Normal file
10
tensorflow/core/profiler/builds/BUILD
Normal file
@ -0,0 +1,10 @@
|
||||
package(
|
||||
default_visibility = ["//tensorflow/core/profiler:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# ONLY FOR DEV TESTING. DO NOT USE IF YOU DO NOT KNOW ABOUT IT ALREADY.
|
||||
config_setting(
|
||||
name = "profiler_build_oss",
|
||||
values = {"define": "profiler_build=oss"},
|
||||
)
|
14
tensorflow/core/profiler/builds/build_config.bzl
Normal file
14
tensorflow/core/profiler/builds/build_config.bzl
Normal file
@ -0,0 +1,14 @@
|
||||
"""Provides a redirection point for platform specific implementations of Starlark utilities."""
|
||||
|
||||
load(
|
||||
"//tensorflow/core/profiler/builds/oss:build_config.bzl",
|
||||
_tf_profiler_alias = "tf_profiler_alias",
|
||||
)
|
||||
|
||||
tf_profiler_alias = _tf_profiler_alias
|
||||
|
||||
def if_profiler_oss(if_true, if_false = []):
|
||||
return select({
|
||||
"//tensorflow/core/profiler/builds:profiler_build_oss": if_true,
|
||||
"//conditions:default": if_false,
|
||||
})
|
8
tensorflow/core/profiler/builds/oss/BUILD
Normal file
8
tensorflow/core/profiler/builds/oss/BUILD
Normal file
@ -0,0 +1,8 @@
|
||||
# Tensorflow default + linux implementations of tensorflow/core/profiler libraries.
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/core/profiler:internal",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
7
tensorflow/core/profiler/builds/oss/build_config.bzl
Normal file
7
tensorflow/core/profiler/builds/oss/build_config.bzl
Normal file
@ -0,0 +1,7 @@
|
||||
# Platform-specific build configurations.
|
||||
"""
|
||||
TF profiler build macros for use in OSS.
|
||||
"""
|
||||
|
||||
def tf_profiler_alias(target_dir, name):
|
||||
return target_dir + "oss:" + name
|
@ -1,11 +1,31 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_external_workspace_visible") # buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") # buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper") # buildifier: disable=same-origin-load
|
||||
load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_alias")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/core/profiler:internal",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc",
|
||||
hdrs = ["grpc.h"],
|
||||
deps = [
|
||||
tf_profiler_alias("//tensorflow/core/profiler/rpc/", "grpc"),
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
[
|
||||
"grpc.h",
|
||||
],
|
||||
visibility = ["//tensorflow/core/profiler/rpc:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "profiler_service_impl",
|
||||
srcs = ["profiler_service_impl.cc"],
|
||||
@ -38,6 +58,7 @@ cc_library(
|
||||
"//tensorflow/python/profiler/internal:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":grpc",
|
||||
":profiler_service_impl",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/profiler:profiler_service_proto_cc",
|
||||
|
@ -56,6 +56,7 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler:profiler_analysis_proto_cc",
|
||||
"//tensorflow/core/profiler:profiler_service_proto_cc",
|
||||
"//tensorflow/core/profiler/rpc:grpc",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
alwayslink = True,
|
||||
|
@ -18,8 +18,10 @@ limitations under the License.
|
||||
|
||||
#include "grpcpp/grpcpp.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/rpc/grpc.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -36,9 +38,13 @@ template <typename T>
|
||||
std::unique_ptr<typename T::Stub> CreateStub(const std::string& service_addr) {
|
||||
::grpc::ChannelArguments channel_args;
|
||||
channel_args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
|
||||
return T::NewStub(::grpc::CreateCustomChannel(
|
||||
"dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
|
||||
channel_args));
|
||||
// Default URI prefix is "dns:///" if not provided.
|
||||
auto channel = ::grpc::CreateCustomChannel(
|
||||
service_addr, ::grpc::InsecureChannelCredentials(), channel_args);
|
||||
if (!channel) {
|
||||
LOG(ERROR) << "Unable to create channel" << service_addr;
|
||||
}
|
||||
return T::NewStub(channel);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
37
tensorflow/core/profiler/rpc/grpc.h
Normal file
37
tensorflow/core/profiler/rpc/grpc.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// GRPC utilities
|
||||
|
||||
#ifndef TENSORFLOW_CORE_PROFILER_COMMON_GRPC_GRPC_H_
|
||||
#define TENSORFLOW_CORE_PROFILER_COMMON_GRPC_GRPC_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "grpcpp/security/credentials.h"
|
||||
#include "grpcpp/security/server_credentials.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace profiler {
|
||||
|
||||
// Returns default credentials for use when creating a gRPC server.
|
||||
std::shared_ptr<::grpc::ServerCredentials> GetDefaultServerCredentials();
|
||||
|
||||
// Returns default credentials for use when creating a gRPC channel.
|
||||
std::shared_ptr<::grpc::ChannelCredentials> GetDefaultChannelCredentials();
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_PROFILER_COMMON_GRPC_GRPC_H_
|
27
tensorflow/core/profiler/rpc/oss/BUILD
Normal file
27
tensorflow/core/profiler/rpc/oss/BUILD
Normal file
@ -0,0 +1,27 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/core/profiler:internal",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc",
|
||||
srcs = [
|
||||
"grpc.cc",
|
||||
"//tensorflow/core/profiler/rpc:grpc.h",
|
||||
],
|
||||
deps = [
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
exports_files(
|
||||
[
|
||||
"grpc.cc",
|
||||
],
|
||||
visibility = ["//tensorflow/core/profiler/rpc:__subpackages__"],
|
||||
)
|
30
tensorflow/core/profiler/rpc/oss/grpc.cc
Normal file
30
tensorflow/core/profiler/rpc/oss/grpc.cc
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/profiler/rpc/grpc.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace profiler {
|
||||
|
||||
std::shared_ptr<::grpc::ServerCredentials> GetDefaultServerCredentials() {
|
||||
return ::grpc::InsecureServerCredentials();
|
||||
}
|
||||
|
||||
std::shared_ptr<::grpc::ChannelCredentials> GetDefaultChannelCredentials() {
|
||||
return ::grpc::InsecureChannelCredentials();
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace tensorflow
|
@ -23,18 +23,28 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
|
||||
#include "tensorflow/core/profiler/rpc/grpc.h"
|
||||
#include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void ProfilerServer::StartProfilerServer(int32 port) {
|
||||
std::string server_address = absl::StrCat("0.0.0.0:", port);
|
||||
std::string server_address = absl::StrCat("[::]:", port);
|
||||
service_ = CreateProfilerService();
|
||||
::grpc::ServerBuilder builder;
|
||||
builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials());
|
||||
|
||||
int selected_port = 0;
|
||||
builder.AddListeningPort(
|
||||
server_address, profiler::GetDefaultServerCredentials(), &selected_port);
|
||||
builder.RegisterService(service_.get());
|
||||
server_ = builder.BuildAndStart();
|
||||
LOG(INFO) << "Profiling Server listening on " << server_address;
|
||||
if (!selected_port) {
|
||||
LOG(ERROR) << "Unable to bind to " << server_address << ":"
|
||||
<< selected_port;
|
||||
} else {
|
||||
LOG(INFO) << "Profiling Server listening on " << server_address << ":"
|
||||
<< selected_port;
|
||||
}
|
||||
}
|
||||
|
||||
ProfilerServer::~ProfilerServer() {
|
||||
|
Loading…
x
Reference in New Issue
Block a user