Enable platform selection in GRPC service.
PiperOrigin-RevId: 214120578
This commit is contained in:
parent
adea2433eb
commit
1a8dd7910e
@ -41,6 +41,7 @@ cc_library(
|
|||||||
":grpc_service",
|
":grpc_service",
|
||||||
"//tensorflow:grpc++",
|
"//tensorflow:grpc++",
|
||||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||||
|
"//tensorflow/compiler/xla/service:platform_util",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "grpcpp/server_builder.h"
|
#include "grpcpp/server_builder.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "tensorflow/compiler/xla/rpc/grpc_service.h"
|
#include "tensorflow/compiler/xla/rpc/grpc_service.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
#include "tensorflow/core/platform/init_main.h"
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
@ -30,7 +31,10 @@ namespace {
|
|||||||
int RealMain(int argc, char** argv) {
|
int RealMain(int argc, char** argv) {
|
||||||
int32 port = 1685;
|
int32 port = 1685;
|
||||||
bool any_address = false;
|
bool any_address = false;
|
||||||
|
string platform_str;
|
||||||
std::vector<tensorflow::Flag> flag_list = {
|
std::vector<tensorflow::Flag> flag_list = {
|
||||||
|
tensorflow::Flag("platform", &platform_str,
|
||||||
|
"The XLA platform this service should be bound to"),
|
||||||
tensorflow::Flag("port", &port, "The TCP port to listen on"),
|
tensorflow::Flag("port", &port, "The TCP port to listen on"),
|
||||||
tensorflow::Flag(
|
tensorflow::Flag(
|
||||||
"any", &any_address,
|
"any", &any_address,
|
||||||
@ -44,8 +48,12 @@ int RealMain(int argc, char** argv) {
|
|||||||
}
|
}
|
||||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||||
|
|
||||||
|
se::Platform* platform = nullptr;
|
||||||
|
if (!platform_str.empty()) {
|
||||||
|
platform = PlatformUtil::GetPlatform(platform_str).ValueOrDie();
|
||||||
|
}
|
||||||
std::unique_ptr<xla::GRPCService> service =
|
std::unique_ptr<xla::GRPCService> service =
|
||||||
xla::GRPCService::NewService().ConsumeValueOrDie();
|
xla::GRPCService::NewService(platform).ConsumeValueOrDie();
|
||||||
|
|
||||||
::grpc::ServerBuilder builder;
|
::grpc::ServerBuilder builder;
|
||||||
string server_address(
|
string server_address(
|
||||||
|
Loading…
Reference in New Issue
Block a user