diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc index f95baaf0b7c..2973548ab19 100644 --- a/tensorflow/core/data/service/dispatcher_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -54,6 +54,17 @@ constexpr char kJournalDir[] = "tf_data_dispatcher_journal"; // The name of the datasets directory inside the dispatcher's working directory. constexpr char kDatasetsDir[] = "datasets"; +constexpr std::array kNodeNameSharingOps = { + "HashTable", + "HashTableV2", + "MutableHashTable", + "MutableHashTableV2", + "MutableDenseHashTable", + "MutableDenseHashTableV2", + "MutableHashTableOfTensors", + "MutableHashTableOfTensorsV2", +}; + using Dataset = DispatcherState::Dataset; using Worker = DispatcherState::Worker; using NamedJobKey = DispatcherState::NamedJobKey; @@ -312,14 +323,24 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset( GetOrRegisterDatasetResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); uint64 fingerprint; - const GraphDef& graph = request->dataset().graph(); - TF_RETURN_IF_ERROR(HashGraph(graph, &fingerprint)); + DatasetDef dataset_def = request->dataset(); + TF_RETURN_IF_ERROR(HashGraph(dataset_def.graph(), &fingerprint)); + // Set `use_node_name_sharing` to `true` so that resources aren't deleted + // prematurely. Otherwise, resources may be deleted when their ops are + // deleted at the end of the GraphRunner::Run used by standalone::Dataset. + for (NodeDef& node : *dataset_def.mutable_graph()->mutable_node()) { + for (const auto& op : kNodeNameSharingOps) { + if (node.op() == op) { + (*node.mutable_attr())["use_node_name_sharing"].set_b(true); + } + } + } mutex_lock l(mu_); #if defined(PLATFORM_GOOGLE) - VLOG_LINES(4, - absl::StrCat("Registering dataset graph: ", graph.DebugString())); + VLOG_LINES(4, absl::StrCat("Registering dataset graph: ", + dataset_def.graph().DebugString())); #else - VLOG(4) << "Registering dataset graph: " << graph.DebugString(); + VLOG(4) << "Registering dataset graph: " << dataset_def.graph().DebugString(); #endif std::shared_ptr dataset; Status s = state_.DatasetFromFingerprint(fingerprint, dataset); @@ -334,7 +355,7 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset( } int64 id; - TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, request->dataset(), id)); + TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, dataset_def, id)); response->set_dataset_id(id); VLOG(3) << "Registered new dataset with id " << id; return Status::OK();