Ensuring that the Switch op used as a pivot is always placed on the CPU. For this we set a private attribute _PivotSwitch while creating this op and then make sure that the device overwriting logic in GraphPartition isn't executed for this op.
Note: Had to fix up control_flow_ops_py_test so that we don't expect a GPU graph when we don't get one. The reason is that now since we already know the switch_pred is going to be placed on CPU, the placer ensures that its input is placed on the CPU as well and we end up saving a copy. This means there is no GPU graph when we partition. PiperOrigin-RevId: 338246477 Change-Id: I5641c9ae1b2d593a2996947bafe92b22cb63371d
This commit is contained in:
parent
be24f6daef
commit
673b993983
@ -2522,10 +2522,11 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test_gpu(
|
||||||
name = "lower_if_op_test",
|
name = "lower_if_op_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["lower_if_op_test.cc"],
|
srcs = ["lower_if_op_test.cc"],
|
||||||
|
tags = tf_cuda_tests_tags(),
|
||||||
deps = [
|
deps = [
|
||||||
":core_cpu",
|
":core_cpu",
|
||||||
":core_cpu_internal",
|
":core_cpu_internal",
|
||||||
|
@ -148,13 +148,22 @@ Status CondBuilder::SetColocationAndFinalize(NodeBuilder node_builder,
|
|||||||
Status CondBuilder::CreatePivotNodes() {
|
Status CondBuilder::CreatePivotNodes() {
|
||||||
// Construct the basic cond body (consisting of feeding in the predicate to
|
// Construct the basic cond body (consisting of feeding in the predicate to
|
||||||
// create pivot nodes).
|
// create pivot nodes).
|
||||||
|
|
||||||
|
// This is a special pivot switch node for lowering. We mark this with a
|
||||||
|
// special _PivotSwitch attr on it as later on in the graph partitioner we
|
||||||
|
// do some special placement for Switch nodes and its necessary to distinguish
|
||||||
|
// between a "normal" Switch node and one of these pivot switches. We would
|
||||||
|
// like to place this node on the CPU always as the pred_ will be on the CPU
|
||||||
|
// as well (either a CPU op output or a GPU op with HostMemory annotation).
|
||||||
|
// TODO(b/171321391): Fix this for NUMA cases.
|
||||||
Node* switch_pred;
|
Node* switch_pred;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
SetColocationAndFinalize(NodeBuilder(NewName("switch_pred"), "Switch",
|
SetColocationAndFinalize(NodeBuilder(NewName("switch_pred"), "Switch",
|
||||||
graph_->op_registry(), &debug_info_)
|
graph_->op_registry(), &debug_info_)
|
||||||
.Input(NodeOut(pred_))
|
.Input(NodeOut(pred_))
|
||||||
.Input(NodeOut(pred_))
|
.Input(NodeOut(pred_))
|
||||||
.Device(if_op_->requested_device()),
|
.Attr("_PivotSwitch", true)
|
||||||
|
.Device("/CPU:0"),
|
||||||
graph_, &switch_pred));
|
graph_, &switch_pred));
|
||||||
control_predecessor_ = switch_pred;
|
control_predecessor_ = switch_pred;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
|
@ -147,6 +147,115 @@ TEST(LowerIfOpTest, Simple) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(LowerIfOpTest, GPUPlacement) {
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
|
||||||
|
// Add test functions for then and else branch.
|
||||||
|
FunctionDefLibrary f_lib_proto;
|
||||||
|
*(f_lib_proto.add_function()) = test::function::XTimesTwo();
|
||||||
|
*(f_lib_proto.add_function()) = test::function::XTimesFour();
|
||||||
|
|
||||||
|
// Construct simple conditional that switches on `pred` and operates only on
|
||||||
|
// single input `A`.
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
|
||||||
|
auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
|
||||||
|
auto x = ops::Placeholder(root.WithOpName("X"), DT_INT32);
|
||||||
|
auto y = ops::Placeholder(root.WithOpName("Y"), DT_INT32);
|
||||||
|
Node* pred;
|
||||||
|
TF_ASSERT_OK(NodeBuilder("greater", "Greater", &root.graph()->flib_def())
|
||||||
|
.Input(x.node())
|
||||||
|
.Input(y.node())
|
||||||
|
.Device("/GPU:0")
|
||||||
|
.Finalize(root.graph(), &pred));
|
||||||
|
Node* written_if;
|
||||||
|
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
NodeBuilder("if", "If", &root.graph()->flib_def())
|
||||||
|
.Input(pred)
|
||||||
|
.Input(inputs)
|
||||||
|
.Attr("then_branch", FuncAttr("XTimesTwo"))
|
||||||
|
.Attr("else_branch", FuncAttr("XTimesFour"))
|
||||||
|
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
|
||||||
|
.Attr("Tout", {DT_INT32})
|
||||||
|
.Device("/GPU:0")
|
||||||
|
.Finalize(root.graph(), &written_if));
|
||||||
|
TF_ASSERT_OK(root.DoShapeInference(written_if));
|
||||||
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
// The input graph has no switch or merge nodes.
|
||||||
|
int node_called_if_count = 0;
|
||||||
|
for (const auto* op : graph->op_nodes()) {
|
||||||
|
ASSERT_FALSE(op->IsSwitch());
|
||||||
|
ASSERT_FALSE(op->IsMerge());
|
||||||
|
if (op->name() == "if") {
|
||||||
|
++node_called_if_count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_EQ(node_called_if_count, 1);
|
||||||
|
|
||||||
|
TF_ASSERT_OK(Rewrite(&graph));
|
||||||
|
|
||||||
|
// Verify the resultant graph has switch and merge nodes, and a node called
|
||||||
|
// `if` (but not If nodes).
|
||||||
|
int switch_count = 0;
|
||||||
|
int merge_count = 0;
|
||||||
|
node_called_if_count = 0;
|
||||||
|
for (const auto* op : graph->op_nodes()) {
|
||||||
|
if (op->IsSwitch()) {
|
||||||
|
++switch_count;
|
||||||
|
}
|
||||||
|
if (op->IsMerge()) {
|
||||||
|
++merge_count;
|
||||||
|
}
|
||||||
|
ASSERT_NE(op->type_string(), "If");
|
||||||
|
if (op->name() == "if") {
|
||||||
|
++node_called_if_count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// One switch for predicate and one for input (A).
|
||||||
|
ASSERT_EQ(switch_count, 2);
|
||||||
|
// One merge for the single output value of then and else, and one more merge
|
||||||
|
// to enforce then and else function call execution (`branch_executed` node).
|
||||||
|
ASSERT_EQ(merge_count, 2);
|
||||||
|
ASSERT_EQ(node_called_if_count, 1);
|
||||||
|
|
||||||
|
// Verify execution.
|
||||||
|
ClientSession session(root, SessionOptionsWithInlining());
|
||||||
|
{
|
||||||
|
RunMetadata metadata;
|
||||||
|
RunOptions options;
|
||||||
|
options.set_output_partition_graphs(true);
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(x.node()), Input::Initializer(5));
|
||||||
|
feeds.emplace(Output(y.node()), Input::Initializer(10));
|
||||||
|
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(options, feeds, {Output(written_if)}, {},
|
||||||
|
&out_tensors, &metadata));
|
||||||
|
GraphDef cpu_graph = metadata.partition_graphs(1);
|
||||||
|
int num_cpu_switch = 0;
|
||||||
|
for (const auto& node : cpu_graph.node()) {
|
||||||
|
if (node.op() == "Switch") {
|
||||||
|
++num_cpu_switch;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_EQ(num_cpu_switch, 2);
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 40);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ClientSession::FeedType feeds;
|
||||||
|
feeds.emplace(Output(x.node()), Input::Initializer(10));
|
||||||
|
feeds.emplace(Output(y.node()), Input::Initializer(5));
|
||||||
|
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
|
||||||
|
EXPECT_EQ(out_tensors.size(), 1);
|
||||||
|
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(LowerIfOpTest, BranchFunctionsWithoutOutputs) {
|
TEST(LowerIfOpTest, BranchFunctionsWithoutOutputs) {
|
||||||
using ::tensorflow::test::function::GDef;
|
using ::tensorflow::test::function::GDef;
|
||||||
using ::tensorflow::test::function::NDef;
|
using ::tensorflow::test::function::NDef;
|
||||||
|
@ -371,6 +371,13 @@ NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
|
|||||||
void OptimizeControlFlowColocation(Graph* graph) {
|
void OptimizeControlFlowColocation(Graph* graph) {
|
||||||
auto visit = [](Node* node) {
|
auto visit = [](Node* node) {
|
||||||
if (IsSwitch(node)) {
|
if (IsSwitch(node)) {
|
||||||
|
// Pivot Switch nodes (which are also of type Switch) are already placed
|
||||||
|
// on the CPU and colocated with its inputs that are also already on the
|
||||||
|
// CPU (or might be placed on GPU but in host memory).
|
||||||
|
if (HasNodeAttr(node->def(), "_PivotSwitch")) {
|
||||||
|
DCHECK(node->requested_device().find("CPU") != string::npos);
|
||||||
|
return;
|
||||||
|
}
|
||||||
for (const Edge* in_edge : node->in_edges()) {
|
for (const Edge* in_edge : node->in_edges()) {
|
||||||
if (in_edge->dst_input() == 0) {
|
if (in_edge->dst_input() == 0) {
|
||||||
// Colocate with the data input.
|
// Colocate with the data input.
|
||||||
|
@ -730,6 +730,8 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
|||||||
g for g in run_metadata.partition_graphs
|
g for g in run_metadata.partition_graphs
|
||||||
if device_str in g.node[0].device
|
if device_str in g.node[0].device
|
||||||
]
|
]
|
||||||
|
if not device_graphs:
|
||||||
|
return 0
|
||||||
self.assertLen(device_graphs, 1)
|
self.assertLen(device_graphs, 1)
|
||||||
switch_nodes = [
|
switch_nodes = [
|
||||||
n for n in device_graphs[0].node
|
n for n in device_graphs[0].node
|
||||||
@ -759,7 +761,6 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
|||||||
options = config_pb2.RunOptions(output_partition_graphs=True)
|
options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||||
sess.run(
|
sess.run(
|
||||||
r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
|
r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
|
||||||
self.assertLen(run_metadata.partition_graphs, 2)
|
|
||||||
# Check that the Switch for `arg` gets placed on CPU.
|
# Check that the Switch for `arg` gets placed on CPU.
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
|
self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user