Enable platform selection in GRPC service.

PiperOrigin-RevId: 214120578
This commit is contained in:
A. Unique TensorFlower 2018-09-22 10:31:37 -07:00 committed by TensorFlower Gardener
parent adea2433eb
commit 1a8dd7910e
2 changed files with 10 additions and 1 deletions

View File

@ -41,6 +41,7 @@ cc_library(
":grpc_service",
"//tensorflow:grpc++",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings:str_format",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "grpcpp/server_builder.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/rpc/grpc_service.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
@ -30,7 +31,10 @@ namespace {
int RealMain(int argc, char** argv) {
int32 port = 1685;
bool any_address = false;
string platform_str;
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("platform", &platform_str,
"The XLA platform this service should be bound to"),
tensorflow::Flag("port", &port, "The TCP port to listen on"),
tensorflow::Flag(
"any", &any_address,
@ -44,8 +48,12 @@ int RealMain(int argc, char** argv) {
}
tensorflow::port::InitMain(argv[0], &argc, &argv);
se::Platform* platform = nullptr;
if (!platform_str.empty()) {
platform = PlatformUtil::GetPlatform(platform_str).ValueOrDie();
}
std::unique_ptr<xla::GRPCService> service =
xla::GRPCService::NewService().ConsumeValueOrDie();
xla::GRPCService::NewService(platform).ConsumeValueOrDie();
::grpc::ServerBuilder builder;
string server_address(