TpuDriver: Switch configuration object to protobuf for easier serialization/parsing.
PiperOrigin-RevId: 281329177 Change-Id: I01b767f2415d0e469d6b1166d96918f72755c0f8
This commit is contained in:
parent
9182431be6
commit
a6750e3dc5
|
@ -48,7 +48,7 @@ static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
|
|||
StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
|
||||
const std::string& worker) {
|
||||
tpu_driver::TpuDriverConfig driver_config;
|
||||
driver_config.worker = worker;
|
||||
driver_config.set_worker(worker);
|
||||
auto client_status = tpu_driver::TpuDriverRegistry::Open(driver_config);
|
||||
|
||||
if (!client_status.ok()) {
|
||||
|
|
|
@ -907,8 +907,8 @@ GrpcTpuDriver::CreateTpuDriverStub(const TpuDriverConfig& config) {
|
|||
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
|
||||
|
||||
// Send at least 20 keep-alives before giving up.
|
||||
int keepalive_timeout_ms = config.keepalive_timeout_secs * 1000;
|
||||
int keepalive_interval_ms = config.keepalive_timeout_secs / 20;
|
||||
int keepalive_timeout_ms = config.grpc().keepalive_timeout_secs() * 1000;
|
||||
int keepalive_interval_ms = keepalive_timeout_ms / 20;
|
||||
|
||||
grpc_arg client_arg_vals[] = {
|
||||
{.type = GRPC_ARG_INTEGER,
|
||||
|
@ -935,7 +935,7 @@ GrpcTpuDriver::CreateTpuDriverStub(const TpuDriverConfig& config) {
|
|||
args.SetChannelArgs(&client_args);
|
||||
|
||||
// strips out 'grpc://'
|
||||
auto worker_addr = absl::StripPrefix(config.worker, kGrpcProtocol);
|
||||
auto worker_addr = absl::StripPrefix(config.worker(), kGrpcProtocol);
|
||||
std::shared_ptr<::grpc::Channel> channel =
|
||||
::grpc::CreateCustomChannel(std::string(worker_addr), creds, args);
|
||||
return grpc::CloudTpuDriver::NewStub(channel);
|
||||
|
@ -977,8 +977,9 @@ REGISTER_TPU_DRIVER(
|
|||
auto stub = GrpcTpuDriver::CreateTpuDriverStub(config);
|
||||
::grpc::ClientContext ctx;
|
||||
ctx.set_fail_fast(false);
|
||||
ctx.set_deadline(std::chrono::system_clock::now() +
|
||||
std::chrono::seconds(config.connection_timeout_secs));
|
||||
ctx.set_deadline(
|
||||
std::chrono::system_clock::now() +
|
||||
std::chrono::seconds(config.grpc().connection_timeout_secs()));
|
||||
OpenRequest req;
|
||||
OpenResponse resp;
|
||||
::grpc::Status status = stub->Open(&ctx, req, &resp);
|
||||
|
@ -988,7 +989,7 @@ REGISTER_TPU_DRIVER(
|
|||
return xla::Status(
|
||||
tensorflow::error::Code(status.error_code()),
|
||||
absl::StrCat("Failed to connect to remote server at address: ",
|
||||
config.worker,
|
||||
config.worker(),
|
||||
". Error from gRPC: ", status.error_details()));
|
||||
}
|
||||
return std::unique_ptr<TpuDriver>(
|
||||
|
|
|
@ -516,7 +516,7 @@ class RecordingTpuDriver : public TpuDriver {
|
|||
|
||||
xla::StatusOr<std::unique_ptr<TpuDriver>> RegisterRecordingTpuDriver(
|
||||
const TpuDriverConfig& config) {
|
||||
std::vector<std::string> configs = absl::StrSplit(config.worker, '|');
|
||||
std::vector<std::string> configs = absl::StrSplit(config.worker(), '|');
|
||||
|
||||
std::string file;
|
||||
std::string worker;
|
||||
|
@ -533,7 +533,7 @@ xla::StatusOr<std::unique_ptr<TpuDriver>> RegisterRecordingTpuDriver(
|
|||
}
|
||||
|
||||
TpuDriverConfig worker_config;
|
||||
worker_config.worker = worker;
|
||||
worker_config.set_worker(worker);
|
||||
|
||||
auto driver_status = TpuDriverRegistry::Open(worker_config);
|
||||
if (!driver_status.ok()) return driver_status.status();
|
||||
|
|
|
@ -88,12 +88,12 @@ uint64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) {
|
|||
/*static*/ xla::StatusOr<std::unique_ptr<TpuDriver>> TpuDriverRegistry::Open(
|
||||
const TpuDriverConfig& config) {
|
||||
for (const auto& driver : *GetDriverRegistryMap()) {
|
||||
if (absl::StartsWith(config.worker, driver.first)) {
|
||||
if (absl::StartsWith(config.worker(), driver.first)) {
|
||||
return driver.second(config);
|
||||
}
|
||||
}
|
||||
return xla::NotFound("Unable to find driver in registry given worker: %s",
|
||||
config.worker);
|
||||
config.worker());
|
||||
}
|
||||
|
||||
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
|
||||
|
|
|
@ -37,6 +37,9 @@
|
|||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
// This API is EXPERIMENTAL and under active developement. It is subject to
|
||||
// change without notice.
|
||||
|
||||
namespace tpu_driver {
|
||||
|
||||
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape);
|
||||
|
@ -227,17 +230,6 @@ class TpuDriver {
|
|||
virtual std::unique_ptr<TpuLinearizer> GetLinearizer() { return nullptr; }
|
||||
};
|
||||
|
||||
struct TpuDriverConfig {
|
||||
std::string worker;
|
||||
|
||||
// Time in seconds before the initial connection to the server will timeout.
|
||||
int64_t connection_timeout_secs = 10;
|
||||
|
||||
// Time in seconds the server may be unresponsive before terminating the
|
||||
// connection.
|
||||
int64_t keepalive_timeout_secs = 30;
|
||||
};
|
||||
|
||||
class TpuDriverRegistry {
|
||||
public:
|
||||
static xla::StatusOr<std::unique_ptr<TpuDriver>> Open(
|
||||
|
|
|
@ -41,3 +41,18 @@ message SystemInfo {
|
|||
repeated TpuChipInfo tpu_chip = 1;
|
||||
required CpuInfo cpu = 2;
|
||||
}
|
||||
|
||||
message TpuDriverConfig {
|
||||
optional string worker = 1;
|
||||
|
||||
message GrpcConfig {
|
||||
// Time in seconds before the initial connection to the server will timeout.
|
||||
optional int64 connection_timeout_secs = 1 [default = 30];
|
||||
|
||||
// Time in seconds the server may be unresponsive before terminating the
|
||||
// connection.
|
||||
optional int64 keepalive_timeout_secs = 2 [default = 30];
|
||||
}
|
||||
|
||||
optional GrpcConfig grpc = 2;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue