[tf.data] Updating logic for stripping device during input pipeline graph serialization to exclude `PyFunc` ops as this ops need to be pinned to hosts that are known to have a Python interpreter.
PiperOrigin-RevId: 306755919 Change-Id: I129828d947e3700b93eadf9491ac642737bb7da0
This commit is contained in:
parent
dba59540c7
commit
85e23ada07
|
@ -40,6 +40,10 @@ namespace data {
|
|||
/* static */ constexpr const char* const DatasetFromGraphOp::kGraphDef;
|
||||
/* static */ constexpr const char* const DatasetFromGraphOp::kHandle;
|
||||
|
||||
namespace {
|
||||
constexpr char kPyFunc[] = "PyFunc";
|
||||
} // namespace
|
||||
|
||||
// See documentation in ../../ops/dataset_ops.cc for a high-level
|
||||
// description of the following op.
|
||||
DatasetToGraphOp::DatasetToGraphOp(OpKernelConstruction* ctx)
|
||||
|
@ -89,7 +93,9 @@ void DatasetToGraphOp::Compute(OpKernelContext* ctx) {
|
|||
auto library = graph_def.mutable_library();
|
||||
for (auto& function : (*library->mutable_function())) {
|
||||
for (auto& node : (*function.mutable_node_def())) {
|
||||
if (!node.device().empty()) {
|
||||
// We do not strip the device assignment from `PyFunc` ops because they
|
||||
// need to be pinned to a host that is known to have Python interpreter.
|
||||
if (!node.device().empty() && node.op() != kPyFunc) {
|
||||
*node.mutable_device() = DeviceNameUtils::LocalName(node.device());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue