TensorFlow: Initial support in SimplePlacer for colocation groups,

to be used to colocate based on attributes rather than either
names of ops or devices (op names and devices aren't portable).

A follow up change will add an ops.colocate_with() to Python that adds
this attribute to nodes, and will be used to replace calls to 'with
tf.device(foo.device)' in TF library code, which assumes that devices
have been specified.
Change: 115463464
This commit is contained in:
Vijay Vasudevan 2016-02-24 09:59:50 -08:00 committed by TensorFlower Gardener
parent 92383c8754
commit 9d84271a20
2 changed files with 138 additions and 0 deletions

View File

@ -67,6 +67,7 @@ std::vector<Device*> FilterSupportedDevices(
return filtered_devices;
}
// TODO(vrv): Remove "@" syntax capability.
bool HasColocatedNodeName(const Node& node) {
return StringPiece(node.def().device()).starts_with("@");
}
@ -83,6 +84,30 @@ Status ParseColocatedNodeName(const Node& node,
return Status::OK();
}
// Returns the name of the colocation group of the node by inspecting
// the "_class" attribute of the NodeDef. Returns "" if it doesn't
// exist.
Status ColocationGroup(const Node& node, string* colocation_group) {
string class_spec;
// TODO(vrv): We should consider adding a GetNodeAttr that returns a
// StringPiece, to avoid a copy.
Status s = GetNodeAttr(node.def(), "_class", &class_spec);
if (!s.ok()) {
// No "_class" attribute is equivalent to the empty colocation_group.
*colocation_group = "";
return Status::OK();
}
StringPiece spec(class_spec);
if (!spec.Consume("loc:")) {
return errors::InvalidArgument("Node had an invalid _class attribute: ",
class_spec);
}
*colocation_group = spec.ToString();
return Status::OK();
}
// This class maintains the connected components of a colocation
// constraint graph, and uses this information to assign a satisfying
// device placement to the nodes of the graph.
@ -134,6 +159,24 @@ class ColocationGraph {
CHECK_GE(member.parent, 0);
members_.resize(member.parent + 1);
members_[member.parent] = std::move(member);
// When adding the node, identify whether it is part of a
// colocation group.
string colocation_group;
TF_RETURN_IF_ERROR(ColocationGroup(node, &colocation_group));
if (!colocation_group.empty()) {
// Node has a colocation group specified.
auto it = colocation_group_root_.find(colocation_group);
if (it == colocation_group_root_.end()) {
// This is the first node of the colocation group, so
// designate this node as the 'root' of that colocation group.
colocation_group_root_[colocation_group] = &node;
} else {
// Colocate this node with the root.
ColocateNodes(node, *(it->second));
}
}
return Status::OK();
}
@ -447,6 +490,10 @@ class ColocationGraph {
const DeviceSet* device_set_; // Not owned.
const std::vector<DeviceType> device_types_;
const SessionOptions* options_; // Not owned;
// Maps from a colocation group identifier to the 'root' of that
// colocation group.
std::unordered_map<string, const Node*> colocation_group_root_;
};
} // namespace

View File

@ -218,6 +218,13 @@ class SimplePlacerTest : public ::testing::Test {
GetNodeByName(g_, (name_b))->assigned_device_name()); \
} while (0)
#define EXPECT_NOT_COLOCATED(g, name_a, name_b) \
do { \
Graph& g_ = (g); \
EXPECT_NE(GetNodeByName(g_, (name_a))->assigned_device_name(), \
GetNodeByName(g_, (name_b))->assigned_device_name()); \
} while (0)
#define EXPECT_DEVICE_TYPE(g, name, expected_device_type) \
EXPECT_EQ(DeviceType(expected_device_type).type(), \
devices_.FindDeviceByName( \
@ -473,6 +480,90 @@ TEST_F(SimplePlacerTest, TestColocatedChainWithLongRangeColocations) {
}
}
TEST_F(SimplePlacerTest, TestColocationGroup) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* input = ops::SourceOp(
"TestInput", b.opts().WithName("in").WithAttr("_class", "loc:ti"));
Node* colocated_with_input = ops::UnaryOp(
"TestRelu", input,
b.opts().WithName("colocated_1").WithAttr("_class", "loc:ti"));
// This will not be colocated with the input because TestInput is
// only availbale on CPU and TestRelu will default to GPU.
Node* not_colocated_with_input =
ops::UnaryOp("TestRelu", input, b.opts().WithName("foo"));
CHECK(colocated_with_input);
CHECK(not_colocated_with_input);
TF_EXPECT_OK(BuildGraph(b, &g));
}
TF_EXPECT_OK(Place(&g));
EXPECT_COLOCATED(g, "in", "colocated_1");
EXPECT_NOT_COLOCATED(g, "in", "foo");
}
TEST_F(SimplePlacerTest, TestColocationGroupWithReferenceConnections) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1"));
Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2"));
// Two assigns (reference connections) with two different
// colocation groups. Because their colocation groups all map to the
// same device, this is a valid assignment.
ops::BinaryOp("TestAssign", var1, input,
b.opts().WithName("assign1").WithAttr("_class", "loc:1"));
ops::BinaryOp("TestAssign", var2, input,
b.opts().WithName("assign2").WithAttr("_class", "loc:2"));
TF_EXPECT_OK(BuildGraph(b, &g));
}
TF_EXPECT_OK(Place(&g));
EXPECT_COLOCATED(g, "in", "var1");
EXPECT_COLOCATED(g, "in", "var2");
EXPECT_COLOCATED(g, "var1", "assign2");
EXPECT_COLOCATED(g, "var2", "assign1");
}
TEST_F(SimplePlacerTest,
TestColocationGroupWithUnsatisfiableReferenceConnections) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1"));
Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2"));
// Var 3 is on GPU
Node* var3 = ops::SourceOp("VariableGPU", b.opts().WithName("var3"));
// Two assigns (reference connections) with two different
// colocation groups. Because their colocation groups all map to the
// same device, this is a valid assignment.
ops::BinaryOp("TestAssign", var1, input,
b.opts().WithName("assign1").WithAttr("_class", "loc:1"));
ops::BinaryOp("TestAssign", var2, input,
b.opts().WithName("assign2").WithAttr("_class", "loc:2"));
// Assign to var3, but try to use a colocation group that matches
// the assign of var2. This should fail because assign2 must be on CPU
// (it has a reference edge on var2), and assign3 must be on GPU,
// hence the conflict.
ops::BinaryOp("TestAssign", var3, input,
b.opts().WithName("assign3").WithAttr("_class", "loc:2"));
TF_EXPECT_OK(BuildGraph(b, &g));
}
Status s = Place(&g);
EXPECT_TRUE(
StringPiece(s.error_message())
.contains("Cannot assign a device to node 'var3': Node had no "
"OpKernel registered"));
}
TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.