diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc index 33a7a631074..37a5ad5efaa 100644 --- a/tensorflow/core/common_runtime/simple_placer.cc +++ b/tensorflow/core/common_runtime/simple_placer.cc @@ -67,6 +67,7 @@ std::vector FilterSupportedDevices( return filtered_devices; } +// TODO(vrv): Remove "@" syntax capability. bool HasColocatedNodeName(const Node& node) { return StringPiece(node.def().device()).starts_with("@"); } @@ -83,6 +84,30 @@ Status ParseColocatedNodeName(const Node& node, return Status::OK(); } +// Returns the name of the colocation group of the node by inspecting +// the "_class" attribute of the NodeDef. Returns "" if it doesn't +// exist. +Status ColocationGroup(const Node& node, string* colocation_group) { + string class_spec; + // TODO(vrv): We should consider adding a GetNodeAttr that returns a + // StringPiece, to avoid a copy. + Status s = GetNodeAttr(node.def(), "_class", &class_spec); + if (!s.ok()) { + // No "_class" attribute is equivalent to the empty colocation_group. + *colocation_group = ""; + return Status::OK(); + } + + StringPiece spec(class_spec); + if (!spec.Consume("loc:")) { + return errors::InvalidArgument("Node had an invalid _class attribute: ", + class_spec); + } + + *colocation_group = spec.ToString(); + return Status::OK(); +} + // This class maintains the connected components of a colocation // constraint graph, and uses this information to assign a satisfying // device placement to the nodes of the graph. @@ -134,6 +159,24 @@ class ColocationGraph { CHECK_GE(member.parent, 0); members_.resize(member.parent + 1); members_[member.parent] = std::move(member); + + // When adding the node, identify whether it is part of a + // colocation group. + string colocation_group; + TF_RETURN_IF_ERROR(ColocationGroup(node, &colocation_group)); + if (!colocation_group.empty()) { + // Node has a colocation group specified. + auto it = colocation_group_root_.find(colocation_group); + if (it == colocation_group_root_.end()) { + // This is the first node of the colocation group, so + // designate this node as the 'root' of that colocation group. + colocation_group_root_[colocation_group] = &node; + } else { + // Colocate this node with the root. + ColocateNodes(node, *(it->second)); + } + } + return Status::OK(); } @@ -447,6 +490,10 @@ class ColocationGraph { const DeviceSet* device_set_; // Not owned. const std::vector device_types_; const SessionOptions* options_; // Not owned; + + // Maps from a colocation group identifier to the 'root' of that + // colocation group. + std::unordered_map colocation_group_root_; }; } // namespace diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc index 14871c7a292..5dfbb60c2bb 100644 --- a/tensorflow/core/common_runtime/simple_placer_test.cc +++ b/tensorflow/core/common_runtime/simple_placer_test.cc @@ -218,6 +218,13 @@ class SimplePlacerTest : public ::testing::Test { GetNodeByName(g_, (name_b))->assigned_device_name()); \ } while (0) +#define EXPECT_NOT_COLOCATED(g, name_a, name_b) \ + do { \ + Graph& g_ = (g); \ + EXPECT_NE(GetNodeByName(g_, (name_a))->assigned_device_name(), \ + GetNodeByName(g_, (name_b))->assigned_device_name()); \ + } while (0) + #define EXPECT_DEVICE_TYPE(g, name, expected_device_type) \ EXPECT_EQ(DeviceType(expected_device_type).type(), \ devices_.FindDeviceByName( \ @@ -473,6 +480,90 @@ TEST_F(SimplePlacerTest, TestColocatedChainWithLongRangeColocations) { } } +TEST_F(SimplePlacerTest, TestColocationGroup) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp( + "TestInput", b.opts().WithName("in").WithAttr("_class", "loc:ti")); + Node* colocated_with_input = ops::UnaryOp( + "TestRelu", input, + b.opts().WithName("colocated_1").WithAttr("_class", "loc:ti")); + + // This will not be colocated with the input because TestInput is + // only availbale on CPU and TestRelu will default to GPU. + Node* not_colocated_with_input = + ops::UnaryOp("TestRelu", input, b.opts().WithName("foo")); + CHECK(colocated_with_input); + CHECK(not_colocated_with_input); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_COLOCATED(g, "in", "colocated_1"); + EXPECT_NOT_COLOCATED(g, "in", "foo"); +} + +TEST_F(SimplePlacerTest, TestColocationGroupWithReferenceConnections) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1")); + Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2")); + + // Two assigns (reference connections) with two different + // colocation groups. Because their colocation groups all map to the + // same device, this is a valid assignment. + ops::BinaryOp("TestAssign", var1, input, + b.opts().WithName("assign1").WithAttr("_class", "loc:1")); + ops::BinaryOp("TestAssign", var2, input, + b.opts().WithName("assign2").WithAttr("_class", "loc:2")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_COLOCATED(g, "in", "var1"); + EXPECT_COLOCATED(g, "in", "var2"); + EXPECT_COLOCATED(g, "var1", "assign2"); + EXPECT_COLOCATED(g, "var2", "assign1"); +} + +TEST_F(SimplePlacerTest, + TestColocationGroupWithUnsatisfiableReferenceConnections) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + + Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1")); + Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2")); + // Var 3 is on GPU + Node* var3 = ops::SourceOp("VariableGPU", b.opts().WithName("var3")); + + // Two assigns (reference connections) with two different + // colocation groups. Because their colocation groups all map to the + // same device, this is a valid assignment. + ops::BinaryOp("TestAssign", var1, input, + b.opts().WithName("assign1").WithAttr("_class", "loc:1")); + ops::BinaryOp("TestAssign", var2, input, + b.opts().WithName("assign2").WithAttr("_class", "loc:2")); + // Assign to var3, but try to use a colocation group that matches + // the assign of var2. This should fail because assign2 must be on CPU + // (it has a reference edge on var2), and assign3 must be on GPU, + // hence the conflict. + ops::BinaryOp("TestAssign", var3, input, + b.opts().WithName("assign3").WithAttr("_class", "loc:2")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_TRUE( + StringPiece(s.error_message()) + .contains("Cannot assign a device to node 'var3': Node had no " + "OpKernel registered")); +} + TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g.