Ensuring that the Switch op used as a pivot is always placed on the CPU. For this we set a private attribute _PivotSwitch while creating this op and then make sure that the device overwriting logic in GraphPartition isn't executed for this op.
Note: Had to fix up control_flow_ops_py_test so that we don't expect a GPU graph when we don't get one. The reason is that now since we already know the switch_pred is going to be placed on CPU, the placer ensures that its input is placed on the CPU as well and we end up saving a copy. This means there is no GPU graph when we partition. PiperOrigin-RevId: 338246477 Change-Id: I5641c9ae1b2d593a2996947bafe92b22cb63371d
This commit is contained in:
parent
be24f6daef
commit
673b993983
tensorflow
core
python/kernel_tests
@ -2522,10 +2522,11 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
tf_cc_test_gpu(
|
||||
name = "lower_if_op_test",
|
||||
size = "small",
|
||||
srcs = ["lower_if_op_test.cc"],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":core_cpu",
|
||||
":core_cpu_internal",
|
||||
|
@ -148,13 +148,22 @@ Status CondBuilder::SetColocationAndFinalize(NodeBuilder node_builder,
|
||||
Status CondBuilder::CreatePivotNodes() {
|
||||
// Construct the basic cond body (consisting of feeding in the predicate to
|
||||
// create pivot nodes).
|
||||
|
||||
// This is a special pivot switch node for lowering. We mark this with a
|
||||
// special _PivotSwitch attr on it as later on in the graph partitioner we
|
||||
// do some special placement for Switch nodes and its necessary to distinguish
|
||||
// between a "normal" Switch node and one of these pivot switches. We would
|
||||
// like to place this node on the CPU always as the pred_ will be on the CPU
|
||||
// as well (either a CPU op output or a GPU op with HostMemory annotation).
|
||||
// TODO(b/171321391): Fix this for NUMA cases.
|
||||
Node* switch_pred;
|
||||
TF_RETURN_IF_ERROR(
|
||||
SetColocationAndFinalize(NodeBuilder(NewName("switch_pred"), "Switch",
|
||||
graph_->op_registry(), &debug_info_)
|
||||
.Input(NodeOut(pred_))
|
||||
.Input(NodeOut(pred_))
|
||||
.Device(if_op_->requested_device()),
|
||||
.Attr("_PivotSwitch", true)
|
||||
.Device("/CPU:0"),
|
||||
graph_, &switch_pred));
|
||||
control_predecessor_ = switch_pred;
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -147,6 +147,115 @@ TEST(LowerIfOpTest, Simple) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(LowerIfOpTest, GPUPlacement) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
|
||||
// Add test functions for then and else branch.
|
||||
FunctionDefLibrary f_lib_proto;
|
||||
*(f_lib_proto.add_function()) = test::function::XTimesTwo();
|
||||
*(f_lib_proto.add_function()) = test::function::XTimesFour();
|
||||
|
||||
// Construct simple conditional that switches on `pred` and operates only on
|
||||
// single input `A`.
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
|
||||
auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
|
||||
auto x = ops::Placeholder(root.WithOpName("X"), DT_INT32);
|
||||
auto y = ops::Placeholder(root.WithOpName("Y"), DT_INT32);
|
||||
Node* pred;
|
||||
TF_ASSERT_OK(NodeBuilder("greater", "Greater", &root.graph()->flib_def())
|
||||
.Input(x.node())
|
||||
.Input(y.node())
|
||||
.Device("/GPU:0")
|
||||
.Finalize(root.graph(), &pred));
|
||||
Node* written_if;
|
||||
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
|
||||
TF_ASSERT_OK(
|
||||
NodeBuilder("if", "If", &root.graph()->flib_def())
|
||||
.Input(pred)
|
||||
.Input(inputs)
|
||||
.Attr("then_branch", FuncAttr("XTimesTwo"))
|
||||
.Attr("else_branch", FuncAttr("XTimesFour"))
|
||||
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
|
||||
.Attr("Tout", {DT_INT32})
|
||||
.Device("/GPU:0")
|
||||
.Finalize(root.graph(), &written_if));
|
||||
TF_ASSERT_OK(root.DoShapeInference(written_if));
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
// The input graph has no switch or merge nodes.
|
||||
int node_called_if_count = 0;
|
||||
for (const auto* op : graph->op_nodes()) {
|
||||
ASSERT_FALSE(op->IsSwitch());
|
||||
ASSERT_FALSE(op->IsMerge());
|
||||
if (op->name() == "if") {
|
||||
++node_called_if_count;
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(node_called_if_count, 1);
|
||||
|
||||
TF_ASSERT_OK(Rewrite(&graph));
|
||||
|
||||
// Verify the resultant graph has switch and merge nodes, and a node called
|
||||
// `if` (but not If nodes).
|
||||
int switch_count = 0;
|
||||
int merge_count = 0;
|
||||
node_called_if_count = 0;
|
||||
for (const auto* op : graph->op_nodes()) {
|
||||
if (op->IsSwitch()) {
|
||||
++switch_count;
|
||||
}
|
||||
if (op->IsMerge()) {
|
||||
++merge_count;
|
||||
}
|
||||
ASSERT_NE(op->type_string(), "If");
|
||||
if (op->name() == "if") {
|
||||
++node_called_if_count;
|
||||
}
|
||||
}
|
||||
// One switch for predicate and one for input (A).
|
||||
ASSERT_EQ(switch_count, 2);
|
||||
// One merge for the single output value of then and else, and one more merge
|
||||
// to enforce then and else function call execution (`branch_executed` node).
|
||||
ASSERT_EQ(merge_count, 2);
|
||||
ASSERT_EQ(node_called_if_count, 1);
|
||||
|
||||
// Verify execution.
|
||||
ClientSession session(root, SessionOptionsWithInlining());
|
||||
{
|
||||
RunMetadata metadata;
|
||||
RunOptions options;
|
||||
options.set_output_partition_graphs(true);
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(x.node()), Input::Initializer(5));
|
||||
feeds.emplace(Output(y.node()), Input::Initializer(10));
|
||||
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(options, feeds, {Output(written_if)}, {},
|
||||
&out_tensors, &metadata));
|
||||
GraphDef cpu_graph = metadata.partition_graphs(1);
|
||||
int num_cpu_switch = 0;
|
||||
for (const auto& node : cpu_graph.node()) {
|
||||
if (node.op() == "Switch") {
|
||||
++num_cpu_switch;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(num_cpu_switch, 2);
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 40);
|
||||
}
|
||||
{
|
||||
ClientSession::FeedType feeds;
|
||||
feeds.emplace(Output(x.node()), Input::Initializer(10));
|
||||
feeds.emplace(Output(y.node()), Input::Initializer(5));
|
||||
feeds.emplace(Output(a.node()), Input::Initializer(10));
|
||||
std::vector<Tensor> out_tensors;
|
||||
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
|
||||
EXPECT_EQ(out_tensors.size(), 1);
|
||||
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(LowerIfOpTest, BranchFunctionsWithoutOutputs) {
|
||||
using ::tensorflow::test::function::GDef;
|
||||
using ::tensorflow::test::function::NDef;
|
||||
|
@ -371,6 +371,13 @@ NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
|
||||
void OptimizeControlFlowColocation(Graph* graph) {
|
||||
auto visit = [](Node* node) {
|
||||
if (IsSwitch(node)) {
|
||||
// Pivot Switch nodes (which are also of type Switch) are already placed
|
||||
// on the CPU and colocated with its inputs that are also already on the
|
||||
// CPU (or might be placed on GPU but in host memory).
|
||||
if (HasNodeAttr(node->def(), "_PivotSwitch")) {
|
||||
DCHECK(node->requested_device().find("CPU") != string::npos);
|
||||
return;
|
||||
}
|
||||
for (const Edge* in_edge : node->in_edges()) {
|
||||
if (in_edge->dst_input() == 0) {
|
||||
// Colocate with the data input.
|
||||
|
@ -730,6 +730,8 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
||||
g for g in run_metadata.partition_graphs
|
||||
if device_str in g.node[0].device
|
||||
]
|
||||
if not device_graphs:
|
||||
return 0
|
||||
self.assertLen(device_graphs, 1)
|
||||
switch_nodes = [
|
||||
n for n in device_graphs[0].node
|
||||
@ -759,7 +761,6 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
||||
options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
sess.run(
|
||||
r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
|
||||
self.assertLen(run_metadata.partition_graphs, 2)
|
||||
# Check that the Switch for `arg` gets placed on CPU.
|
||||
self.assertEqual(
|
||||
self._count_matching_switch_nodes_on_device(run_metadata, "CPU",
|
||||
|
Loading…
Reference in New Issue
Block a user