Rollback changelist 338246477

Rollback changelist 338246477 because it sometimes results in ops placed on the local host that shouldn't be. In certain cases the tpu worker will try to retrieve the ops from the local host and fail.

PiperOrigin-RevId: 346564821
Change-Id: I22eb50264825a0e6076b9b682568fef3b6b5c669
This commit is contained in:
Marissa Ikonomidis 2020-12-09 09:11:00 -08:00 committed by TensorFlower Gardener
parent 070b02f441
commit 7d7d7ec173
5 changed files with 3 additions and 132 deletions

View File

@ -2513,13 +2513,10 @@ tf_cc_test(
],
)
tf_cuda_cc_test(
tf_cc_test(
name = "lower_if_op_test",
size = "small",
srcs = ["lower_if_op_test.cc"],
tags = tf_cuda_tests_tags() + [
"no_cuda_asan", # TODO(b/171575050): re-enable once fixed.
],
deps = [
":core_cpu",
":core_cpu_internal",

View File

@ -148,22 +148,13 @@ Status CondBuilder::SetColocationAndFinalize(NodeBuilder node_builder,
Status CondBuilder::CreatePivotNodes() {
// Construct the basic cond body (consisting of feeding in the predicate to
// create pivot nodes).
// This is a special pivot switch node for lowering. We mark this with a
// special _PivotSwitch attr on it as later on in the graph partitioner we
// do some special placement for Switch nodes and its necessary to distinguish
// between a "normal" Switch node and one of these pivot switches. We would
// like to place this node on the CPU always as the pred_ will be on the CPU
// as well (either a CPU op output or a GPU op with HostMemory annotation).
// TODO(b/171321391): Fix this for NUMA cases.
Node* switch_pred;
TF_RETURN_IF_ERROR(
SetColocationAndFinalize(NodeBuilder(NewName("switch_pred"), "Switch",
graph_->op_registry(), &debug_info_)
.Input(NodeOut(pred_))
.Input(NodeOut(pred_))
.Attr("_PivotSwitch", true)
.Device("/CPU:0"),
.Device(if_op_->requested_device()),
graph_, &switch_pred));
control_predecessor_ = switch_pred;
TF_RETURN_IF_ERROR(

View File

@ -147,115 +147,6 @@ TEST(LowerIfOpTest, Simple) {
}
}
TEST(LowerIfOpTest, GPUPlacement) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
// Add test functions for then and else branch.
FunctionDefLibrary f_lib_proto;
*(f_lib_proto.add_function()) = test::function::XTimesTwo();
*(f_lib_proto.add_function()) = test::function::XTimesFour();
// Construct simple conditional that switches on `pred` and operates only on
// single input `A`.
Scope root = Scope::NewRootScope().ExitOnError();
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
auto x = ops::Placeholder(root.WithOpName("X"), DT_INT32);
auto y = ops::Placeholder(root.WithOpName("Y"), DT_INT32);
Node* pred;
TF_ASSERT_OK(NodeBuilder("greater", "Greater", &root.graph()->flib_def())
.Input(x.node())
.Input(y.node())
.Device("/GPU:0")
.Finalize(root.graph(), &pred));
Node* written_if;
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
TF_ASSERT_OK(
NodeBuilder("if", "If", &root.graph()->flib_def())
.Input(pred)
.Input(inputs)
.Attr("then_branch", FuncAttr("XTimesTwo"))
.Attr("else_branch", FuncAttr("XTimesFour"))
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
.Attr("Tout", {DT_INT32})
.Device("/GPU:0")
.Finalize(root.graph(), &written_if));
TF_ASSERT_OK(root.DoShapeInference(written_if));
TF_ASSERT_OK(root.ToGraph(graph.get()));
// The input graph has no switch or merge nodes.
int node_called_if_count = 0;
for (const auto* op : graph->op_nodes()) {
ASSERT_FALSE(op->IsSwitch());
ASSERT_FALSE(op->IsMerge());
if (op->name() == "if") {
++node_called_if_count;
}
}
ASSERT_EQ(node_called_if_count, 1);
TF_ASSERT_OK(Rewrite(&graph));
// Verify the resultant graph has switch and merge nodes, and a node called
// `if` (but not If nodes).
int switch_count = 0;
int merge_count = 0;
node_called_if_count = 0;
for (const auto* op : graph->op_nodes()) {
if (op->IsSwitch()) {
++switch_count;
}
if (op->IsMerge()) {
++merge_count;
}
ASSERT_NE(op->type_string(), "If");
if (op->name() == "if") {
++node_called_if_count;
}
}
// One switch for predicate and one for input (A).
ASSERT_EQ(switch_count, 2);
// One merge for the single output value of then and else, and one more merge
// to enforce then and else function call execution (`branch_executed` node).
ASSERT_EQ(merge_count, 2);
ASSERT_EQ(node_called_if_count, 1);
// Verify execution.
ClientSession session(root, SessionOptionsWithInlining());
{
RunMetadata metadata;
RunOptions options;
options.set_output_partition_graphs(true);
ClientSession::FeedType feeds;
feeds.emplace(Output(x.node()), Input::Initializer(5));
feeds.emplace(Output(y.node()), Input::Initializer(10));
feeds.emplace(Output(a.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(options, feeds, {Output(written_if)}, {},
&out_tensors, &metadata));
GraphDef cpu_graph = metadata.partition_graphs(1);
int num_cpu_switch = 0;
for (const auto& node : cpu_graph.node()) {
if (node.op() == "Switch") {
++num_cpu_switch;
}
}
EXPECT_EQ(num_cpu_switch, 2);
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 40);
}
{
ClientSession::FeedType feeds;
feeds.emplace(Output(x.node()), Input::Initializer(10));
feeds.emplace(Output(y.node()), Input::Initializer(5));
feeds.emplace(Output(a.node()), Input::Initializer(10));
std::vector<Tensor> out_tensors;
TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
EXPECT_EQ(out_tensors.size(), 1);
EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
}
}
TEST(LowerIfOpTest, BranchFunctionsWithoutOutputs) {
using ::tensorflow::test::function::GDef;
using ::tensorflow::test::function::NDef;

View File

@ -371,13 +371,6 @@ NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
void OptimizeControlFlowColocation(Graph* graph) {
auto visit = [](Node* node) {
if (IsSwitch(node)) {
// Pivot Switch nodes (which are also of type Switch) are already placed
// on the CPU and colocated with its inputs that are also already on the
// CPU (or might be placed on GPU but in host memory).
if (HasNodeAttr(node->def(), "_PivotSwitch")) {
DCHECK(node->requested_device().find("CPU") != string::npos);
return;
}
for (const Edge* in_edge : node->in_edges()) {
if (in_edge->dst_input() == 0) {
// Colocate with the data input.

View File

@ -730,8 +730,6 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
g for g in run_metadata.partition_graphs
if device_str in g.node[0].device
]
if not device_graphs:
return 0
self.assertLen(device_graphs, 1)
switch_nodes = [
n for n in device_graphs[0].node
@ -761,6 +759,7 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
options = config_pb2.RunOptions(output_partition_graphs=True)
sess.run(
r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
self.assertLen(run_metadata.partition_graphs, 2)
# Check that the Switch for `arg` gets placed on CPU.
self.assertEqual(
self._count_matching_switch_nodes_on_device(run_metadata, "CPU",