diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc index f436b960846..2af86b67f29 100644 --- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc +++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc @@ -95,7 +95,7 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) { EXPECT_EQ(5, cm->AllocationId(node, 0)); } else if (node->name() == y_neg->name()) { EXPECT_LE(8, cm->MaxMemorySize(node, 0)); - EXPECT_EQ(6, cm->AllocationId(node, 0)); + EXPECT_EQ(7, cm->AllocationId(node, 0)); } // Check the execution time. Since it's highly variable, we'll // use a large window: anything between 1 and 10000 microseconds is diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc index b7f2479e65f..7a1de6f0df2 100644 --- a/tensorflow/core/common_runtime/simple_placer.cc +++ b/tensorflow/core/common_runtime/simple_placer.cc @@ -587,6 +587,17 @@ bool IsMetadataNode(const Node* node) { return (node_type == "Size" || node_type == "Shape" || node_type == "Rank"); } +// Returns true if the node has no inputs and produces outputs +// that are consumed by a single node. +// +// TODO(vrv): Currently this handles only nodes with one output, but +// this could be extended to handle the case where a node has many +// outputs that are connected to nodes in the same colocation group. +bool IsGeneratorNode(const Node* node) { + return node->num_inputs() == 0 && node->num_outputs() == 1 && + node->out_edges().size() == 1 && !IsRefType(node->output_type(0)); +} + } // namespace SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices, @@ -690,6 +701,7 @@ Status SimplePlacer::Run() { // 3. For each node, assign a device based on the constraints in the // disjoint node set. std::vector devices; + std::vector second_pass; for (Node* node : graph_->nodes()) { // Skip the source and sink nodes. if (!node->IsOp()) { @@ -700,6 +712,17 @@ Status SimplePlacer::Run() { continue; } + // Heuristic A: prefer to place "generators" with their only + // consumers. + // + // If this is a node with no inputs and a single (non-ref) + // consumer, we save this for a second pass, so that the + // consumer's placement is chosen. + if (IsGeneratorNode(node)) { + second_pass.push_back(node); + continue; + } + status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { return AttachDef( @@ -719,35 +742,81 @@ Status SimplePlacer::Run() { // to perform good placement we can add an interface for this. string assigned_device = devices[0]->name(); - // If the node only operates on metadata, not data, then it is - // desirable to place that metadata node with its input. + // Heuristic B: If the node only operates on metadata, not data, + // then it is desirable to place that metadata node with its + // input. if (IsMetadataNode(node)) { // Make sure that the input device type is in the list of supported // device types for this node. const Node* input = (*node->in_edges().begin())->src(); - - if (!input->assigned_device_name().empty()) { - const Device* input_device = - devices_->FindDeviceByName(input->assigned_device_name()); - if (std::any_of( - devices.begin(), devices.end(), [input_device](Device* d) { - return d->device_type() == input_device->device_type(); - })) { - assigned_device = input->assigned_device_name(); - } + // TODO(vrv): if the input is empty, consider postponing this + // node's assignment to the second pass, so that we handle the + // case where a metadata node's input comes from a backedge + // of a loop. + const string& input_device_name = input->assigned_device_name(); + if (CanAssignToDevice(input_device_name, devices)) { + assigned_device = input_device_name; } } - node->set_assigned_device_name(assigned_device); - // Log placement if log_device_placement is set. - if (options_ && options_->config.log_device_placement()) { - printf("%s: %s\n", node->name().c_str(), - node->assigned_device_name().c_str()); - LOG(INFO) << node->name() << ": " << node->assigned_device_name(); + AssignAndLog(assigned_device, node); + } + + // 4. Perform a second pass assignment for those nodes explicitly + // skipped during the first pass. + for (Node* node : second_pass) { + status = colocation_graph.GetDevicesForNode(node, &devices); + if (!status.ok()) { + return AttachDef( + errors::InvalidArgument("Cannot assign a device to node '", + node->name(), "': ", status.error_message()), + node->def()); } + + string assigned_device = devices[0]->name(); + + // Heuristic A application. + if (IsGeneratorNode(node)) { + const Node* output = (*node->out_edges().begin())->dst(); + const string& output_device_name = output->assigned_device_name(); + if (CanAssignToDevice(output_device_name, devices)) { + assigned_device = output_device_name; + } + } + + AssignAndLog(assigned_device, node); } return Status::OK(); } +bool SimplePlacer::CanAssignToDevice(const string& candidate_device_name, + const std::vector devices) const { + if (!candidate_device_name.empty()) { + // Can we assign to the same device? Check by validating that + // the device type of 'candidate_device_name' is present + // in 'devices'. + const Device* other_device = + devices_->FindDeviceByName(candidate_device_name); + if (std::any_of(devices.begin(), devices.end(), [other_device](Device* d) { + return d->device_type() == other_device->device_type(); + })) { + return true; + } + } + + return false; +} + +void SimplePlacer::AssignAndLog(const string& assigned_device, + Node* node) const { + node->set_assigned_device_name(assigned_device); + // Log placement if log_device_placement is set. + if (options_ && options_->config.log_device_placement()) { + printf("%s: %s\n", node->name().c_str(), + node->assigned_device_name().c_str()); + LOG(INFO) << node->name() << ": " << node->assigned_device_name(); + } +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/simple_placer.h b/tensorflow/core/common_runtime/simple_placer.h index 7bfd5a13110..1dfb3733078 100644 --- a/tensorflow/core/common_runtime/simple_placer.h +++ b/tensorflow/core/common_runtime/simple_placer.h @@ -79,6 +79,15 @@ class SimplePlacer { Status Run(); private: + // Returns true if the device type of 'candidate_device_name' is + // found in 'devices'. + bool CanAssignToDevice(const string& candidate_device_name, + const std::vector devices) const; + + // Assigns 'node's devices to 'assigned_device', and logs the + // placement if the SessionOptions entry in 'options_' requests it. + void AssignAndLog(const string& assigned_device, Node* node) const; + Graph* const graph_; // Not owned. const DeviceSet* const devices_; // Not owned. const SessionOptions* options_; // Not owned. diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc index 1c8c86aaa13..cfddcbb052f 100644 --- a/tensorflow/core/common_runtime/simple_placer_test.cc +++ b/tensorflow/core/common_runtime/simple_placer_test.cc @@ -307,6 +307,34 @@ TEST_F(SimplePlacerTest, TestMetadataColocatedWithInput) { EXPECT_COLOCATED(g, "var_cpu", "shape_op"); } +// Heuristic A implements "Island fusing": if a node only generates +// an output and it has only one consumer, we place the node +// with its consumer. +TEST_F(SimplePlacerTest, TestHeuristicA) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + + // A variable is only on CPU + Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu")); + + // The constant to be assigned can be on both GPU or CPU. + // + // Because of the heuristic, it gets placed on CPU to avoid a + // copy. + Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in")); + + // The assign is bound to CPU by the reference edge. + ops::BinaryOp("TestAssign", var_cpu, input, b.opts().WithName("assign")); + + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_COLOCATED(g, "var_cpu", "in"); + EXPECT_COLOCATED(g, "assign", "in"); +} + // Test that a graph with partial device specifications on the ops // will successfully TEST_F(SimplePlacerTest, TestPartialSpec) {