TpuDriver: Switch configuration object to protobuf for easier serialization/parsing.

PiperOrigin-RevId: 281329177
Change-Id: I01b767f2415d0e469d6b1166d96918f72755c0f8
This commit is contained in:
Russell Power 2019-11-19 10:49:08 -08:00 committed by TensorFlower Gardener
parent 9182431be6
commit a6750e3dc5
6 changed files with 30 additions and 22 deletions

View File

@ -48,7 +48,7 @@ static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
const std::string& worker) {
tpu_driver::TpuDriverConfig driver_config;
driver_config.worker = worker;
driver_config.set_worker(worker);
auto client_status = tpu_driver::TpuDriverRegistry::Open(driver_config);
if (!client_status.ok()) {

View File

@ -907,8 +907,8 @@ GrpcTpuDriver::CreateTpuDriverStub(const TpuDriverConfig& config) {
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
// Send at least 20 keep-alives before giving up.
int keepalive_timeout_ms = config.keepalive_timeout_secs * 1000;
int keepalive_interval_ms = config.keepalive_timeout_secs / 20;
int keepalive_timeout_ms = config.grpc().keepalive_timeout_secs() * 1000;
int keepalive_interval_ms = keepalive_timeout_ms / 20;
grpc_arg client_arg_vals[] = {
{.type = GRPC_ARG_INTEGER,
@ -935,7 +935,7 @@ GrpcTpuDriver::CreateTpuDriverStub(const TpuDriverConfig& config) {
args.SetChannelArgs(&client_args);
// strips out 'grpc://'
auto worker_addr = absl::StripPrefix(config.worker, kGrpcProtocol);
auto worker_addr = absl::StripPrefix(config.worker(), kGrpcProtocol);
std::shared_ptr<::grpc::Channel> channel =
::grpc::CreateCustomChannel(std::string(worker_addr), creds, args);
return grpc::CloudTpuDriver::NewStub(channel);
@ -977,8 +977,9 @@ REGISTER_TPU_DRIVER(
auto stub = GrpcTpuDriver::CreateTpuDriverStub(config);
::grpc::ClientContext ctx;
ctx.set_fail_fast(false);
ctx.set_deadline(std::chrono::system_clock::now() +
std::chrono::seconds(config.connection_timeout_secs));
ctx.set_deadline(
std::chrono::system_clock::now() +
std::chrono::seconds(config.grpc().connection_timeout_secs()));
OpenRequest req;
OpenResponse resp;
::grpc::Status status = stub->Open(&ctx, req, &resp);
@ -988,7 +989,7 @@ REGISTER_TPU_DRIVER(
return xla::Status(
tensorflow::error::Code(status.error_code()),
absl::StrCat("Failed to connect to remote server at address: ",
config.worker,
config.worker(),
". Error from gRPC: ", status.error_details()));
}
return std::unique_ptr<TpuDriver>(

View File

@ -516,7 +516,7 @@ class RecordingTpuDriver : public TpuDriver {
xla::StatusOr<std::unique_ptr<TpuDriver>> RegisterRecordingTpuDriver(
const TpuDriverConfig& config) {
std::vector<std::string> configs = absl::StrSplit(config.worker, '|');
std::vector<std::string> configs = absl::StrSplit(config.worker(), '|');
std::string file;
std::string worker;
@ -533,7 +533,7 @@ xla::StatusOr<std::unique_ptr<TpuDriver>> RegisterRecordingTpuDriver(
}
TpuDriverConfig worker_config;
worker_config.worker = worker;
worker_config.set_worker(worker);
auto driver_status = TpuDriverRegistry::Open(worker_config);
if (!driver_status.ok()) return driver_status.status();

View File

@ -88,12 +88,12 @@ uint64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) {
/*static*/ xla::StatusOr<std::unique_ptr<TpuDriver>> TpuDriverRegistry::Open(
const TpuDriverConfig& config) {
for (const auto& driver : *GetDriverRegistryMap()) {
if (absl::StartsWith(config.worker, driver.first)) {
if (absl::StartsWith(config.worker(), driver.first)) {
return driver.second(config);
}
}
return xla::NotFound("Unable to find driver in registry given worker: %s",
config.worker);
config.worker());
}
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {

View File

@ -37,6 +37,9 @@
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
// This API is EXPERIMENTAL and under active developement. It is subject to
// change without notice.
namespace tpu_driver {
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape);
@ -227,17 +230,6 @@ class TpuDriver {
virtual std::unique_ptr<TpuLinearizer> GetLinearizer() { return nullptr; }
};
struct TpuDriverConfig {
std::string worker;
// Time in seconds before the initial connection to the server will timeout.
int64_t connection_timeout_secs = 10;
// Time in seconds the server may be unresponsive before terminating the
// connection.
int64_t keepalive_timeout_secs = 30;
};
class TpuDriverRegistry {
public:
static xla::StatusOr<std::unique_ptr<TpuDriver>> Open(

View File

@ -41,3 +41,18 @@ message SystemInfo {
repeated TpuChipInfo tpu_chip = 1;
required CpuInfo cpu = 2;
}
message TpuDriverConfig {
optional string worker = 1;
message GrpcConfig {
// Time in seconds before the initial connection to the server will timeout.
optional int64 connection_timeout_secs = 1 [default = 30];
// Time in seconds the server may be unresponsive before terminating the
// connection.
optional int64 keepalive_timeout_secs = 2 [default = 30];
}
optional GrpcConfig grpc = 2;
}