From eaa1729a52952f2a541491a1ce2c34af7ab66fc8 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev <ezhulenev@google.com> Date: Wed, 18 Mar 2020 13:03:19 -0700 Subject: [PATCH] Force CPU placement for ops that has DT_VARIANT inputs with host-only underlying data type. Fix for #28007 PiperOrigin-RevId: 301650148 Change-Id: I47fa9c1b0b7a7d56c5a519095687f36651892644 --- tensorflow/core/BUILD | 1 + .../core/common_runtime/colocation_graph.cc | 80 +++++++++++++++++++ .../core/common_runtime/colocation_graph.h | 8 ++ tensorflow/core/graph/algorithm.cc | 27 ++++--- tensorflow/core/graph/algorithm.h | 11 ++- .../python/kernel_tests/list_ops_test.py | 30 ++++++- 6 files changed, 142 insertions(+), 15 deletions(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 188988d92c4..8efada20e24 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2668,6 +2668,7 @@ tf_cuda_library( "@com_google_absl//absl/base", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", diff --git a/tensorflow/core/common_runtime/colocation_graph.cc b/tensorflow/core/common_runtime/colocation_graph.cc index ab58902f415..41058ae208a 100644 --- a/tensorflow/core/common_runtime/colocation_graph.cc +++ b/tensorflow/core/common_runtime/colocation_graph.cc @@ -23,7 +23,9 @@ limitations under the License. #include <vector> #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_join.h" +#include "absl/types/optional.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/function.h" @@ -39,6 +41,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_node_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -726,6 +729,82 @@ Status ColocationGraph::ColocateResourceAndRefEdges( return Status::OK(); } +namespace { +// Returns tensor list element data type, if the node is one of the ops that +// operate with TensorLists. Otherwise returns DT_INVALID. +DataType GetElementDataType(const Node& node) { + static absl::flat_hash_set<std::string>* tensor_list_ops = + new absl::flat_hash_set<std::string>( + {"TensorListReserve", "TensorListFromTensor", "EmptyTensorList", + "TensorListSplit", "TensorListScatter", "TensorListScatterV2", + "TensorListScatterIntoExistingList", "TensorListPushBack", + "TensorListPushBackBatch", "TensorListPopBack", "TensorListStack", + "TensorListConcat", "TensorListConcatV2", "TensorListGetItem", + "TensorListSetItem", "TensorListGather", "TensorListConcatLists"}); + + if (tensor_list_ops->contains(node.type_string())) { + DataType element_type; + if (GetNodeAttr(node.attrs(), "element_dtype", &element_type).ok()) { + return element_type; + } + } + + return DT_INVALID; +} +} // namespace + +Status ColocationGraph::AddHostOnlyDataTypesConstraints() { + auto is_variant = [](DataType dtype) -> bool { return dtype == DT_VARIANT; }; + + auto is_cpu_device = [](const std::pair<DeviceType, int32>& entry) -> bool { + return entry.first == DEVICE_CPU; + }; + + for (Node* node : graph_.nodes()) { + // Skip nodes that do not have DT_VARIANT inputs. + if (absl::c_none_of(node->input_types(), is_variant)) continue; + + // Skip nodes that can't be placed on GPU anyway. + Member& root = members_[FindAndUpdateRoot(node->id())]; + if (absl::c_all_of(root.supported_device_types(), is_cpu_device)) continue; + + // Stop DFS traversal when found the underlying data type of a variant. + absl::optional<bool> is_host_data_type; + + auto edge_filter = [&](const Edge& edge) -> bool { + return !is_host_data_type.has_value(); + }; + + auto enter = [&](Node* n) -> void { + DataType element_type = GetElementDataType(*n); + // To handle nested lists continue traversal after finding a TensorList + // operation that uses DT_VARIANT for element type. + if (element_type == DT_INVALID || element_type == DT_VARIANT) return; + is_host_data_type = DataTypeAlwaysOnHost(element_type); + }; + + ReverseDFSFrom(graph_, {node}, enter, /*leave=*/nullptr, + /*stable_comparator=*/nullptr, edge_filter); + + if (is_host_data_type.has_value() && *is_host_data_type) { + VLOG(2) << "Limit node possible devices to CPU only, because it has a " + "DT_VARIANT input with host-only underlying data type: " + << "node=" << node->name(); + + // Restrict possible device types to CPU only. + PossibleDevices possible_devices; + absl::c_copy_if(root.supported_device_types(), + std::back_inserter(possible_devices.device_types), + is_cpu_device); + + TF_RETURN_IF_ERROR(root.LimitToPossibleDevices( + possible_devices, /*allow_soft_placement=*/false)); + } + } + + return Status::OK(); +} + Status ColocationGraph::AddInspectionConstraints( const std::unordered_set<Node*>& inspection_required) { for (Node* node : inspection_required) { @@ -744,6 +823,7 @@ Status ColocationGraph::Initialize() { std::unordered_set<Node*> inspection_required; TF_RETURN_IF_ERROR(ColocateResourceAndRefEdges(&inspection_required)); + TF_RETURN_IF_ERROR(AddHostOnlyDataTypesConstraints()); TF_RETURN_IF_ERROR(AddInspectionConstraints(inspection_required)); TF_RETURN_IF_ERROR(ColocateAllNodes()); diff --git a/tensorflow/core/common_runtime/colocation_graph.h b/tensorflow/core/common_runtime/colocation_graph.h index 65fddf931ef..d0714d54a5a 100644 --- a/tensorflow/core/common_runtime/colocation_graph.h +++ b/tensorflow/core/common_runtime/colocation_graph.h @@ -283,6 +283,14 @@ class ColocationGraph { Status ColocateResourceAndRefEdges( std::unordered_set<Node*>* inspection_required); + // Updates this ColocationGraph by making sure that all nodes having inputs of + // a DT_VARIANT data type with a host-only underlying types (e.g. strings) can + // be placed only on CPU device. We do that by reverse-DFS traversal from all + // nodes that take variant inputs to the node that produces that variant. + // TODO(ezhulenev): This function does not yet support "deep op" inspection, + // that we have for DT_RESOURCE edges. + Status AddHostOnlyDataTypesConstraints(); + Status AddInspectionConstraints( const std::unordered_set<Node*>& inspection_required); diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc index 5524ab53c5a..f80822d5b00 100644 --- a/tensorflow/core/graph/algorithm.cc +++ b/tensorflow/core/graph/algorithm.cc @@ -112,8 +112,10 @@ void DFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start, void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter, const std::function<void(Node*)>& leave, - const NodeComparator& stable_comparator) { - ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator); + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { + ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator, + edge_filter); } namespace { @@ -122,7 +124,8 @@ template <typename T> void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start, const std::function<void(T)>& enter, const std::function<void(T)>& leave, - const NodeComparator& stable_comparator) { + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { // Stack of work to do. struct Work { T node; @@ -161,7 +164,9 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start, if (stable_comparator) { std::vector<T> nodes_sorted; for (const Edge* in_edge : n->in_edges()) { - nodes_sorted.emplace_back(in_edge->src()); + if (!edge_filter || edge_filter(*in_edge)) { + nodes_sorted.emplace_back(in_edge->src()); + } } std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator); for (T in : nodes_sorted) { @@ -169,7 +174,9 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start, } } else { for (const Edge* in_edge : n->in_edges()) { - add_work(in_edge->src()); + if (!edge_filter || edge_filter(*in_edge)) { + add_work(in_edge->src()); + } } } } @@ -180,15 +187,17 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start, void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start, const std::function<void(const Node*)>& enter, const std::function<void(const Node*)>& leave, - const NodeComparator& stable_comparator) { - ReverseDFSFromHelper(g, start, enter, leave, stable_comparator); + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { + ReverseDFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter); } void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start, const std::function<void(Node*)>& enter, const std::function<void(Node*)>& leave, - const NodeComparator& stable_comparator) { - ReverseDFSFromHelper(g, start, enter, leave, stable_comparator); + const NodeComparator& stable_comparator, + const EdgeFilter& edge_filter) { + ReverseDFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter); } void GetPostOrder(const Graph& g, std::vector<Node*>* order, diff --git a/tensorflow/core/graph/algorithm.h b/tensorflow/core/graph/algorithm.h index 8774a67a91e..9a9595a86d6 100644 --- a/tensorflow/core/graph/algorithm.h +++ b/tensorflow/core/graph/algorithm.h @@ -77,23 +77,28 @@ extern void DFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start, // If leave is not empty, calls leave(n) after visiting all parents of n. // If stable_comparator is set, a stable ordering of visit is achieved by // sorting a node's neighbors first before visiting them. +// If edge_filter is set then ignores edges for which edge_filter returns false. extern void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter, const std::function<void(Node*)>& leave, - const NodeComparator& stable_comparator = {}); + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); // Perform a reverse depth-first-search on g starting at the 'start' nodes. // If enter is not empty, calls enter(n) before visiting any parents of n. // If leave is not empty, calls leave(n) after visiting all parents of n. // If stable_comparator is set, a stable ordering of visit is achieved by // sorting a node's neighbors first before visiting them. +// If edge_filter is set then ignores edges for which edge_filter returns false. extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start, const std::function<void(Node*)>& enter, const std::function<void(Node*)>& leave, - const NodeComparator& stable_comparator = {}); + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start, const std::function<void(const Node*)>& enter, const std::function<void(const Node*)>& leave, - const NodeComparator& stable_comparator = {}); + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); // Stores in *order the post-order numbering of all nodes // in graph found via a depth first search starting at the source node. diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index 11f882b5bf3..e618e21ed9d 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -26,6 +26,7 @@ from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function +from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -1632,14 +1633,37 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertAllEqual(t, [1.0, 2.0, 3.0]) def testTensorListStrings(self): - self.skipTest("b/150742232") - @def_function.function def f(): return map_fn.map_fn(string_ops.string_upper, constant_op.constant(["a", "b", "c"])) - self.assertAllEqual(f(), ["A", "B", "C"]) + self.assertAllEqual(f(), [b"A", b"B", b"C"]) + + def testTensorListStringsNoInline(self): + # Generator function output type is a variant with a host-only underlying + # data type. "ColocationGraph::AddHostOnlyDataTypesConstraints" needs to + # have "deep op inspection" to be able to correctly place the while loop + # generated from map_fn. + self.skipTest("b/150742232") + + @function.defun_with_attributes(attributes={"_noinline": True}) + def generator(): + c = constant_op.constant(["a", "b", "c"]) + return list_ops.tensor_list_from_tensor(c, element_shape=[]) + + @def_function.function + def f(): + l = generator() + + def upper(i): + e = list_ops.tensor_list_get_item(l, i, element_dtype=dtypes.string) + return string_ops.string_upper(e) + + return map_fn.map_fn( + upper, constant_op.constant([0, 1, 2]), dtype=dtypes.string) + + self.assertAllEqual(f(), [b"A", b"B", b"C"]) if __name__ == "__main__":