[tf.data service] Move device stripping to the server side.
This consolidates all dataset graph preparation logic to happen on the dispatcher server when the dataset graph is registered. PiperOrigin-RevId: 337190481 Change-Id: I6cfaac0e5f889422b9d46ca46b62f604127ea12c
This commit is contained in:
parent
60424aaaeb
commit
ffd60185f9
@ -94,6 +94,23 @@ Status CreateWorkerStub(const std::string& address, const std::string& protocol,
|
|||||||
stub = WorkerService::NewStub(channel);
|
stub = WorkerService::NewStub(channel);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PrepareGraph(GraphDef* graph) {
|
||||||
|
for (NodeDef& node : *graph->mutable_node()) {
|
||||||
|
for (const auto& op : kNodeNameSharingOps) {
|
||||||
|
// Set `use_node_name_sharing` to `true` so that resources aren't deleted
|
||||||
|
// prematurely. Otherwise, resources may be deleted when their ops are
|
||||||
|
// deleted at the end of the GraphRunner::Run used by standalone::Dataset.
|
||||||
|
if (node.op() == op) {
|
||||||
|
(*node.mutable_attr())["use_node_name_sharing"].set_b(true);
|
||||||
|
}
|
||||||
|
if (!node.device().empty()) {
|
||||||
|
*node.mutable_device() = "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
StripDevicePlacement(graph->mutable_library());
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
DataServiceDispatcherImpl::DataServiceDispatcherImpl(
|
DataServiceDispatcherImpl::DataServiceDispatcherImpl(
|
||||||
@ -324,23 +341,16 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset(
|
|||||||
TF_RETURN_IF_ERROR(CheckStarted());
|
TF_RETURN_IF_ERROR(CheckStarted());
|
||||||
uint64 fingerprint;
|
uint64 fingerprint;
|
||||||
DatasetDef dataset_def = request->dataset();
|
DatasetDef dataset_def = request->dataset();
|
||||||
TF_RETURN_IF_ERROR(HashGraph(dataset_def.graph(), &fingerprint));
|
GraphDef* graph = dataset_def.mutable_graph();
|
||||||
// Set `use_node_name_sharing` to `true` so that resources aren't deleted
|
PrepareGraph(graph);
|
||||||
// prematurely. Otherwise, resources may be deleted when their ops are
|
TF_RETURN_IF_ERROR(HashGraph(*graph, &fingerprint));
|
||||||
// deleted at the end of the GraphRunner::Run used by standalone::Dataset.
|
|
||||||
for (NodeDef& node : *dataset_def.mutable_graph()->mutable_node()) {
|
|
||||||
for (const auto& op : kNodeNameSharingOps) {
|
|
||||||
if (node.op() == op) {
|
|
||||||
(*node.mutable_attr())["use_node_name_sharing"].set_b(true);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
#if defined(PLATFORM_GOOGLE)
|
#if defined(PLATFORM_GOOGLE)
|
||||||
VLOG_LINES(4, absl::StrCat("Registering dataset graph: ",
|
VLOG_LINES(4,
|
||||||
dataset_def.graph().DebugString()));
|
absl::StrCat("Registering dataset graph: ", graph->DebugString()));
|
||||||
#else
|
#else
|
||||||
VLOG(4) << "Registering dataset graph: " << dataset_def.graph().DebugString();
|
VLOG(4) << "Registering dataset graph: " << graph->DebugString();
|
||||||
#endif
|
#endif
|
||||||
std::shared_ptr<const Dataset> dataset;
|
std::shared_ptr<const Dataset> dataset;
|
||||||
Status s = state_.DatasetFromFingerprint(fingerprint, dataset);
|
Status s = state_.DatasetFromFingerprint(fingerprint, dataset);
|
||||||
|
@ -57,7 +57,6 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
|
|||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def));
|
ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def));
|
||||||
StripDevicePlacement(graph_def.mutable_library());
|
|
||||||
|
|
||||||
DataServiceDispatcherClient client(address, protocol);
|
DataServiceDispatcherClient client(address, protocol);
|
||||||
int64 dataset_id;
|
int64 dataset_id;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user