[tf.data service] Move device stripping to the server side.

This consolidates all dataset graph preparation logic to happen on the dispatcher server when the dataset graph is registered.

PiperOrigin-RevId: 337190481
Change-Id: I6cfaac0e5f889422b9d46ca46b62f604127ea12c
This commit is contained in:
Andrew Audibert 2020-10-14 15:53:02 -07:00 committed by TensorFlower Gardener
parent 60424aaaeb
commit ffd60185f9
2 changed files with 24 additions and 15 deletions

View File

@ -94,6 +94,23 @@ Status CreateWorkerStub(const std::string& address, const std::string& protocol,
stub = WorkerService::NewStub(channel);
return Status::OK();
}
void PrepareGraph(GraphDef* graph) {
for (NodeDef& node : *graph->mutable_node()) {
for (const auto& op : kNodeNameSharingOps) {
// Set `use_node_name_sharing` to `true` so that resources aren't deleted
// prematurely. Otherwise, resources may be deleted when their ops are
// deleted at the end of the GraphRunner::Run used by standalone::Dataset.
if (node.op() == op) {
(*node.mutable_attr())["use_node_name_sharing"].set_b(true);
}
if (!node.device().empty()) {
*node.mutable_device() = "";
}
}
}
StripDevicePlacement(graph->mutable_library());
}
} // namespace
DataServiceDispatcherImpl::DataServiceDispatcherImpl(
@ -324,23 +341,16 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset(
TF_RETURN_IF_ERROR(CheckStarted());
uint64 fingerprint;
DatasetDef dataset_def = request->dataset();
TF_RETURN_IF_ERROR(HashGraph(dataset_def.graph(), &fingerprint));
// Set `use_node_name_sharing` to `true` so that resources aren't deleted
// prematurely. Otherwise, resources may be deleted when their ops are
// deleted at the end of the GraphRunner::Run used by standalone::Dataset.
for (NodeDef& node : *dataset_def.mutable_graph()->mutable_node()) {
for (const auto& op : kNodeNameSharingOps) {
if (node.op() == op) {
(*node.mutable_attr())["use_node_name_sharing"].set_b(true);
}
}
}
GraphDef* graph = dataset_def.mutable_graph();
PrepareGraph(graph);
TF_RETURN_IF_ERROR(HashGraph(*graph, &fingerprint));
mutex_lock l(mu_);
#if defined(PLATFORM_GOOGLE)
VLOG_LINES(4, absl::StrCat("Registering dataset graph: ",
dataset_def.graph().DebugString()));
VLOG_LINES(4,
absl::StrCat("Registering dataset graph: ", graph->DebugString()));
#else
VLOG(4) << "Registering dataset graph: " << dataset_def.graph().DebugString();
VLOG(4) << "Registering dataset graph: " << graph->DebugString();
#endif
std::shared_ptr<const Dataset> dataset;
Status s = state_.DatasetFromFingerprint(fingerprint, dataset);

View File

@ -57,7 +57,6 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
GraphDef graph_def;
OP_REQUIRES_OK(
ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def));
StripDevicePlacement(graph_def.mutable_library());
DataServiceDispatcherClient client(address, protocol);
int64 dataset_id;