From 85e23ada07fe8a04eab0f946954f41601101d889 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Wed, 15 Apr 2020 17:48:59 -0700 Subject: [PATCH] [tf.data] Updating logic for stripping device during input pipeline graph serialization to exclude `PyFunc` ops as this ops need to be pinned to hosts that are known to have a Python interpreter. PiperOrigin-RevId: 306755919 Change-Id: I129828d947e3700b93eadf9491ac642737bb7da0 --- tensorflow/core/kernels/data/dataset_ops.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc index 20049bf51f7..5166d97e75f 100644 --- a/tensorflow/core/kernels/data/dataset_ops.cc +++ b/tensorflow/core/kernels/data/dataset_ops.cc @@ -40,6 +40,10 @@ namespace data { /* static */ constexpr const char* const DatasetFromGraphOp::kGraphDef; /* static */ constexpr const char* const DatasetFromGraphOp::kHandle; +namespace { +constexpr char kPyFunc[] = "PyFunc"; +} // namespace + // See documentation in ../../ops/dataset_ops.cc for a high-level // description of the following op. DatasetToGraphOp::DatasetToGraphOp(OpKernelConstruction* ctx) @@ -89,7 +93,9 @@ void DatasetToGraphOp::Compute(OpKernelContext* ctx) { auto library = graph_def.mutable_library(); for (auto& function : (*library->mutable_function())) { for (auto& node : (*function.mutable_node_def())) { - if (!node.device().empty()) { + // We do not strip the device assignment from `PyFunc` ops because they + // need to be pinned to a host that is known to have Python interpreter. + if (!node.device().empty() && node.op() != kPyFunc) { *node.mutable_device() = DeviceNameUtils::LocalName(node.device()); } }