[tf.data] Updating logic for stripping device during input pipeline graph serialization to exclude `PyFunc` ops as this ops need to be pinned to hosts that are known to have a Python interpreter.

PiperOrigin-RevId: 306755919
Change-Id: I129828d947e3700b93eadf9491ac642737bb7da0
This commit is contained in:
Jiri Simsa 2020-04-15 17:48:59 -07:00 committed by TensorFlower Gardener
parent dba59540c7
commit 85e23ada07
1 changed files with 7 additions and 1 deletions

View File

@ -40,6 +40,10 @@ namespace data {
/* static */ constexpr const char* const DatasetFromGraphOp::kGraphDef;
/* static */ constexpr const char* const DatasetFromGraphOp::kHandle;
namespace {
constexpr char kPyFunc[] = "PyFunc";
} // namespace
// See documentation in ../../ops/dataset_ops.cc for a high-level
// description of the following op.
DatasetToGraphOp::DatasetToGraphOp(OpKernelConstruction* ctx)
@ -89,7 +93,9 @@ void DatasetToGraphOp::Compute(OpKernelContext* ctx) {
auto library = graph_def.mutable_library();
for (auto& function : (*library->mutable_function())) {
for (auto& node : (*function.mutable_node_def())) {
if (!node.device().empty()) {
// We do not strip the device assignment from `PyFunc` ops because they
// need to be pinned to a host that is known to have Python interpreter.
if (!node.device().empty() && node.op() != kPyFunc) {
*node.mutable_device() = DeviceNameUtils::LocalName(node.device());
}
}