[tf.data service] Move device stripping to the server side.
This consolidates all dataset graph preparation logic to happen on the dispatcher server when the dataset graph is registered. PiperOrigin-RevId: 337190481 Change-Id: I6cfaac0e5f889422b9d46ca46b62f604127ea12c
This commit is contained in:
parent
60424aaaeb
commit
ffd60185f9
@ -94,6 +94,23 @@ Status CreateWorkerStub(const std::string& address, const std::string& protocol,
|
||||
stub = WorkerService::NewStub(channel);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void PrepareGraph(GraphDef* graph) {
|
||||
for (NodeDef& node : *graph->mutable_node()) {
|
||||
for (const auto& op : kNodeNameSharingOps) {
|
||||
// Set `use_node_name_sharing` to `true` so that resources aren't deleted
|
||||
// prematurely. Otherwise, resources may be deleted when their ops are
|
||||
// deleted at the end of the GraphRunner::Run used by standalone::Dataset.
|
||||
if (node.op() == op) {
|
||||
(*node.mutable_attr())["use_node_name_sharing"].set_b(true);
|
||||
}
|
||||
if (!node.device().empty()) {
|
||||
*node.mutable_device() = "";
|
||||
}
|
||||
}
|
||||
}
|
||||
StripDevicePlacement(graph->mutable_library());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
DataServiceDispatcherImpl::DataServiceDispatcherImpl(
|
||||
@ -324,23 +341,16 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset(
|
||||
TF_RETURN_IF_ERROR(CheckStarted());
|
||||
uint64 fingerprint;
|
||||
DatasetDef dataset_def = request->dataset();
|
||||
TF_RETURN_IF_ERROR(HashGraph(dataset_def.graph(), &fingerprint));
|
||||
// Set `use_node_name_sharing` to `true` so that resources aren't deleted
|
||||
// prematurely. Otherwise, resources may be deleted when their ops are
|
||||
// deleted at the end of the GraphRunner::Run used by standalone::Dataset.
|
||||
for (NodeDef& node : *dataset_def.mutable_graph()->mutable_node()) {
|
||||
for (const auto& op : kNodeNameSharingOps) {
|
||||
if (node.op() == op) {
|
||||
(*node.mutable_attr())["use_node_name_sharing"].set_b(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
GraphDef* graph = dataset_def.mutable_graph();
|
||||
PrepareGraph(graph);
|
||||
TF_RETURN_IF_ERROR(HashGraph(*graph, &fingerprint));
|
||||
|
||||
mutex_lock l(mu_);
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
VLOG_LINES(4, absl::StrCat("Registering dataset graph: ",
|
||||
dataset_def.graph().DebugString()));
|
||||
VLOG_LINES(4,
|
||||
absl::StrCat("Registering dataset graph: ", graph->DebugString()));
|
||||
#else
|
||||
VLOG(4) << "Registering dataset graph: " << dataset_def.graph().DebugString();
|
||||
VLOG(4) << "Registering dataset graph: " << graph->DebugString();
|
||||
#endif
|
||||
std::shared_ptr<const Dataset> dataset;
|
||||
Status s = state_.DatasetFromFingerprint(fingerprint, dataset);
|
||||
|
@ -57,7 +57,6 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
|
||||
GraphDef graph_def;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def));
|
||||
StripDevicePlacement(graph_def.mutable_library());
|
||||
|
||||
DataServiceDispatcherClient client(address, protocol);
|
||||
int64 dataset_id;
|
||||
|
Loading…
Reference in New Issue
Block a user