Force CPU placement for ops that has DT_VARIANT inputs with host-only underlying data type.

Fix for #28007

PiperOrigin-RevId: 301650148
Change-Id: I47fa9c1b0b7a7d56c5a519095687f36651892644
This commit is contained in:
Eugene Zhulenev 2020-03-18 13:03:19 -07:00 committed by TensorFlower Gardener
parent 09cf2e6d20
commit eaa1729a52
6 changed files with 142 additions and 15 deletions

View File

@ -2668,6 +2668,7 @@ tf_cuda_library(
"@com_google_absl//absl/base",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",

View File

@ -23,7 +23,9 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/function.h"
@ -39,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_node_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@ -726,6 +729,82 @@ Status ColocationGraph::ColocateResourceAndRefEdges(
return Status::OK();
}
namespace {
// Returns tensor list element data type, if the node is one of the ops that
// operate with TensorLists. Otherwise returns DT_INVALID.
DataType GetElementDataType(const Node& node) {
static absl::flat_hash_set<std::string>* tensor_list_ops =
new absl::flat_hash_set<std::string>(
{"TensorListReserve", "TensorListFromTensor", "EmptyTensorList",
"TensorListSplit", "TensorListScatter", "TensorListScatterV2",
"TensorListScatterIntoExistingList", "TensorListPushBack",
"TensorListPushBackBatch", "TensorListPopBack", "TensorListStack",
"TensorListConcat", "TensorListConcatV2", "TensorListGetItem",
"TensorListSetItem", "TensorListGather", "TensorListConcatLists"});
if (tensor_list_ops->contains(node.type_string())) {
DataType element_type;
if (GetNodeAttr(node.attrs(), "element_dtype", &element_type).ok()) {
return element_type;
}
}
return DT_INVALID;
}
} // namespace
Status ColocationGraph::AddHostOnlyDataTypesConstraints() {
auto is_variant = [](DataType dtype) -> bool { return dtype == DT_VARIANT; };
auto is_cpu_device = [](const std::pair<DeviceType, int32>& entry) -> bool {
return entry.first == DEVICE_CPU;
};
for (Node* node : graph_.nodes()) {
// Skip nodes that do not have DT_VARIANT inputs.
if (absl::c_none_of(node->input_types(), is_variant)) continue;
// Skip nodes that can't be placed on GPU anyway.
Member& root = members_[FindAndUpdateRoot(node->id())];
if (absl::c_all_of(root.supported_device_types(), is_cpu_device)) continue;
// Stop DFS traversal when found the underlying data type of a variant.
absl::optional<bool> is_host_data_type;
auto edge_filter = [&](const Edge& edge) -> bool {
return !is_host_data_type.has_value();
};
auto enter = [&](Node* n) -> void {
DataType element_type = GetElementDataType(*n);
// To handle nested lists continue traversal after finding a TensorList
// operation that uses DT_VARIANT for element type.
if (element_type == DT_INVALID || element_type == DT_VARIANT) return;
is_host_data_type = DataTypeAlwaysOnHost(element_type);
};
ReverseDFSFrom(graph_, {node}, enter, /*leave=*/nullptr,
/*stable_comparator=*/nullptr, edge_filter);
if (is_host_data_type.has_value() && *is_host_data_type) {
VLOG(2) << "Limit node possible devices to CPU only, because it has a "
"DT_VARIANT input with host-only underlying data type: "
<< "node=" << node->name();
// Restrict possible device types to CPU only.
PossibleDevices possible_devices;
absl::c_copy_if(root.supported_device_types(),
std::back_inserter(possible_devices.device_types),
is_cpu_device);
TF_RETURN_IF_ERROR(root.LimitToPossibleDevices(
possible_devices, /*allow_soft_placement=*/false));
}
}
return Status::OK();
}
Status ColocationGraph::AddInspectionConstraints(
const std::unordered_set<Node*>& inspection_required) {
for (Node* node : inspection_required) {
@ -744,6 +823,7 @@ Status ColocationGraph::Initialize() {
std::unordered_set<Node*> inspection_required;
TF_RETURN_IF_ERROR(ColocateResourceAndRefEdges(&inspection_required));
TF_RETURN_IF_ERROR(AddHostOnlyDataTypesConstraints());
TF_RETURN_IF_ERROR(AddInspectionConstraints(inspection_required));
TF_RETURN_IF_ERROR(ColocateAllNodes());

View File

@ -283,6 +283,14 @@ class ColocationGraph {
Status ColocateResourceAndRefEdges(
std::unordered_set<Node*>* inspection_required);
// Updates this ColocationGraph by making sure that all nodes having inputs of
// a DT_VARIANT data type with a host-only underlying types (e.g. strings) can
// be placed only on CPU device. We do that by reverse-DFS traversal from all
// nodes that take variant inputs to the node that produces that variant.
// TODO(ezhulenev): This function does not yet support "deep op" inspection,
// that we have for DT_RESOURCE edges.
Status AddHostOnlyDataTypesConstraints();
Status AddInspectionConstraints(
const std::unordered_set<Node*>& inspection_required);

View File

@ -112,8 +112,10 @@ void DFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
const std::function<void(Node*)>& leave,
const NodeComparator& stable_comparator) {
ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator);
const NodeComparator& stable_comparator,
const EdgeFilter& edge_filter) {
ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator,
edge_filter);
}
namespace {
@ -122,7 +124,8 @@ template <typename T>
void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
const std::function<void(T)>& enter,
const std::function<void(T)>& leave,
const NodeComparator& stable_comparator) {
const NodeComparator& stable_comparator,
const EdgeFilter& edge_filter) {
// Stack of work to do.
struct Work {
T node;
@ -161,7 +164,9 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
if (stable_comparator) {
std::vector<T> nodes_sorted;
for (const Edge* in_edge : n->in_edges()) {
nodes_sorted.emplace_back(in_edge->src());
if (!edge_filter || edge_filter(*in_edge)) {
nodes_sorted.emplace_back(in_edge->src());
}
}
std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
for (T in : nodes_sorted) {
@ -169,7 +174,9 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
}
} else {
for (const Edge* in_edge : n->in_edges()) {
add_work(in_edge->src());
if (!edge_filter || edge_filter(*in_edge)) {
add_work(in_edge->src());
}
}
}
}
@ -180,15 +187,17 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
const std::function<void(const Node*)>& enter,
const std::function<void(const Node*)>& leave,
const NodeComparator& stable_comparator) {
ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
const NodeComparator& stable_comparator,
const EdgeFilter& edge_filter) {
ReverseDFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
}
void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
const std::function<void(Node*)>& enter,
const std::function<void(Node*)>& leave,
const NodeComparator& stable_comparator) {
ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
const NodeComparator& stable_comparator,
const EdgeFilter& edge_filter) {
ReverseDFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
}
void GetPostOrder(const Graph& g, std::vector<Node*>* order,

View File

@ -77,23 +77,28 @@ extern void DFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
// If leave is not empty, calls leave(n) after visiting all parents of n.
// If stable_comparator is set, a stable ordering of visit is achieved by
// sorting a node's neighbors first before visiting them.
// If edge_filter is set then ignores edges for which edge_filter returns false.
extern void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
const std::function<void(Node*)>& leave,
const NodeComparator& stable_comparator = {});
const NodeComparator& stable_comparator = {},
const EdgeFilter& edge_filter = {});
// Perform a reverse depth-first-search on g starting at the 'start' nodes.
// If enter is not empty, calls enter(n) before visiting any parents of n.
// If leave is not empty, calls leave(n) after visiting all parents of n.
// If stable_comparator is set, a stable ordering of visit is achieved by
// sorting a node's neighbors first before visiting them.
// If edge_filter is set then ignores edges for which edge_filter returns false.
extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
const std::function<void(Node*)>& enter,
const std::function<void(Node*)>& leave,
const NodeComparator& stable_comparator = {});
const NodeComparator& stable_comparator = {},
const EdgeFilter& edge_filter = {});
extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
const std::function<void(const Node*)>& enter,
const std::function<void(const Node*)>& leave,
const NodeComparator& stable_comparator = {});
const NodeComparator& stable_comparator = {},
const EdgeFilter& edge_filter = {});
// Stores in *order the post-order numbering of all nodes
// in graph found via a depth first search starting at the source node.

View File

@ -26,6 +26,7 @@ from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -1632,14 +1633,37 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertAllEqual(t, [1.0, 2.0, 3.0])
def testTensorListStrings(self):
self.skipTest("b/150742232")
@def_function.function
def f():
return map_fn.map_fn(string_ops.string_upper,
constant_op.constant(["a", "b", "c"]))
self.assertAllEqual(f(), ["A", "B", "C"])
self.assertAllEqual(f(), [b"A", b"B", b"C"])
def testTensorListStringsNoInline(self):
# Generator function output type is a variant with a host-only underlying
# data type. "ColocationGraph::AddHostOnlyDataTypesConstraints" needs to
# have "deep op inspection" to be able to correctly place the while loop
# generated from map_fn.
self.skipTest("b/150742232")
@function.defun_with_attributes(attributes={"_noinline": True})
def generator():
c = constant_op.constant(["a", "b", "c"])
return list_ops.tensor_list_from_tensor(c, element_shape=[])
@def_function.function
def f():
l = generator()
def upper(i):
e = list_ops.tensor_list_get_item(l, i, element_dtype=dtypes.string)
return string_ops.string_upper(e)
return map_fn.map_fn(
upper, constant_op.constant([0, 1, 2]), dtype=dtypes.string)
self.assertAllEqual(f(), [b"A", b"B", b"C"])
if __name__ == "__main__":