Enable a soft ordering for collectives which adds dependencies as node attrs.

Before this change, OrderCollectives would add control edges between collective
ops to make their execution sequential.  This change introduces a technique of
encoding the ordering information as a node attribute.  In future changes, the
collective executor will parse these changes and enforce a fine-grained
ordering between collective ops.

Additionally, this change also
* restricts ordering to CollectiveReduce nodes,
* and fixes a bug so that this ordering is invoked from both DirectSession and
MasterSession.

PiperOrigin-RevId: 228801501
This commit is contained in:
Ayush Dubey 2019-01-10 16:44:20 -08:00 committed by TensorFlower Gardener
parent 40874676a6
commit e2760cb89f
10 changed files with 268 additions and 64 deletions

View File

@ -2819,6 +2819,7 @@ tf_cuda_library(
":protos_all_cc",
"//third_party/eigen3",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
],
)

View File

@ -35,6 +35,19 @@ string BuildGraphOptions::DebugString() const {
if (collective_graph_key != kNoCollectiveGraphKey) {
strings::StrAppend(&rv, "\ncollective_graph_key: ", collective_graph_key);
}
string collective_order_str;
switch (collective_order) {
case GraphCollectiveOrder::kNone:
collective_order_str = "none";
break;
case GraphCollectiveOrder::kEdges:
collective_order_str = "edges";
break;
case GraphCollectiveOrder::kAttrs:
collective_order_str = "attrs";
break;
}
strings::StrAppend(&rv, "\ncollective_order: ", collective_order_str);
return rv;
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/graph/collective_order.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
@ -34,6 +35,11 @@ struct BuildGraphOptions {
static const int64 kNoCollectiveGraphKey = 0;
int64 collective_graph_key = kNoCollectiveGraphKey;
// If not `kNone`, order all CollectiveReduce operations statically and
// deterministically. If `kEdges`, encode dependencies as explicit control
// edges, if `kAttrs` encode as attribute on collective op.
GraphCollectiveOrder collective_order = GraphCollectiveOrder::kNone;
string DebugString() const;
};

View File

@ -45,7 +45,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/collective_order.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_partition.h"
@ -1192,6 +1191,10 @@ Status DirectSession::CreateExecutors(
options.use_function_convention = !run_state_args->is_partial_run;
options.collective_graph_key =
callable_options.run_options().experimental().collective_graph_key();
if (options_.config.experimental()
.collective_deterministic_sequential_execution()) {
options.collective_order = GraphCollectiveOrder::kEdges;
}
std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
@ -1523,12 +1526,6 @@ Status DirectSession::CreateGraphs(
CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
}
// Make collective execution order deterministic if needed.
if (options_.config.experimental()
.collective_deterministic_sequential_execution()) {
TF_RETURN_IF_ERROR(OrderCollectives(&client_graph->graph));
}
// Partition the graph across devices.
PartitionOptions popts;
popts.node_to_loc = [](const Node* node) {

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/collective_order.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/subgraph.h"
@ -819,6 +820,12 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
}
}
// Make collective execution order deterministic if needed.
if (options.collective_order != GraphCollectiveOrder::kNone) {
TF_RETURN_IF_ERROR(
OrderCollectives(optimized_graph.get(), options.collective_order));
}
// Copy the extracted graph in order to make its node ids dense,
// since the local CostModel used to record its stats is sized by
// the largest node id.

View File

@ -292,8 +292,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
if (tot >= 0.1 * 1048576.0) {
bytes = strings::Printf("[%.1fMB] ", tot / 1048576.0);
}
return strings::StrCat(bytes, stats.node_name(), " = ",
details.type_string, details.detail_text);
return strings::StrCat(bytes, stats.node_name(), " = ", details.type_string,
details.detail_text);
}
// Send/Recv nodes that are the result of client-added
@ -1081,16 +1081,17 @@ void CopyAndSortStrings(size_t size,
} // namespace
void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
const ConfigProto& config,
BuildGraphOptions* opts) {
CallableOptions* callable_opts = &opts->callable_options;
CopyAndSortStrings(req.num_feeds(),
[&req](size_t i) { return req.feed_name(i); },
CopyAndSortStrings(
req.num_feeds(), [&req](size_t i) { return req.feed_name(i); },
callable_opts->mutable_feed());
CopyAndSortStrings(req.num_fetches(),
[&req](size_t i) { return req.fetch_name(i); },
CopyAndSortStrings(
req.num_fetches(), [&req](size_t i) { return req.fetch_name(i); },
callable_opts->mutable_fetch());
CopyAndSortStrings(req.num_targets(),
[&req](size_t i) { return req.target_name(i); },
CopyAndSortStrings(
req.num_targets(), [&req](size_t i) { return req.target_name(i); },
callable_opts->mutable_target());
if (!req.options().debug_options().debug_tensor_watch_opts().empty()) {
@ -1100,18 +1101,22 @@ void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
opts->collective_graph_key =
req.options().experimental().collective_graph_key();
if (config.experimental().collective_deterministic_sequential_execution()) {
opts->collective_order = GraphCollectiveOrder::kEdges;
}
}
void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
BuildGraphOptions* opts) {
CallableOptions* callable_opts = &opts->callable_options;
CopyAndSortStrings(req.feed_size(), [&req](size_t i) { return req.feed(i); },
CopyAndSortStrings(
req.feed_size(), [&req](size_t i) { return req.feed(i); },
callable_opts->mutable_feed());
CopyAndSortStrings(req.fetch_size(),
[&req](size_t i) { return req.fetch(i); },
CopyAndSortStrings(
req.fetch_size(), [&req](size_t i) { return req.fetch(i); },
callable_opts->mutable_fetch());
CopyAndSortStrings(req.target_size(),
[&req](size_t i) { return req.target(i); },
CopyAndSortStrings(
req.target_size(), [&req](size_t i) { return req.target(i); },
callable_opts->mutable_target());
// TODO(cais): Add TFDBG support to partial runs.
@ -1852,7 +1857,7 @@ Status MasterSession::DoRunWithLocalExecution(
// Prepare.
BuildGraphOptions bgopts;
BuildBuildGraphOptions(req, &bgopts);
BuildBuildGraphOptions(req, session_opts_.config, &bgopts);
ReffedClientGraph* rcg = nullptr;
int64 count;
TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count));

View File

@ -14,55 +14,70 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/graph/collective_order.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/graph/algorithm.h"
namespace tensorflow {
namespace {
Status OrderCollectives(Graph* graph) {
// `instance_keys[i]` corresponds to `collective_nodes[i]`
std::vector<Node*> collective_nodes;
std::vector<int32> instance_keys;
// node -> set of collectives on which node depends.
std::unordered_map<Node*, std::unordered_set<int32>> node_dependencies;
// Find all CollectiveReduce nodes and the existing data dependencies between
// them.
Status DiscoverDataDependencies(
const Graph* graph, std::vector<Node*>* collective_nodes,
std::vector<int32>* instance_keys,
absl::flat_hash_map<Node*, absl::flat_hash_set<int32>>* data_dependencies) {
Status s;
// Algorithm: do Reverse DFS starting at sink. `node_leave` is called when
// all parents of `node` have been visited. At that point, the collectives
// on which this node depends on are up to date. For this node's children,
// add all these collectives. Also, if this node is collective, add as a
// dependency for the children.
auto node_leave = [&collective_nodes, &instance_keys, &node_dependencies,
// all parents of `node` have been visited. At that point,
// `data_dependencies[node]` is a list containing `instance_key` of every
// `CollectiveReduce` on which `node` has a data dependency.
// For this node's children, add all these instance keys. Also, if this node
// is collective, add as a dependency for the children.
auto node_leave = [collective_nodes, instance_keys, data_dependencies,
&s](Node* node) {
int32 instance_key;
if (node->IsCollective()) {
bool enter_node =
node->IsCollective() && node->type_string() == "CollectiveReduce";
if (enter_node) {
Status get_attr_status =
GetNodeAttr(node->attrs(), "instance_key", &instance_key);
s.Update(get_attr_status);
collective_nodes.push_back(node);
instance_keys.push_back(instance_key);
collective_nodes->push_back(node);
instance_keys->push_back(instance_key);
VLOG(2) << "collective node " << node->DebugString();
}
const auto& node_deps = node_dependencies[node];
const auto& node_deps = (*data_dependencies)[node];
for (const Edge* out_edge : node->out_edges()) {
auto& child_deps = node_dependencies[out_edge->dst()];
auto& child_deps = (*data_dependencies)[out_edge->dst()];
child_deps.insert(node_deps.begin(), node_deps.end());
if (node->IsCollective() && s.ok()) {
if (enter_node && s.ok()) {
child_deps.insert(instance_key);
}
}
};
ReverseDFS(*graph, nullptr, node_leave);
if (!s.ok()) return s;
return s;
}
// For all pairs of collective nodes n1 and n2 on the same device, if n1 does
// not depend on n2 and n2 does not depend on n1, then they are potentially
// concurrent. Add an arbitrary, deterministic control edge between them.
// Given a list of `collective_nodes` and `data_dependencies` between the
// collective nodes, create control dependencies between concurrent collectives
// and store in `dependency_edges`.
// If there exists an edge a -> b then `dependency_edges[a]` contains `b`
Status CreateControlDependencies(
const std::vector<Node*>& collective_nodes,
const std::vector<int32>& instance_keys,
absl::flat_hash_map<Node*, absl::flat_hash_set<int32>>* data_dependencies,
absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>>* dependency_edges) {
// If there exists some path a -> ... -> b then `all_paths[a]` contains `b`
absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>> all_paths;
for (int i = 0; i < collective_nodes.size() - 1; i++) {
if (!collective_nodes[i]->IsCollective()) {
if (!collective_nodes[i]->IsCollective() ||
collective_nodes[i]->type_string() != "CollectiveReduce") {
return errors::Internal("Unexpected node ",
collective_nodes[i]->DebugString());
}
const auto& deps_i = node_dependencies[collective_nodes[i]];
const auto& deps_i = (*data_dependencies)[collective_nodes[i]];
for (int j = i + 1; j < collective_nodes.size(); j++) {
if (collective_nodes[i]->requested_device() !=
collective_nodes[j]->requested_device()) {
@ -74,17 +89,48 @@ Status OrderCollectives(Graph* graph) {
" on 2 nodes with the same device ",
collective_nodes[i]->requested_device());
}
const auto& deps_j = node_dependencies[collective_nodes[j]];
const auto& deps_j = (*data_dependencies)[collective_nodes[j]];
if (deps_i.find(instance_keys[j]) == deps_i.end() &&
deps_j.find(instance_keys[i]) == deps_j.end()) {
int src_idx = instance_keys[i] < instance_keys[j] ? i : j;
int dst_idx = instance_keys[i] < instance_keys[j] ? j : i;
Node* src_node = collective_nodes[src_idx];
Node* dst_node = collective_nodes[dst_idx];
VLOG(1) << "Adding control edge from node " << src_node->name()
VLOG(1) << "Adding control dependency from node " << src_node->name()
<< " instance " << instance_keys[src_idx] << " to node "
<< dst_node->name() << " instance " << instance_keys[dst_idx];
graph->AddControlEdge(src_node, dst_node);
(*dependency_edges)[src_node].insert(dst_node);
auto& src_paths = all_paths[src_node];
src_paths.insert(dst_node);
for (Node* downstream_node : all_paths[dst_node]) {
src_paths.insert(downstream_node);
}
}
}
}
// Prune dependency edges so that if there are edges a -> b, b -> c, and a ->
// c, then remove a -> c. This dependency would be handled naturally during
// op scheduling.
for (int i = 0; i < collective_nodes.size(); ++i) {
Node* node = collective_nodes[i];
auto& neighbor_set = (*dependency_edges)[node];
std::vector<Node*> neighbor_list(neighbor_set.begin(), neighbor_set.end());
// For all n1, n2 in `neighbor_list` if there is a path from n1 -> n2 then
// eliminate n2 from `neighbor_set` and `neighbor_list`. We remove from
// `neighbor_list` by replacing with a `nullptr`, hence the `nullptr` checks
// below.
for (int j = 0; j < neighbor_list.size(); ++j) {
Node* n1 = neighbor_list[j];
if (n1 == nullptr) continue;
auto& n1_paths = all_paths[n1];
for (int k = 0; k < neighbor_list.size(); ++k) {
Node* n2 = neighbor_list[k];
if (j == k || n2 == nullptr) continue;
if (n1_paths.find(n2) != n1_paths.end()) {
neighbor_set.erase(n2);
neighbor_list[k] = nullptr;
}
}
}
}
@ -92,4 +138,65 @@ Status OrderCollectives(Graph* graph) {
return Status::OK();
}
// Insert control dependencies defined by `dependency_edges` in `graph`. If
// `order_type` is `kEdges`, insert explicit control edges, else if `order_type`
// is `kAttrs`, encode depdencies as an attribute on collective node.
Status InsertControlDependencies(
Graph* graph, GraphCollectiveOrder order_type,
const absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>>&
dependency_edges) {
if (order_type == GraphCollectiveOrder::kEdges) {
for (const auto& pair : dependency_edges) {
Node* src_node = pair.first;
for (Node* dst_node : pair.second) {
graph->AddControlEdge(src_node, dst_node);
}
}
} else if (order_type == GraphCollectiveOrder::kAttrs) {
// `wait_for` is the inverse of `dependency_edges`, i.e. `wait_for[node]`
// contains the list of instance keys for which `node` must wait.
absl::flat_hash_map<Node*, absl::flat_hash_set<int32>> wait_for;
for (const auto& pair : dependency_edges) {
int32 src_instance;
TF_RETURN_IF_ERROR(
GetNodeAttr(pair.first->attrs(), "instance_key", &src_instance));
for (Node* dst_node : pair.second) {
wait_for[dst_node].insert(src_instance);
}
}
for (const auto& pair : wait_for) {
std::vector<int32> wait_for_list(pair.second.begin(), pair.second.end());
pair.first->ClearAttr("wait_for");
pair.first->AddAttr("wait_for", wait_for_list);
}
} else {
return errors::Internal("Unexpected GraphCollectiveOrder type ",
static_cast<int>(order_type));
}
return Status::OK();
}
} // namespace
Status OrderCollectives(Graph* graph, GraphCollectiveOrder order_type) {
// `instance_keys[i]` corresponds to `collective_nodes[i]`
std::vector<Node*> collective_nodes;
std::vector<int32> instance_keys;
// node -> set of collectives on which node depends.
absl::flat_hash_map<Node*, absl::flat_hash_set<int32>> data_dependencies;
TF_RETURN_IF_ERROR(DiscoverDataDependencies(
graph, &collective_nodes, &instance_keys, &data_dependencies));
if (collective_nodes.empty()) return Status::OK();
absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>> dependency_edges;
// For all pairs of collective nodes n1 and n2 on the same device, if n1 does
// not depend on n2 and n2 does not depend on n1, then they are potentially
// concurrent. Create an arbitrary, deterministic ordering between them.
TF_RETURN_IF_ERROR(CreateControlDependencies(
collective_nodes, instance_keys, &data_dependencies, &dependency_edges));
return InsertControlDependencies(graph, order_type, dependency_edges);
}
} // namespace tensorflow

View File

@ -19,11 +19,17 @@ limitations under the License.
namespace tensorflow {
// Introduces control edges between potentially concurrent CollectiveOps to make
// their execution order deterministic. This may be used to execute collectives
// in the same order across all workers in a distributed execution, if all
// workers are executing the same graph.
Status OrderCollectives(Graph* graph);
enum class GraphCollectiveOrder { kNone, kEdges, kAttrs };
// Introduces a deterministic execution order between potentially concurrent
// CollectiveOps. This may be used to execute collectives in the same order
// across all workers in a distributed execution, if all workers are executing
// the same graph.
// If `order_type` is `kEdges`, introduce the ordering in the form of explicit
// control edges between collective graph nodes. If `order_type` is `kAttrs`,
// add an attribute to the node which may be used by collective executor to
// ensure the required ordering.
Status OrderCollectives(Graph* graph, GraphCollectiveOrder order_type);
} // namespace tensorflow

View File

@ -59,6 +59,23 @@ void VerifyGraph(const Graph& graph,
UnorderedElementsAreArray(expected_collective_control_edges));
}
// Verifies that the `wait_for` attribute on collective nodes matches
// `wait_for_map`.
void VerifyAttrs(
const Graph& graph,
const std::unordered_map<string, std::vector<int32>> wait_for_map) {
for (const Node* node : graph.nodes()) {
if (node->IsCollective() ||
wait_for_map.find(node->name()) == wait_for_map.end()) {
continue;
}
std::vector<int32> wait_for_actual;
TF_EXPECT_OK(GetNodeAttr(node->attrs(), "wait_for", &wait_for_actual));
auto wait_for_expected = wait_for_map.at(node->name());
EXPECT_THAT(wait_for_actual, UnorderedElementsAreArray(wait_for_expected));
}
}
Node* CollectiveReduceNode(GraphDefBuilder* builder, Node* input,
const string& name, const string& device,
int instance_key) {
@ -123,11 +140,17 @@ std::unique_ptr<Graph> InitGraph() {
// added after calling `OrderCollectives`: c2_0 -> c3_0 and c2_1 -> c3_1.
TEST(CollectiveOrderTest, SimpleOrder) {
std::unique_ptr<Graph> graph = InitGraph();
TF_EXPECT_OK(OrderCollectives(graph.get()));
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kEdges));
VerifyGraph(*graph, {"c1_0", "c1_1", "c2_0", "c2_1", "c3_0", "c3_1"},
{{"c2_0", "c3_0"}, {"c2_1", "c3_1"}});
}
TEST(CollectiveOrderTest, SimpleOrderAttr) {
std::unique_ptr<Graph> graph = InitGraph();
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kAttrs));
VerifyAttrs(*graph, {{"c3_0", {2}}, {"c3_1", {2}}});
}
// Initialize the following graph:
//
// a
@ -162,12 +185,50 @@ std::unique_ptr<Graph> InitGraph2() {
}
// Tests that in the graph created by `InitGraph2`, we add the following control
// edges after calling `OrderCollectives`: c2 -> c3, c3 -> c4, and c2 -> c4.
// edges after calling `OrderCollectives`: c2 -> c3, c3 -> c4. c2->c4 is
// pruned because it follows from the other two edges.
TEST(CollectiveOrderTest, SimpleOrder2) {
std::unique_ptr<Graph> graph = InitGraph2();
TF_EXPECT_OK(OrderCollectives(graph.get()));
VerifyGraph(*graph, {"c1", "c2", "c3", "c4"},
{{"c2", "c3"}, {"c3", "c4"}, {"c2", "c4"}});
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kEdges));
VerifyGraph(*graph, {"c1", "c2", "c3", "c4"}, {{"c2", "c3"}, {"c3", "c4"}});
}
// Initialize the following graph:
//
// w x y z
// | | | |
// c1 c2 c3 c4
//
std::unique_ptr<Graph> InitGraphForPruning() {
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
const string dev0 = "/job:localhost/replica:0/task:0/device:CPU:0";
Node* w = ops::SourceOp("TestParams",
builder.opts().WithName("w").WithDevice(dev0));
Node* x = ops::SourceOp("TestParams",
builder.opts().WithName("x").WithDevice(dev0));
Node* y = ops::SourceOp("TestParams",
builder.opts().WithName("y").WithDevice(dev0));
Node* z = ops::SourceOp("TestParams",
builder.opts().WithName("z").WithDevice(dev0));
CollectiveReduceNode(&builder, w, "c1", dev0, 1);
CollectiveReduceNode(&builder, x, "c2", dev0, 2);
CollectiveReduceNode(&builder, y, "c3", dev0, 3);
CollectiveReduceNode(&builder, z, "c4", dev0, 4);
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
Status s = GraphDefBuilderToGraph(builder, graph.get());
if (!s.ok()) {
LOG(FATAL) << "Error building graph " << s;
}
return graph;
}
// Tests that in the graph created by `InitGraphForPruning`, we only add c1 ->
// c2, c2 -> c3, c3 -> c4, and other edges are pruned away.
TEST(CollectiveOrderTest, Pruning) {
std::unique_ptr<Graph> graph = InitGraphForPruning();
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kAttrs));
VerifyAttrs(*graph, {{"c4", {3}}, {"c3", {2}}, {"c2", {1}}});
}
} // namespace

View File

@ -28,6 +28,7 @@ REGISTER_OP("CollectiveReduce")
.Attr("merge_op: {'Min', 'Max', 'Mul', 'Add'}")
.Attr("final_op: {'Id', 'Div'}")
.Attr("subdiv_offsets: list(int)")
.Attr("wait_for: list(int) = []")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape);