[tf.data service] Set use_node_name_sharing for lookup tables.
This avoids an issue where lookup table resources are deleted as soon as GraphRunner::Run creates the dataset in standalone::Dataset. With `use_node_name_sharing`, resources won't be deleted until an explicit delete op is run. PiperOrigin-RevId: 336130997 Change-Id: I452ed9eb4a93f3709d60304752d688e70f475e00
This commit is contained in:
parent
28a88cc4f7
commit
17dc4b07a0
@ -54,6 +54,17 @@ constexpr char kJournalDir[] = "tf_data_dispatcher_journal";
|
||||
// The name of the datasets directory inside the dispatcher's working directory.
|
||||
constexpr char kDatasetsDir[] = "datasets";
|
||||
|
||||
constexpr std::array<const char*, 8> kNodeNameSharingOps = {
|
||||
"HashTable",
|
||||
"HashTableV2",
|
||||
"MutableHashTable",
|
||||
"MutableHashTableV2",
|
||||
"MutableDenseHashTable",
|
||||
"MutableDenseHashTableV2",
|
||||
"MutableHashTableOfTensors",
|
||||
"MutableHashTableOfTensorsV2",
|
||||
};
|
||||
|
||||
using Dataset = DispatcherState::Dataset;
|
||||
using Worker = DispatcherState::Worker;
|
||||
using NamedJobKey = DispatcherState::NamedJobKey;
|
||||
@ -312,14 +323,24 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset(
|
||||
GetOrRegisterDatasetResponse* response) {
|
||||
TF_RETURN_IF_ERROR(CheckStarted());
|
||||
uint64 fingerprint;
|
||||
const GraphDef& graph = request->dataset().graph();
|
||||
TF_RETURN_IF_ERROR(HashGraph(graph, &fingerprint));
|
||||
DatasetDef dataset_def = request->dataset();
|
||||
TF_RETURN_IF_ERROR(HashGraph(dataset_def.graph(), &fingerprint));
|
||||
// Set `use_node_name_sharing` to `true` so that resources aren't deleted
|
||||
// prematurely. Otherwise, resources may be deleted when their ops are
|
||||
// deleted at the end of the GraphRunner::Run used by standalone::Dataset.
|
||||
for (NodeDef& node : *dataset_def.mutable_graph()->mutable_node()) {
|
||||
for (const auto& op : kNodeNameSharingOps) {
|
||||
if (node.op() == op) {
|
||||
(*node.mutable_attr())["use_node_name_sharing"].set_b(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
mutex_lock l(mu_);
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
VLOG_LINES(4,
|
||||
absl::StrCat("Registering dataset graph: ", graph.DebugString()));
|
||||
VLOG_LINES(4, absl::StrCat("Registering dataset graph: ",
|
||||
dataset_def.graph().DebugString()));
|
||||
#else
|
||||
VLOG(4) << "Registering dataset graph: " << graph.DebugString();
|
||||
VLOG(4) << "Registering dataset graph: " << dataset_def.graph().DebugString();
|
||||
#endif
|
||||
std::shared_ptr<const Dataset> dataset;
|
||||
Status s = state_.DatasetFromFingerprint(fingerprint, dataset);
|
||||
@ -334,7 +355,7 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset(
|
||||
}
|
||||
|
||||
int64 id;
|
||||
TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, request->dataset(), id));
|
||||
TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, dataset_def, id));
|
||||
response->set_dataset_id(id);
|
||||
VLOG(3) << "Registered new dataset with id " << id;
|
||||
return Status::OK();
|
||||
|
Loading…
Reference in New Issue
Block a user