TpuDriver: Switch configuration object to protobuf for easier serialization/parsing.
PiperOrigin-RevId: 281329177 Change-Id: I01b767f2415d0e469d6b1166d96918f72755c0f8
This commit is contained in:
parent
9182431be6
commit
a6750e3dc5
|
@ -48,7 +48,7 @@ static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
|
||||||
StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
|
StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
|
||||||
const std::string& worker) {
|
const std::string& worker) {
|
||||||
tpu_driver::TpuDriverConfig driver_config;
|
tpu_driver::TpuDriverConfig driver_config;
|
||||||
driver_config.worker = worker;
|
driver_config.set_worker(worker);
|
||||||
auto client_status = tpu_driver::TpuDriverRegistry::Open(driver_config);
|
auto client_status = tpu_driver::TpuDriverRegistry::Open(driver_config);
|
||||||
|
|
||||||
if (!client_status.ok()) {
|
if (!client_status.ok()) {
|
||||||
|
|
|
@ -907,8 +907,8 @@ GrpcTpuDriver::CreateTpuDriverStub(const TpuDriverConfig& config) {
|
||||||
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
|
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
|
||||||
|
|
||||||
// Send at least 20 keep-alives before giving up.
|
// Send at least 20 keep-alives before giving up.
|
||||||
int keepalive_timeout_ms = config.keepalive_timeout_secs * 1000;
|
int keepalive_timeout_ms = config.grpc().keepalive_timeout_secs() * 1000;
|
||||||
int keepalive_interval_ms = config.keepalive_timeout_secs / 20;
|
int keepalive_interval_ms = keepalive_timeout_ms / 20;
|
||||||
|
|
||||||
grpc_arg client_arg_vals[] = {
|
grpc_arg client_arg_vals[] = {
|
||||||
{.type = GRPC_ARG_INTEGER,
|
{.type = GRPC_ARG_INTEGER,
|
||||||
|
@ -935,7 +935,7 @@ GrpcTpuDriver::CreateTpuDriverStub(const TpuDriverConfig& config) {
|
||||||
args.SetChannelArgs(&client_args);
|
args.SetChannelArgs(&client_args);
|
||||||
|
|
||||||
// strips out 'grpc://'
|
// strips out 'grpc://'
|
||||||
auto worker_addr = absl::StripPrefix(config.worker, kGrpcProtocol);
|
auto worker_addr = absl::StripPrefix(config.worker(), kGrpcProtocol);
|
||||||
std::shared_ptr<::grpc::Channel> channel =
|
std::shared_ptr<::grpc::Channel> channel =
|
||||||
::grpc::CreateCustomChannel(std::string(worker_addr), creds, args);
|
::grpc::CreateCustomChannel(std::string(worker_addr), creds, args);
|
||||||
return grpc::CloudTpuDriver::NewStub(channel);
|
return grpc::CloudTpuDriver::NewStub(channel);
|
||||||
|
@ -977,8 +977,9 @@ REGISTER_TPU_DRIVER(
|
||||||
auto stub = GrpcTpuDriver::CreateTpuDriverStub(config);
|
auto stub = GrpcTpuDriver::CreateTpuDriverStub(config);
|
||||||
::grpc::ClientContext ctx;
|
::grpc::ClientContext ctx;
|
||||||
ctx.set_fail_fast(false);
|
ctx.set_fail_fast(false);
|
||||||
ctx.set_deadline(std::chrono::system_clock::now() +
|
ctx.set_deadline(
|
||||||
std::chrono::seconds(config.connection_timeout_secs));
|
std::chrono::system_clock::now() +
|
||||||
|
std::chrono::seconds(config.grpc().connection_timeout_secs()));
|
||||||
OpenRequest req;
|
OpenRequest req;
|
||||||
OpenResponse resp;
|
OpenResponse resp;
|
||||||
::grpc::Status status = stub->Open(&ctx, req, &resp);
|
::grpc::Status status = stub->Open(&ctx, req, &resp);
|
||||||
|
@ -988,7 +989,7 @@ REGISTER_TPU_DRIVER(
|
||||||
return xla::Status(
|
return xla::Status(
|
||||||
tensorflow::error::Code(status.error_code()),
|
tensorflow::error::Code(status.error_code()),
|
||||||
absl::StrCat("Failed to connect to remote server at address: ",
|
absl::StrCat("Failed to connect to remote server at address: ",
|
||||||
config.worker,
|
config.worker(),
|
||||||
". Error from gRPC: ", status.error_details()));
|
". Error from gRPC: ", status.error_details()));
|
||||||
}
|
}
|
||||||
return std::unique_ptr<TpuDriver>(
|
return std::unique_ptr<TpuDriver>(
|
||||||
|
|
|
@ -516,7 +516,7 @@ class RecordingTpuDriver : public TpuDriver {
|
||||||
|
|
||||||
xla::StatusOr<std::unique_ptr<TpuDriver>> RegisterRecordingTpuDriver(
|
xla::StatusOr<std::unique_ptr<TpuDriver>> RegisterRecordingTpuDriver(
|
||||||
const TpuDriverConfig& config) {
|
const TpuDriverConfig& config) {
|
||||||
std::vector<std::string> configs = absl::StrSplit(config.worker, '|');
|
std::vector<std::string> configs = absl::StrSplit(config.worker(), '|');
|
||||||
|
|
||||||
std::string file;
|
std::string file;
|
||||||
std::string worker;
|
std::string worker;
|
||||||
|
@ -533,7 +533,7 @@ xla::StatusOr<std::unique_ptr<TpuDriver>> RegisterRecordingTpuDriver(
|
||||||
}
|
}
|
||||||
|
|
||||||
TpuDriverConfig worker_config;
|
TpuDriverConfig worker_config;
|
||||||
worker_config.worker = worker;
|
worker_config.set_worker(worker);
|
||||||
|
|
||||||
auto driver_status = TpuDriverRegistry::Open(worker_config);
|
auto driver_status = TpuDriverRegistry::Open(worker_config);
|
||||||
if (!driver_status.ok()) return driver_status.status();
|
if (!driver_status.ok()) return driver_status.status();
|
||||||
|
|
|
@ -88,12 +88,12 @@ uint64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) {
|
||||||
/*static*/ xla::StatusOr<std::unique_ptr<TpuDriver>> TpuDriverRegistry::Open(
|
/*static*/ xla::StatusOr<std::unique_ptr<TpuDriver>> TpuDriverRegistry::Open(
|
||||||
const TpuDriverConfig& config) {
|
const TpuDriverConfig& config) {
|
||||||
for (const auto& driver : *GetDriverRegistryMap()) {
|
for (const auto& driver : *GetDriverRegistryMap()) {
|
||||||
if (absl::StartsWith(config.worker, driver.first)) {
|
if (absl::StartsWith(config.worker(), driver.first)) {
|
||||||
return driver.second(config);
|
return driver.second(config);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return xla::NotFound("Unable to find driver in registry given worker: %s",
|
return xla::NotFound("Unable to find driver in registry given worker: %s",
|
||||||
config.worker);
|
config.worker());
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
|
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
|
||||||
|
|
|
@ -37,6 +37,9 @@
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
// This API is EXPERIMENTAL and under active developement. It is subject to
|
||||||
|
// change without notice.
|
||||||
|
|
||||||
namespace tpu_driver {
|
namespace tpu_driver {
|
||||||
|
|
||||||
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape);
|
uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape);
|
||||||
|
@ -227,17 +230,6 @@ class TpuDriver {
|
||||||
virtual std::unique_ptr<TpuLinearizer> GetLinearizer() { return nullptr; }
|
virtual std::unique_ptr<TpuLinearizer> GetLinearizer() { return nullptr; }
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TpuDriverConfig {
|
|
||||||
std::string worker;
|
|
||||||
|
|
||||||
// Time in seconds before the initial connection to the server will timeout.
|
|
||||||
int64_t connection_timeout_secs = 10;
|
|
||||||
|
|
||||||
// Time in seconds the server may be unresponsive before terminating the
|
|
||||||
// connection.
|
|
||||||
int64_t keepalive_timeout_secs = 30;
|
|
||||||
};
|
|
||||||
|
|
||||||
class TpuDriverRegistry {
|
class TpuDriverRegistry {
|
||||||
public:
|
public:
|
||||||
static xla::StatusOr<std::unique_ptr<TpuDriver>> Open(
|
static xla::StatusOr<std::unique_ptr<TpuDriver>> Open(
|
||||||
|
|
|
@ -41,3 +41,18 @@ message SystemInfo {
|
||||||
repeated TpuChipInfo tpu_chip = 1;
|
repeated TpuChipInfo tpu_chip = 1;
|
||||||
required CpuInfo cpu = 2;
|
required CpuInfo cpu = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message TpuDriverConfig {
|
||||||
|
optional string worker = 1;
|
||||||
|
|
||||||
|
message GrpcConfig {
|
||||||
|
// Time in seconds before the initial connection to the server will timeout.
|
||||||
|
optional int64 connection_timeout_secs = 1 [default = 30];
|
||||||
|
|
||||||
|
// Time in seconds the server may be unresponsive before terminating the
|
||||||
|
// connection.
|
||||||
|
optional int64 keepalive_timeout_secs = 2 [default = 30];
|
||||||
|
}
|
||||||
|
|
||||||
|
optional GrpcConfig grpc = 2;
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue