Enable platform selection in GRPC service.
PiperOrigin-RevId: 214120578
This commit is contained in:
parent
adea2433eb
commit
1a8dd7910e
@ -41,6 +41,7 @@ cc_library(
|
||||
":grpc_service",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "grpcpp/server_builder.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/compiler/xla/rpc/grpc_service.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
@ -30,7 +31,10 @@ namespace {
|
||||
int RealMain(int argc, char** argv) {
|
||||
int32 port = 1685;
|
||||
bool any_address = false;
|
||||
string platform_str;
|
||||
std::vector<tensorflow::Flag> flag_list = {
|
||||
tensorflow::Flag("platform", &platform_str,
|
||||
"The XLA platform this service should be bound to"),
|
||||
tensorflow::Flag("port", &port, "The TCP port to listen on"),
|
||||
tensorflow::Flag(
|
||||
"any", &any_address,
|
||||
@ -44,8 +48,12 @@ int RealMain(int argc, char** argv) {
|
||||
}
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
|
||||
se::Platform* platform = nullptr;
|
||||
if (!platform_str.empty()) {
|
||||
platform = PlatformUtil::GetPlatform(platform_str).ValueOrDie();
|
||||
}
|
||||
std::unique_ptr<xla::GRPCService> service =
|
||||
xla::GRPCService::NewService().ConsumeValueOrDie();
|
||||
xla::GRPCService::NewService(platform).ConsumeValueOrDie();
|
||||
|
||||
::grpc::ServerBuilder builder;
|
||||
string server_address(
|
||||
|
Loading…
Reference in New Issue
Block a user