From eaa1729a52952f2a541491a1ce2c34af7ab66fc8 Mon Sep 17 00:00:00 2001
From: Eugene Zhulenev <ezhulenev@google.com>
Date: Wed, 18 Mar 2020 13:03:19 -0700
Subject: [PATCH] Force CPU placement for ops that has DT_VARIANT inputs with
 host-only underlying data type.

Fix for #28007

PiperOrigin-RevId: 301650148
Change-Id: I47fa9c1b0b7a7d56c5a519095687f36651892644
---
 tensorflow/core/BUILD                         |  1 +
 .../core/common_runtime/colocation_graph.cc   | 80 +++++++++++++++++++
 .../core/common_runtime/colocation_graph.h    |  8 ++
 tensorflow/core/graph/algorithm.cc            | 27 ++++---
 tensorflow/core/graph/algorithm.h             | 11 ++-
 .../python/kernel_tests/list_ops_test.py      | 30 ++++++-
 6 files changed, 142 insertions(+), 15 deletions(-)

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 188988d92c4..8efada20e24 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2668,6 +2668,7 @@ tf_cuda_library(
         "@com_google_absl//absl/base",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/container:flat_hash_set",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:optional",
diff --git a/tensorflow/core/common_runtime/colocation_graph.cc b/tensorflow/core/common_runtime/colocation_graph.cc
index ab58902f415..41058ae208a 100644
--- a/tensorflow/core/common_runtime/colocation_graph.cc
+++ b/tensorflow/core/common_runtime/colocation_graph.cc
@@ -23,7 +23,9 @@ limitations under the License.
 #include <vector>
 
 #include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
 #include "absl/strings/str_join.h"
+#include "absl/types/optional.h"
 #include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/common_runtime/device_set.h"
 #include "tensorflow/core/common_runtime/function.h"
@@ -39,6 +41,7 @@ limitations under the License.
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
 #include "tensorflow/core/graph/graph_node_util.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
@@ -726,6 +729,82 @@ Status ColocationGraph::ColocateResourceAndRefEdges(
   return Status::OK();
 }
 
+namespace {
+// Returns tensor list element data type, if the node is one of the ops that
+// operate with TensorLists. Otherwise returns DT_INVALID.
+DataType GetElementDataType(const Node& node) {
+  static absl::flat_hash_set<std::string>* tensor_list_ops =
+      new absl::flat_hash_set<std::string>(
+          {"TensorListReserve", "TensorListFromTensor", "EmptyTensorList",
+           "TensorListSplit", "TensorListScatter", "TensorListScatterV2",
+           "TensorListScatterIntoExistingList", "TensorListPushBack",
+           "TensorListPushBackBatch", "TensorListPopBack", "TensorListStack",
+           "TensorListConcat", "TensorListConcatV2", "TensorListGetItem",
+           "TensorListSetItem", "TensorListGather", "TensorListConcatLists"});
+
+  if (tensor_list_ops->contains(node.type_string())) {
+    DataType element_type;
+    if (GetNodeAttr(node.attrs(), "element_dtype", &element_type).ok()) {
+      return element_type;
+    }
+  }
+
+  return DT_INVALID;
+}
+}  // namespace
+
+Status ColocationGraph::AddHostOnlyDataTypesConstraints() {
+  auto is_variant = [](DataType dtype) -> bool { return dtype == DT_VARIANT; };
+
+  auto is_cpu_device = [](const std::pair<DeviceType, int32>& entry) -> bool {
+    return entry.first == DEVICE_CPU;
+  };
+
+  for (Node* node : graph_.nodes()) {
+    // Skip nodes that do not have DT_VARIANT inputs.
+    if (absl::c_none_of(node->input_types(), is_variant)) continue;
+
+    // Skip nodes that can't be placed on GPU anyway.
+    Member& root = members_[FindAndUpdateRoot(node->id())];
+    if (absl::c_all_of(root.supported_device_types(), is_cpu_device)) continue;
+
+    // Stop DFS traversal when found the underlying data type of a variant.
+    absl::optional<bool> is_host_data_type;
+
+    auto edge_filter = [&](const Edge& edge) -> bool {
+      return !is_host_data_type.has_value();
+    };
+
+    auto enter = [&](Node* n) -> void {
+      DataType element_type = GetElementDataType(*n);
+      // To handle nested lists continue traversal after finding a TensorList
+      // operation that uses DT_VARIANT for element type.
+      if (element_type == DT_INVALID || element_type == DT_VARIANT) return;
+      is_host_data_type = DataTypeAlwaysOnHost(element_type);
+    };
+
+    ReverseDFSFrom(graph_, {node}, enter, /*leave=*/nullptr,
+                   /*stable_comparator=*/nullptr, edge_filter);
+
+    if (is_host_data_type.has_value() && *is_host_data_type) {
+      VLOG(2) << "Limit node possible devices to CPU only, because it has a "
+                 "DT_VARIANT input with host-only underlying data type: "
+              << "node=" << node->name();
+
+      // Restrict possible device types to CPU only.
+      PossibleDevices possible_devices;
+      absl::c_copy_if(root.supported_device_types(),
+                      std::back_inserter(possible_devices.device_types),
+                      is_cpu_device);
+
+      TF_RETURN_IF_ERROR(root.LimitToPossibleDevices(
+          possible_devices, /*allow_soft_placement=*/false));
+    }
+  }
+
+  return Status::OK();
+}
+
 Status ColocationGraph::AddInspectionConstraints(
     const std::unordered_set<Node*>& inspection_required) {
   for (Node* node : inspection_required) {
@@ -744,6 +823,7 @@ Status ColocationGraph::Initialize() {
 
   std::unordered_set<Node*> inspection_required;
   TF_RETURN_IF_ERROR(ColocateResourceAndRefEdges(&inspection_required));
+  TF_RETURN_IF_ERROR(AddHostOnlyDataTypesConstraints());
   TF_RETURN_IF_ERROR(AddInspectionConstraints(inspection_required));
   TF_RETURN_IF_ERROR(ColocateAllNodes());
 
diff --git a/tensorflow/core/common_runtime/colocation_graph.h b/tensorflow/core/common_runtime/colocation_graph.h
index 65fddf931ef..d0714d54a5a 100644
--- a/tensorflow/core/common_runtime/colocation_graph.h
+++ b/tensorflow/core/common_runtime/colocation_graph.h
@@ -283,6 +283,14 @@ class ColocationGraph {
   Status ColocateResourceAndRefEdges(
       std::unordered_set<Node*>* inspection_required);
 
+  // Updates this ColocationGraph by making sure that all nodes having inputs of
+  // a DT_VARIANT data type with a host-only underlying types (e.g. strings) can
+  // be placed only on CPU device. We do that by reverse-DFS traversal from all
+  // nodes that take variant inputs to the node that produces that variant.
+  // TODO(ezhulenev): This function does not yet support "deep op" inspection,
+  // that we have for DT_RESOURCE edges.
+  Status AddHostOnlyDataTypesConstraints();
+
   Status AddInspectionConstraints(
       const std::unordered_set<Node*>& inspection_required);
 
diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc
index 5524ab53c5a..f80822d5b00 100644
--- a/tensorflow/core/graph/algorithm.cc
+++ b/tensorflow/core/graph/algorithm.cc
@@ -112,8 +112,10 @@ void DFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
 
 void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
                 const std::function<void(Node*)>& leave,
-                const NodeComparator& stable_comparator) {
-  ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator);
+                const NodeComparator& stable_comparator,
+                const EdgeFilter& edge_filter) {
+  ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator,
+                 edge_filter);
 }
 
 namespace {
@@ -122,7 +124,8 @@ template <typename T>
 void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
                           const std::function<void(T)>& enter,
                           const std::function<void(T)>& leave,
-                          const NodeComparator& stable_comparator) {
+                          const NodeComparator& stable_comparator,
+                          const EdgeFilter& edge_filter) {
   // Stack of work to do.
   struct Work {
     T node;
@@ -161,7 +164,9 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
     if (stable_comparator) {
       std::vector<T> nodes_sorted;
       for (const Edge* in_edge : n->in_edges()) {
-        nodes_sorted.emplace_back(in_edge->src());
+        if (!edge_filter || edge_filter(*in_edge)) {
+          nodes_sorted.emplace_back(in_edge->src());
+        }
       }
       std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
       for (T in : nodes_sorted) {
@@ -169,7 +174,9 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
       }
     } else {
       for (const Edge* in_edge : n->in_edges()) {
-        add_work(in_edge->src());
+        if (!edge_filter || edge_filter(*in_edge)) {
+          add_work(in_edge->src());
+        }
       }
     }
   }
@@ -180,15 +187,17 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
                     const std::function<void(const Node*)>& enter,
                     const std::function<void(const Node*)>& leave,
-                    const NodeComparator& stable_comparator) {
-  ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
+                    const NodeComparator& stable_comparator,
+                    const EdgeFilter& edge_filter) {
+  ReverseDFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
 }
 
 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
                     const std::function<void(Node*)>& enter,
                     const std::function<void(Node*)>& leave,
-                    const NodeComparator& stable_comparator) {
-  ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
+                    const NodeComparator& stable_comparator,
+                    const EdgeFilter& edge_filter) {
+  ReverseDFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
 }
 
 void GetPostOrder(const Graph& g, std::vector<Node*>* order,
diff --git a/tensorflow/core/graph/algorithm.h b/tensorflow/core/graph/algorithm.h
index 8774a67a91e..9a9595a86d6 100644
--- a/tensorflow/core/graph/algorithm.h
+++ b/tensorflow/core/graph/algorithm.h
@@ -77,23 +77,28 @@ extern void DFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
 // If leave is not empty, calls leave(n) after visiting all parents of n.
 // If stable_comparator is set, a stable ordering of visit is achieved by
 // sorting a node's neighbors first before visiting them.
+// If edge_filter is set then ignores edges for which edge_filter returns false.
 extern void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
                        const std::function<void(Node*)>& leave,
-                       const NodeComparator& stable_comparator = {});
+                       const NodeComparator& stable_comparator = {},
+                       const EdgeFilter& edge_filter = {});
 
 // Perform a reverse depth-first-search on g starting at the 'start' nodes.
 // If enter is not empty, calls enter(n) before visiting any parents of n.
 // If leave is not empty, calls leave(n) after visiting all parents of n.
 // If stable_comparator is set, a stable ordering of visit is achieved by
 // sorting a node's neighbors first before visiting them.
+// If edge_filter is set then ignores edges for which edge_filter returns false.
 extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
                            const std::function<void(Node*)>& enter,
                            const std::function<void(Node*)>& leave,
-                           const NodeComparator& stable_comparator = {});
+                           const NodeComparator& stable_comparator = {},
+                           const EdgeFilter& edge_filter = {});
 extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
                            const std::function<void(const Node*)>& enter,
                            const std::function<void(const Node*)>& leave,
-                           const NodeComparator& stable_comparator = {});
+                           const NodeComparator& stable_comparator = {},
+                           const EdgeFilter& edge_filter = {});
 
 // Stores in *order the post-order numbering of all nodes
 // in graph found via a depth first search starting at the source node.
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index 11f882b5bf3..e618e21ed9d 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.client import session
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
+from tensorflow.python.eager import function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -1632,14 +1633,37 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
       self.assertAllEqual(t, [1.0, 2.0, 3.0])
 
   def testTensorListStrings(self):
-    self.skipTest("b/150742232")
-
     @def_function.function
     def f():
       return map_fn.map_fn(string_ops.string_upper,
                            constant_op.constant(["a", "b", "c"]))
 
-    self.assertAllEqual(f(), ["A", "B", "C"])
+    self.assertAllEqual(f(), [b"A", b"B", b"C"])
+
+  def testTensorListStringsNoInline(self):
+    # Generator function output type is a variant with a host-only underlying
+    # data type. "ColocationGraph::AddHostOnlyDataTypesConstraints" needs to
+    # have "deep op inspection" to be able to correctly place the while loop
+    # generated from map_fn.
+    self.skipTest("b/150742232")
+
+    @function.defun_with_attributes(attributes={"_noinline": True})
+    def generator():
+      c = constant_op.constant(["a", "b", "c"])
+      return list_ops.tensor_list_from_tensor(c, element_shape=[])
+
+    @def_function.function
+    def f():
+      l = generator()
+
+      def upper(i):
+        e = list_ops.tensor_list_get_item(l, i, element_dtype=dtypes.string)
+        return string_ops.string_upper(e)
+
+      return map_fn.map_fn(
+          upper, constant_op.constant([0, 1, 2]), dtype=dtypes.string)
+
+    self.assertAllEqual(f(), [b"A", b"B", b"C"])
 
 
 if __name__ == "__main__":