TensorFlow: Initial support in SimplePlacer for colocation groups,
to be used to colocate based on attributes rather than either names of ops or devices (op names and devices aren't portable). A follow up change will add an ops.colocate_with() to Python that adds this attribute to nodes, and will be used to replace calls to 'with tf.device(foo.device)' in TF library code, which assumes that devices have been specified. Change: 115463464
This commit is contained in:
parent
92383c8754
commit
9d84271a20
@ -67,6 +67,7 @@ std::vector<Device*> FilterSupportedDevices(
|
||||
return filtered_devices;
|
||||
}
|
||||
|
||||
// TODO(vrv): Remove "@" syntax capability.
|
||||
bool HasColocatedNodeName(const Node& node) {
|
||||
return StringPiece(node.def().device()).starts_with("@");
|
||||
}
|
||||
@ -83,6 +84,30 @@ Status ParseColocatedNodeName(const Node& node,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns the name of the colocation group of the node by inspecting
|
||||
// the "_class" attribute of the NodeDef. Returns "" if it doesn't
|
||||
// exist.
|
||||
Status ColocationGroup(const Node& node, string* colocation_group) {
|
||||
string class_spec;
|
||||
// TODO(vrv): We should consider adding a GetNodeAttr that returns a
|
||||
// StringPiece, to avoid a copy.
|
||||
Status s = GetNodeAttr(node.def(), "_class", &class_spec);
|
||||
if (!s.ok()) {
|
||||
// No "_class" attribute is equivalent to the empty colocation_group.
|
||||
*colocation_group = "";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StringPiece spec(class_spec);
|
||||
if (!spec.Consume("loc:")) {
|
||||
return errors::InvalidArgument("Node had an invalid _class attribute: ",
|
||||
class_spec);
|
||||
}
|
||||
|
||||
*colocation_group = spec.ToString();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// This class maintains the connected components of a colocation
|
||||
// constraint graph, and uses this information to assign a satisfying
|
||||
// device placement to the nodes of the graph.
|
||||
@ -134,6 +159,24 @@ class ColocationGraph {
|
||||
CHECK_GE(member.parent, 0);
|
||||
members_.resize(member.parent + 1);
|
||||
members_[member.parent] = std::move(member);
|
||||
|
||||
// When adding the node, identify whether it is part of a
|
||||
// colocation group.
|
||||
string colocation_group;
|
||||
TF_RETURN_IF_ERROR(ColocationGroup(node, &colocation_group));
|
||||
if (!colocation_group.empty()) {
|
||||
// Node has a colocation group specified.
|
||||
auto it = colocation_group_root_.find(colocation_group);
|
||||
if (it == colocation_group_root_.end()) {
|
||||
// This is the first node of the colocation group, so
|
||||
// designate this node as the 'root' of that colocation group.
|
||||
colocation_group_root_[colocation_group] = &node;
|
||||
} else {
|
||||
// Colocate this node with the root.
|
||||
ColocateNodes(node, *(it->second));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -447,6 +490,10 @@ class ColocationGraph {
|
||||
const DeviceSet* device_set_; // Not owned.
|
||||
const std::vector<DeviceType> device_types_;
|
||||
const SessionOptions* options_; // Not owned;
|
||||
|
||||
// Maps from a colocation group identifier to the 'root' of that
|
||||
// colocation group.
|
||||
std::unordered_map<string, const Node*> colocation_group_root_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
@ -218,6 +218,13 @@ class SimplePlacerTest : public ::testing::Test {
|
||||
GetNodeByName(g_, (name_b))->assigned_device_name()); \
|
||||
} while (0)
|
||||
|
||||
#define EXPECT_NOT_COLOCATED(g, name_a, name_b) \
|
||||
do { \
|
||||
Graph& g_ = (g); \
|
||||
EXPECT_NE(GetNodeByName(g_, (name_a))->assigned_device_name(), \
|
||||
GetNodeByName(g_, (name_b))->assigned_device_name()); \
|
||||
} while (0)
|
||||
|
||||
#define EXPECT_DEVICE_TYPE(g, name, expected_device_type) \
|
||||
EXPECT_EQ(DeviceType(expected_device_type).type(), \
|
||||
devices_.FindDeviceByName( \
|
||||
@ -473,6 +480,90 @@ TEST_F(SimplePlacerTest, TestColocatedChainWithLongRangeColocations) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SimplePlacerTest, TestColocationGroup) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{ // Scope for temporary variables used to construct g.
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* input = ops::SourceOp(
|
||||
"TestInput", b.opts().WithName("in").WithAttr("_class", "loc:ti"));
|
||||
Node* colocated_with_input = ops::UnaryOp(
|
||||
"TestRelu", input,
|
||||
b.opts().WithName("colocated_1").WithAttr("_class", "loc:ti"));
|
||||
|
||||
// This will not be colocated with the input because TestInput is
|
||||
// only availbale on CPU and TestRelu will default to GPU.
|
||||
Node* not_colocated_with_input =
|
||||
ops::UnaryOp("TestRelu", input, b.opts().WithName("foo"));
|
||||
CHECK(colocated_with_input);
|
||||
CHECK(not_colocated_with_input);
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(Place(&g));
|
||||
EXPECT_COLOCATED(g, "in", "colocated_1");
|
||||
EXPECT_NOT_COLOCATED(g, "in", "foo");
|
||||
}
|
||||
|
||||
TEST_F(SimplePlacerTest, TestColocationGroupWithReferenceConnections) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{ // Scope for temporary variables used to construct g.
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
|
||||
Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1"));
|
||||
Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2"));
|
||||
|
||||
// Two assigns (reference connections) with two different
|
||||
// colocation groups. Because their colocation groups all map to the
|
||||
// same device, this is a valid assignment.
|
||||
ops::BinaryOp("TestAssign", var1, input,
|
||||
b.opts().WithName("assign1").WithAttr("_class", "loc:1"));
|
||||
ops::BinaryOp("TestAssign", var2, input,
|
||||
b.opts().WithName("assign2").WithAttr("_class", "loc:2"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(Place(&g));
|
||||
EXPECT_COLOCATED(g, "in", "var1");
|
||||
EXPECT_COLOCATED(g, "in", "var2");
|
||||
EXPECT_COLOCATED(g, "var1", "assign2");
|
||||
EXPECT_COLOCATED(g, "var2", "assign1");
|
||||
}
|
||||
|
||||
TEST_F(SimplePlacerTest,
|
||||
TestColocationGroupWithUnsatisfiableReferenceConnections) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{ // Scope for temporary variables used to construct g.
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
|
||||
|
||||
Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1"));
|
||||
Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2"));
|
||||
// Var 3 is on GPU
|
||||
Node* var3 = ops::SourceOp("VariableGPU", b.opts().WithName("var3"));
|
||||
|
||||
// Two assigns (reference connections) with two different
|
||||
// colocation groups. Because their colocation groups all map to the
|
||||
// same device, this is a valid assignment.
|
||||
ops::BinaryOp("TestAssign", var1, input,
|
||||
b.opts().WithName("assign1").WithAttr("_class", "loc:1"));
|
||||
ops::BinaryOp("TestAssign", var2, input,
|
||||
b.opts().WithName("assign2").WithAttr("_class", "loc:2"));
|
||||
// Assign to var3, but try to use a colocation group that matches
|
||||
// the assign of var2. This should fail because assign2 must be on CPU
|
||||
// (it has a reference edge on var2), and assign3 must be on GPU,
|
||||
// hence the conflict.
|
||||
ops::BinaryOp("TestAssign", var3, input,
|
||||
b.opts().WithName("assign3").WithAttr("_class", "loc:2"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
|
||||
Status s = Place(&g);
|
||||
EXPECT_TRUE(
|
||||
StringPiece(s.error_message())
|
||||
.contains("Cannot assign a device to node 'var3': Node had no "
|
||||
"OpKernel registered"));
|
||||
}
|
||||
|
||||
TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{ // Scope for temporary variables used to construct g.
|
||||
|
Loading…
Reference in New Issue
Block a user