[tf.data service] Move device stripping to the server side.

This consolidates all dataset graph preparation logic to happen on the dispatcher server when the dataset graph is registered.

PiperOrigin-RevId: 337190481
Change-Id: I6cfaac0e5f889422b9d46ca46b62f604127ea12c
This commit is contained in:
Andrew Audibert 2020-10-14 15:53:02 -07:00 committed by TensorFlower Gardener
parent 60424aaaeb
commit ffd60185f9
2 changed files with 24 additions and 15 deletions

View File

@ -94,6 +94,23 @@ Status CreateWorkerStub(const std::string& address, const std::string& protocol,
stub = WorkerService::NewStub(channel); stub = WorkerService::NewStub(channel);
return Status::OK(); return Status::OK();
} }
void PrepareGraph(GraphDef* graph) {
for (NodeDef& node : *graph->mutable_node()) {
for (const auto& op : kNodeNameSharingOps) {
// Set `use_node_name_sharing` to `true` so that resources aren't deleted
// prematurely. Otherwise, resources may be deleted when their ops are
// deleted at the end of the GraphRunner::Run used by standalone::Dataset.
if (node.op() == op) {
(*node.mutable_attr())["use_node_name_sharing"].set_b(true);
}
if (!node.device().empty()) {
*node.mutable_device() = "";
}
}
}
StripDevicePlacement(graph->mutable_library());
}
} // namespace } // namespace
DataServiceDispatcherImpl::DataServiceDispatcherImpl( DataServiceDispatcherImpl::DataServiceDispatcherImpl(
@ -324,23 +341,16 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset(
TF_RETURN_IF_ERROR(CheckStarted()); TF_RETURN_IF_ERROR(CheckStarted());
uint64 fingerprint; uint64 fingerprint;
DatasetDef dataset_def = request->dataset(); DatasetDef dataset_def = request->dataset();
TF_RETURN_IF_ERROR(HashGraph(dataset_def.graph(), &fingerprint)); GraphDef* graph = dataset_def.mutable_graph();
// Set `use_node_name_sharing` to `true` so that resources aren't deleted PrepareGraph(graph);
// prematurely. Otherwise, resources may be deleted when their ops are TF_RETURN_IF_ERROR(HashGraph(*graph, &fingerprint));
// deleted at the end of the GraphRunner::Run used by standalone::Dataset.
for (NodeDef& node : *dataset_def.mutable_graph()->mutable_node()) {
for (const auto& op : kNodeNameSharingOps) {
if (node.op() == op) {
(*node.mutable_attr())["use_node_name_sharing"].set_b(true);
}
}
}
mutex_lock l(mu_); mutex_lock l(mu_);
#if defined(PLATFORM_GOOGLE) #if defined(PLATFORM_GOOGLE)
VLOG_LINES(4, absl::StrCat("Registering dataset graph: ", VLOG_LINES(4,
dataset_def.graph().DebugString())); absl::StrCat("Registering dataset graph: ", graph->DebugString()));
#else #else
VLOG(4) << "Registering dataset graph: " << dataset_def.graph().DebugString(); VLOG(4) << "Registering dataset graph: " << graph->DebugString();
#endif #endif
std::shared_ptr<const Dataset> dataset; std::shared_ptr<const Dataset> dataset;
Status s = state_.DatasetFromFingerprint(fingerprint, dataset); Status s = state_.DatasetFromFingerprint(fingerprint, dataset);

View File

@ -57,7 +57,6 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
GraphDef graph_def; GraphDef graph_def;
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def)); ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def));
StripDevicePlacement(graph_def.mutable_library());
DataServiceDispatcherClient client(address, protocol); DataServiceDispatcherClient client(address, protocol);
int64 dataset_id; int64 dataset_id;