[tf.data service] Add client/server version mismatch error.

We raise an error instead of logging a warning because client/server mismatch is likely to cause hangs or other failures that will be easier to debug if the user gets a direct error message instead of needing to discover the warning in their logs.

PiperOrigin-RevId: 357863794
Change-Id: Ie5cc0d7906b10c404b1df47f8a5544f95cbe5e33
This commit is contained in:
Andrew Audibert 2021-02-16 20:21:51 -08:00 committed by TensorFlower Gardener
parent cc5e6d395e
commit 8e421f5cc6
7 changed files with 45 additions and 0 deletions

View File

@ -236,6 +236,28 @@ Status DataServiceDispatcherClient::EnsureInitialized() {
args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true);
auto channel = grpc::CreateCustomChannel(address_, credentials, args);
stub_ = DispatcherService::NewStub(channel);
GetVersionRequest req;
GetVersionResponse resp;
TF_RETURN_IF_ERROR(grpc_util::Retry(
[&] {
grpc::ClientContext ctx;
grpc::Status s = stub_->GetVersion(&ctx, req, &resp);
if (!s.ok()) {
return grpc_util::WrapError("Failed to get dispatcher version", s);
}
return Status::OK();
},
"checking service version",
/*deadline_micros=*/kint64max));
if (resp.version() != kDataServiceVersion) {
return errors::FailedPrecondition(
"Version mismatch with tf.data service server. The server is running "
"version ",
resp.version(), ", while the client is running version ",
kDataServiceVersion,
". Please ensure that the client and server side are running the "
"same version of TensorFlow.");
}
return Status::OK();
}

View File

@ -28,6 +28,10 @@ limitations under the License.
namespace tensorflow {
namespace data {
// Increment this when making backwards-incompatible changes to communication
// between tf.data servers.
constexpr int kDataServiceVersion = 1;
// Modes for how a tf.data service job should process a dataset.
enum class ProcessingMode : int64 {
UNSET = 0,

View File

@ -48,6 +48,12 @@ message GetSplitResponse {
bool end_of_splits = 2;
}
message GetVersionRequest {}
message GetVersionResponse {
int64 version = 1;
}
message GetOrRegisterDatasetRequest {
// The dataset to register.
DatasetDef dataset = 1;
@ -146,6 +152,9 @@ service DispatcherService {
// Gets the next split for a given job.
rpc GetSplit(GetSplitRequest) returns (GetSplitResponse);
// Returns the API version of the server.
rpc GetVersion(GetVersionRequest) returns (GetVersionResponse);
// Registers a dataset with the server, or returns its id if it is already
// registered.
//

View File

@ -361,6 +361,12 @@ Status DataServiceDispatcherImpl::MakeSplitProvider(
return Status::OK();
}
Status DataServiceDispatcherImpl::GetVersion(const GetVersionRequest* request,
GetVersionResponse* response) {
response->set_version(kDataServiceVersion);
return Status::OK();
}
Status DataServiceDispatcherImpl::GetOrRegisterDataset(
const GetOrRegisterDatasetRequest* request,
GetOrRegisterDatasetResponse* response) {

View File

@ -135,6 +135,8 @@ class DataServiceDispatcherImpl {
Status GetSplit(const GetSplitRequest* request, GetSplitResponse* response);
/// Client-facing API.
Status GetVersion(const GetVersionRequest* request,
GetVersionResponse* response);
Status GetOrRegisterDataset(const GetOrRegisterDatasetRequest* request,
GetOrRegisterDatasetResponse* response);
Status GetOrCreateJob(const GetOrCreateJobRequest* request,

View File

@ -44,6 +44,7 @@ HANDLER(WorkerHeartbeat);
HANDLER(WorkerUpdate);
HANDLER(GetDatasetDef);
HANDLER(GetSplit);
HANDLER(GetVersion);
HANDLER(GetOrRegisterDataset);
HANDLER(ReleaseJobClient);
HANDLER(GetOrCreateJob);

View File

@ -43,6 +43,7 @@ class GrpcDispatcherImpl : public DispatcherService::Service {
HANDLER(WorkerUpdate);
HANDLER(GetDatasetDef);
HANDLER(GetSplit);
HANDLER(GetVersion);
HANDLER(GetOrRegisterDataset);
HANDLER(ReleaseJobClient);
HANDLER(GetOrCreateJob);