diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc index 20049bf51f7..5166d97e75f 100644 --- a/tensorflow/core/kernels/data/dataset_ops.cc +++ b/tensorflow/core/kernels/data/dataset_ops.cc @@ -40,6 +40,10 @@ namespace data { /* static */ constexpr const char* const DatasetFromGraphOp::kGraphDef; /* static */ constexpr const char* const DatasetFromGraphOp::kHandle; +namespace { +constexpr char kPyFunc[] = "PyFunc"; +} // namespace + // See documentation in ../../ops/dataset_ops.cc for a high-level // description of the following op. DatasetToGraphOp::DatasetToGraphOp(OpKernelConstruction* ctx) @@ -89,7 +93,9 @@ void DatasetToGraphOp::Compute(OpKernelContext* ctx) { auto library = graph_def.mutable_library(); for (auto& function : (*library->mutable_function())) { for (auto& node : (*function.mutable_node_def())) { - if (!node.device().empty()) { + // We do not strip the device assignment from `PyFunc` ops because they + // need to be pinned to a host that is known to have Python interpreter. + if (!node.device().empty() && node.op() != kPyFunc) { *node.mutable_device() = DeviceNameUtils::LocalName(node.device()); } }