From 17dc4b07a0114e3912907ffa0c000fa29f1d5c44 Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Thu, 8 Oct 2020 11:42:39 -0700 Subject: [PATCH] [tf.data service] Set use_node_name_sharing for lookup tables. This avoids an issue where lookup table resources are deleted as soon as GraphRunner::Run creates the dataset in standalone::Dataset. With `use_node_name_sharing`, resources won't be deleted until an explicit delete op is run. PiperOrigin-RevId: 336130997 Change-Id: I452ed9eb4a93f3709d60304752d688e70f475e00 --- .../core/data/service/dispatcher_impl.cc | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc index f95baaf0b7c..2973548ab19 100644 --- a/tensorflow/core/data/service/dispatcher_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -54,6 +54,17 @@ constexpr char kJournalDir[] = "tf_data_dispatcher_journal"; // The name of the datasets directory inside the dispatcher's working directory. constexpr char kDatasetsDir[] = "datasets"; +constexpr std::array kNodeNameSharingOps = { + "HashTable", + "HashTableV2", + "MutableHashTable", + "MutableHashTableV2", + "MutableDenseHashTable", + "MutableDenseHashTableV2", + "MutableHashTableOfTensors", + "MutableHashTableOfTensorsV2", +}; + using Dataset = DispatcherState::Dataset; using Worker = DispatcherState::Worker; using NamedJobKey = DispatcherState::NamedJobKey; @@ -312,14 +323,24 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset( GetOrRegisterDatasetResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); uint64 fingerprint; - const GraphDef& graph = request->dataset().graph(); - TF_RETURN_IF_ERROR(HashGraph(graph, &fingerprint)); + DatasetDef dataset_def = request->dataset(); + TF_RETURN_IF_ERROR(HashGraph(dataset_def.graph(), &fingerprint)); + // Set `use_node_name_sharing` to `true` so that resources aren't deleted + // prematurely. Otherwise, resources may be deleted when their ops are + // deleted at the end of the GraphRunner::Run used by standalone::Dataset. + for (NodeDef& node : *dataset_def.mutable_graph()->mutable_node()) { + for (const auto& op : kNodeNameSharingOps) { + if (node.op() == op) { + (*node.mutable_attr())["use_node_name_sharing"].set_b(true); + } + } + } mutex_lock l(mu_); #if defined(PLATFORM_GOOGLE) - VLOG_LINES(4, - absl::StrCat("Registering dataset graph: ", graph.DebugString())); + VLOG_LINES(4, absl::StrCat("Registering dataset graph: ", + dataset_def.graph().DebugString())); #else - VLOG(4) << "Registering dataset graph: " << graph.DebugString(); + VLOG(4) << "Registering dataset graph: " << dataset_def.graph().DebugString(); #endif std::shared_ptr dataset; Status s = state_.DatasetFromFingerprint(fingerprint, dataset); @@ -334,7 +355,7 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset( } int64 id; - TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, request->dataset(), id)); + TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, dataset_def, id)); response->set_dataset_id(id); VLOG(3) << "Registered new dataset with id " << id; return Status::OK();