[tf.data service] Set use_node_name_sharing for lookup tables.

This avoids an issue where lookup table resources are deleted as soon as GraphRunner::Run creates the dataset in standalone::Dataset. With `use_node_name_sharing`, resources won't be deleted until an explicit delete op is run.

PiperOrigin-RevId: 336130997
Change-Id: I452ed9eb4a93f3709d60304752d688e70f475e00
This commit is contained in:
Andrew Audibert 2020-10-08 11:42:39 -07:00 committed by TensorFlower Gardener
parent 28a88cc4f7
commit 17dc4b07a0

View File

@ -54,6 +54,17 @@ constexpr char kJournalDir[] = "tf_data_dispatcher_journal";
// The name of the datasets directory inside the dispatcher's working directory.
constexpr char kDatasetsDir[] = "datasets";
constexpr std::array<const char*, 8> kNodeNameSharingOps = {
"HashTable",
"HashTableV2",
"MutableHashTable",
"MutableHashTableV2",
"MutableDenseHashTable",
"MutableDenseHashTableV2",
"MutableHashTableOfTensors",
"MutableHashTableOfTensorsV2",
};
using Dataset = DispatcherState::Dataset;
using Worker = DispatcherState::Worker;
using NamedJobKey = DispatcherState::NamedJobKey;
@ -312,14 +323,24 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset(
GetOrRegisterDatasetResponse* response) {
TF_RETURN_IF_ERROR(CheckStarted());
uint64 fingerprint;
const GraphDef& graph = request->dataset().graph();
TF_RETURN_IF_ERROR(HashGraph(graph, &fingerprint));
DatasetDef dataset_def = request->dataset();
TF_RETURN_IF_ERROR(HashGraph(dataset_def.graph(), &fingerprint));
// Set `use_node_name_sharing` to `true` so that resources aren't deleted
// prematurely. Otherwise, resources may be deleted when their ops are
// deleted at the end of the GraphRunner::Run used by standalone::Dataset.
for (NodeDef& node : *dataset_def.mutable_graph()->mutable_node()) {
for (const auto& op : kNodeNameSharingOps) {
if (node.op() == op) {
(*node.mutable_attr())["use_node_name_sharing"].set_b(true);
}
}
}
mutex_lock l(mu_);
#if defined(PLATFORM_GOOGLE)
VLOG_LINES(4,
absl::StrCat("Registering dataset graph: ", graph.DebugString()));
VLOG_LINES(4, absl::StrCat("Registering dataset graph: ",
dataset_def.graph().DebugString()));
#else
VLOG(4) << "Registering dataset graph: " << graph.DebugString();
VLOG(4) << "Registering dataset graph: " << dataset_def.graph().DebugString();
#endif
std::shared_ptr<const Dataset> dataset;
Status s = state_.DatasetFromFingerprint(fingerprint, dataset);
@ -334,7 +355,7 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset(
}
int64 id;
TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, request->dataset(), id));
TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, dataset_def, id));
response->set_dataset_id(id);
VLOG(3) << "Registered new dataset with id " << id;
return Status::OK();