Enable a soft ordering for collectives which adds dependencies as node attrs.
Before this change, OrderCollectives would add control edges between collective ops to make their execution sequential. This change introduces a technique of encoding the ordering information as a node attribute. In future changes, the collective executor will parse these changes and enforce a fine-grained ordering between collective ops. Additionally, this change also * restricts ordering to CollectiveReduce nodes, * and fixes a bug so that this ordering is invoked from both DirectSession and MasterSession. PiperOrigin-RevId: 228801501
This commit is contained in:
parent
40874676a6
commit
e2760cb89f
@ -2819,6 +2819,7 @@ tf_cuda_library(
|
||||
":protos_all_cc",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -35,6 +35,19 @@ string BuildGraphOptions::DebugString() const {
|
||||
if (collective_graph_key != kNoCollectiveGraphKey) {
|
||||
strings::StrAppend(&rv, "\ncollective_graph_key: ", collective_graph_key);
|
||||
}
|
||||
string collective_order_str;
|
||||
switch (collective_order) {
|
||||
case GraphCollectiveOrder::kNone:
|
||||
collective_order_str = "none";
|
||||
break;
|
||||
case GraphCollectiveOrder::kEdges:
|
||||
collective_order_str = "edges";
|
||||
break;
|
||||
case GraphCollectiveOrder::kAttrs:
|
||||
collective_order_str = "attrs";
|
||||
break;
|
||||
}
|
||||
strings::StrAppend(&rv, "\ncollective_order: ", collective_order_str);
|
||||
return rv;
|
||||
}
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/graph/collective_order.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
|
||||
@ -34,6 +35,11 @@ struct BuildGraphOptions {
|
||||
static const int64 kNoCollectiveGraphKey = 0;
|
||||
int64 collective_graph_key = kNoCollectiveGraphKey;
|
||||
|
||||
// If not `kNone`, order all CollectiveReduce operations statically and
|
||||
// deterministically. If `kEdges`, encode dependencies as explicit control
|
||||
// edges, if `kAttrs` encode as attribute on collective op.
|
||||
GraphCollectiveOrder collective_order = GraphCollectiveOrder::kNone;
|
||||
|
||||
string DebugString() const;
|
||||
};
|
||||
|
||||
|
@ -45,7 +45,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/collective_order.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_partition.h"
|
||||
@ -1192,6 +1191,10 @@ Status DirectSession::CreateExecutors(
|
||||
options.use_function_convention = !run_state_args->is_partial_run;
|
||||
options.collective_graph_key =
|
||||
callable_options.run_options().experimental().collective_graph_key();
|
||||
if (options_.config.experimental()
|
||||
.collective_deterministic_sequential_execution()) {
|
||||
options.collective_order = GraphCollectiveOrder::kEdges;
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
|
||||
std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
|
||||
@ -1523,12 +1526,6 @@ Status DirectSession::CreateGraphs(
|
||||
CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
|
||||
}
|
||||
|
||||
// Make collective execution order deterministic if needed.
|
||||
if (options_.config.experimental()
|
||||
.collective_deterministic_sequential_execution()) {
|
||||
TF_RETURN_IF_ERROR(OrderCollectives(&client_graph->graph));
|
||||
}
|
||||
|
||||
// Partition the graph across devices.
|
||||
PartitionOptions popts;
|
||||
popts.node_to_loc = [](const Node* node) {
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/collective_order.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/subgraph.h"
|
||||
@ -819,6 +820,12 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
|
||||
}
|
||||
}
|
||||
|
||||
// Make collective execution order deterministic if needed.
|
||||
if (options.collective_order != GraphCollectiveOrder::kNone) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
OrderCollectives(optimized_graph.get(), options.collective_order));
|
||||
}
|
||||
|
||||
// Copy the extracted graph in order to make its node ids dense,
|
||||
// since the local CostModel used to record its stats is sized by
|
||||
// the largest node id.
|
||||
|
@ -292,8 +292,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
||||
if (tot >= 0.1 * 1048576.0) {
|
||||
bytes = strings::Printf("[%.1fMB] ", tot / 1048576.0);
|
||||
}
|
||||
return strings::StrCat(bytes, stats.node_name(), " = ",
|
||||
details.type_string, details.detail_text);
|
||||
return strings::StrCat(bytes, stats.node_name(), " = ", details.type_string,
|
||||
details.detail_text);
|
||||
}
|
||||
|
||||
// Send/Recv nodes that are the result of client-added
|
||||
@ -1081,16 +1081,17 @@ void CopyAndSortStrings(size_t size,
|
||||
} // namespace
|
||||
|
||||
void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
|
||||
const ConfigProto& config,
|
||||
BuildGraphOptions* opts) {
|
||||
CallableOptions* callable_opts = &opts->callable_options;
|
||||
CopyAndSortStrings(req.num_feeds(),
|
||||
[&req](size_t i) { return req.feed_name(i); },
|
||||
CopyAndSortStrings(
|
||||
req.num_feeds(), [&req](size_t i) { return req.feed_name(i); },
|
||||
callable_opts->mutable_feed());
|
||||
CopyAndSortStrings(req.num_fetches(),
|
||||
[&req](size_t i) { return req.fetch_name(i); },
|
||||
CopyAndSortStrings(
|
||||
req.num_fetches(), [&req](size_t i) { return req.fetch_name(i); },
|
||||
callable_opts->mutable_fetch());
|
||||
CopyAndSortStrings(req.num_targets(),
|
||||
[&req](size_t i) { return req.target_name(i); },
|
||||
CopyAndSortStrings(
|
||||
req.num_targets(), [&req](size_t i) { return req.target_name(i); },
|
||||
callable_opts->mutable_target());
|
||||
|
||||
if (!req.options().debug_options().debug_tensor_watch_opts().empty()) {
|
||||
@ -1100,18 +1101,22 @@ void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
|
||||
|
||||
opts->collective_graph_key =
|
||||
req.options().experimental().collective_graph_key();
|
||||
if (config.experimental().collective_deterministic_sequential_execution()) {
|
||||
opts->collective_order = GraphCollectiveOrder::kEdges;
|
||||
}
|
||||
}
|
||||
|
||||
void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
|
||||
BuildGraphOptions* opts) {
|
||||
CallableOptions* callable_opts = &opts->callable_options;
|
||||
CopyAndSortStrings(req.feed_size(), [&req](size_t i) { return req.feed(i); },
|
||||
CopyAndSortStrings(
|
||||
req.feed_size(), [&req](size_t i) { return req.feed(i); },
|
||||
callable_opts->mutable_feed());
|
||||
CopyAndSortStrings(req.fetch_size(),
|
||||
[&req](size_t i) { return req.fetch(i); },
|
||||
CopyAndSortStrings(
|
||||
req.fetch_size(), [&req](size_t i) { return req.fetch(i); },
|
||||
callable_opts->mutable_fetch());
|
||||
CopyAndSortStrings(req.target_size(),
|
||||
[&req](size_t i) { return req.target(i); },
|
||||
CopyAndSortStrings(
|
||||
req.target_size(), [&req](size_t i) { return req.target(i); },
|
||||
callable_opts->mutable_target());
|
||||
|
||||
// TODO(cais): Add TFDBG support to partial runs.
|
||||
@ -1852,7 +1857,7 @@ Status MasterSession::DoRunWithLocalExecution(
|
||||
|
||||
// Prepare.
|
||||
BuildGraphOptions bgopts;
|
||||
BuildBuildGraphOptions(req, &bgopts);
|
||||
BuildBuildGraphOptions(req, session_opts_.config, &bgopts);
|
||||
ReffedClientGraph* rcg = nullptr;
|
||||
int64 count;
|
||||
TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count));
|
||||
|
@ -14,55 +14,70 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/graph/collective_order.h"
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status OrderCollectives(Graph* graph) {
|
||||
// `instance_keys[i]` corresponds to `collective_nodes[i]`
|
||||
std::vector<Node*> collective_nodes;
|
||||
std::vector<int32> instance_keys;
|
||||
// node -> set of collectives on which node depends.
|
||||
std::unordered_map<Node*, std::unordered_set<int32>> node_dependencies;
|
||||
// Find all CollectiveReduce nodes and the existing data dependencies between
|
||||
// them.
|
||||
Status DiscoverDataDependencies(
|
||||
const Graph* graph, std::vector<Node*>* collective_nodes,
|
||||
std::vector<int32>* instance_keys,
|
||||
absl::flat_hash_map<Node*, absl::flat_hash_set<int32>>* data_dependencies) {
|
||||
Status s;
|
||||
|
||||
// Algorithm: do Reverse DFS starting at sink. `node_leave` is called when
|
||||
// all parents of `node` have been visited. At that point, the collectives
|
||||
// on which this node depends on are up to date. For this node's children,
|
||||
// add all these collectives. Also, if this node is collective, add as a
|
||||
// dependency for the children.
|
||||
auto node_leave = [&collective_nodes, &instance_keys, &node_dependencies,
|
||||
// all parents of `node` have been visited. At that point,
|
||||
// `data_dependencies[node]` is a list containing `instance_key` of every
|
||||
// `CollectiveReduce` on which `node` has a data dependency.
|
||||
// For this node's children, add all these instance keys. Also, if this node
|
||||
// is collective, add as a dependency for the children.
|
||||
auto node_leave = [collective_nodes, instance_keys, data_dependencies,
|
||||
&s](Node* node) {
|
||||
int32 instance_key;
|
||||
if (node->IsCollective()) {
|
||||
bool enter_node =
|
||||
node->IsCollective() && node->type_string() == "CollectiveReduce";
|
||||
if (enter_node) {
|
||||
Status get_attr_status =
|
||||
GetNodeAttr(node->attrs(), "instance_key", &instance_key);
|
||||
s.Update(get_attr_status);
|
||||
collective_nodes.push_back(node);
|
||||
instance_keys.push_back(instance_key);
|
||||
collective_nodes->push_back(node);
|
||||
instance_keys->push_back(instance_key);
|
||||
VLOG(2) << "collective node " << node->DebugString();
|
||||
}
|
||||
const auto& node_deps = node_dependencies[node];
|
||||
const auto& node_deps = (*data_dependencies)[node];
|
||||
for (const Edge* out_edge : node->out_edges()) {
|
||||
auto& child_deps = node_dependencies[out_edge->dst()];
|
||||
auto& child_deps = (*data_dependencies)[out_edge->dst()];
|
||||
child_deps.insert(node_deps.begin(), node_deps.end());
|
||||
if (node->IsCollective() && s.ok()) {
|
||||
if (enter_node && s.ok()) {
|
||||
child_deps.insert(instance_key);
|
||||
}
|
||||
}
|
||||
};
|
||||
ReverseDFS(*graph, nullptr, node_leave);
|
||||
if (!s.ok()) return s;
|
||||
return s;
|
||||
}
|
||||
|
||||
// For all pairs of collective nodes n1 and n2 on the same device, if n1 does
|
||||
// not depend on n2 and n2 does not depend on n1, then they are potentially
|
||||
// concurrent. Add an arbitrary, deterministic control edge between them.
|
||||
// Given a list of `collective_nodes` and `data_dependencies` between the
|
||||
// collective nodes, create control dependencies between concurrent collectives
|
||||
// and store in `dependency_edges`.
|
||||
// If there exists an edge a -> b then `dependency_edges[a]` contains `b`
|
||||
Status CreateControlDependencies(
|
||||
const std::vector<Node*>& collective_nodes,
|
||||
const std::vector<int32>& instance_keys,
|
||||
absl::flat_hash_map<Node*, absl::flat_hash_set<int32>>* data_dependencies,
|
||||
absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>>* dependency_edges) {
|
||||
// If there exists some path a -> ... -> b then `all_paths[a]` contains `b`
|
||||
absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>> all_paths;
|
||||
for (int i = 0; i < collective_nodes.size() - 1; i++) {
|
||||
if (!collective_nodes[i]->IsCollective()) {
|
||||
if (!collective_nodes[i]->IsCollective() ||
|
||||
collective_nodes[i]->type_string() != "CollectiveReduce") {
|
||||
return errors::Internal("Unexpected node ",
|
||||
collective_nodes[i]->DebugString());
|
||||
}
|
||||
const auto& deps_i = node_dependencies[collective_nodes[i]];
|
||||
const auto& deps_i = (*data_dependencies)[collective_nodes[i]];
|
||||
for (int j = i + 1; j < collective_nodes.size(); j++) {
|
||||
if (collective_nodes[i]->requested_device() !=
|
||||
collective_nodes[j]->requested_device()) {
|
||||
@ -74,17 +89,48 @@ Status OrderCollectives(Graph* graph) {
|
||||
" on 2 nodes with the same device ",
|
||||
collective_nodes[i]->requested_device());
|
||||
}
|
||||
const auto& deps_j = node_dependencies[collective_nodes[j]];
|
||||
const auto& deps_j = (*data_dependencies)[collective_nodes[j]];
|
||||
if (deps_i.find(instance_keys[j]) == deps_i.end() &&
|
||||
deps_j.find(instance_keys[i]) == deps_j.end()) {
|
||||
int src_idx = instance_keys[i] < instance_keys[j] ? i : j;
|
||||
int dst_idx = instance_keys[i] < instance_keys[j] ? j : i;
|
||||
Node* src_node = collective_nodes[src_idx];
|
||||
Node* dst_node = collective_nodes[dst_idx];
|
||||
VLOG(1) << "Adding control edge from node " << src_node->name()
|
||||
VLOG(1) << "Adding control dependency from node " << src_node->name()
|
||||
<< " instance " << instance_keys[src_idx] << " to node "
|
||||
<< dst_node->name() << " instance " << instance_keys[dst_idx];
|
||||
graph->AddControlEdge(src_node, dst_node);
|
||||
(*dependency_edges)[src_node].insert(dst_node);
|
||||
auto& src_paths = all_paths[src_node];
|
||||
src_paths.insert(dst_node);
|
||||
for (Node* downstream_node : all_paths[dst_node]) {
|
||||
src_paths.insert(downstream_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prune dependency edges so that if there are edges a -> b, b -> c, and a ->
|
||||
// c, then remove a -> c. This dependency would be handled naturally during
|
||||
// op scheduling.
|
||||
for (int i = 0; i < collective_nodes.size(); ++i) {
|
||||
Node* node = collective_nodes[i];
|
||||
auto& neighbor_set = (*dependency_edges)[node];
|
||||
std::vector<Node*> neighbor_list(neighbor_set.begin(), neighbor_set.end());
|
||||
// For all n1, n2 in `neighbor_list` if there is a path from n1 -> n2 then
|
||||
// eliminate n2 from `neighbor_set` and `neighbor_list`. We remove from
|
||||
// `neighbor_list` by replacing with a `nullptr`, hence the `nullptr` checks
|
||||
// below.
|
||||
for (int j = 0; j < neighbor_list.size(); ++j) {
|
||||
Node* n1 = neighbor_list[j];
|
||||
if (n1 == nullptr) continue;
|
||||
auto& n1_paths = all_paths[n1];
|
||||
for (int k = 0; k < neighbor_list.size(); ++k) {
|
||||
Node* n2 = neighbor_list[k];
|
||||
if (j == k || n2 == nullptr) continue;
|
||||
if (n1_paths.find(n2) != n1_paths.end()) {
|
||||
neighbor_set.erase(n2);
|
||||
neighbor_list[k] = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -92,4 +138,65 @@ Status OrderCollectives(Graph* graph) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Insert control dependencies defined by `dependency_edges` in `graph`. If
|
||||
// `order_type` is `kEdges`, insert explicit control edges, else if `order_type`
|
||||
// is `kAttrs`, encode depdencies as an attribute on collective node.
|
||||
Status InsertControlDependencies(
|
||||
Graph* graph, GraphCollectiveOrder order_type,
|
||||
const absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>>&
|
||||
dependency_edges) {
|
||||
if (order_type == GraphCollectiveOrder::kEdges) {
|
||||
for (const auto& pair : dependency_edges) {
|
||||
Node* src_node = pair.first;
|
||||
for (Node* dst_node : pair.second) {
|
||||
graph->AddControlEdge(src_node, dst_node);
|
||||
}
|
||||
}
|
||||
} else if (order_type == GraphCollectiveOrder::kAttrs) {
|
||||
// `wait_for` is the inverse of `dependency_edges`, i.e. `wait_for[node]`
|
||||
// contains the list of instance keys for which `node` must wait.
|
||||
absl::flat_hash_map<Node*, absl::flat_hash_set<int32>> wait_for;
|
||||
for (const auto& pair : dependency_edges) {
|
||||
int32 src_instance;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(pair.first->attrs(), "instance_key", &src_instance));
|
||||
for (Node* dst_node : pair.second) {
|
||||
wait_for[dst_node].insert(src_instance);
|
||||
}
|
||||
}
|
||||
for (const auto& pair : wait_for) {
|
||||
std::vector<int32> wait_for_list(pair.second.begin(), pair.second.end());
|
||||
pair.first->ClearAttr("wait_for");
|
||||
pair.first->AddAttr("wait_for", wait_for_list);
|
||||
}
|
||||
} else {
|
||||
return errors::Internal("Unexpected GraphCollectiveOrder type ",
|
||||
static_cast<int>(order_type));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status OrderCollectives(Graph* graph, GraphCollectiveOrder order_type) {
|
||||
// `instance_keys[i]` corresponds to `collective_nodes[i]`
|
||||
std::vector<Node*> collective_nodes;
|
||||
std::vector<int32> instance_keys;
|
||||
// node -> set of collectives on which node depends.
|
||||
absl::flat_hash_map<Node*, absl::flat_hash_set<int32>> data_dependencies;
|
||||
TF_RETURN_IF_ERROR(DiscoverDataDependencies(
|
||||
graph, &collective_nodes, &instance_keys, &data_dependencies));
|
||||
|
||||
if (collective_nodes.empty()) return Status::OK();
|
||||
|
||||
absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>> dependency_edges;
|
||||
// For all pairs of collective nodes n1 and n2 on the same device, if n1 does
|
||||
// not depend on n2 and n2 does not depend on n1, then they are potentially
|
||||
// concurrent. Create an arbitrary, deterministic ordering between them.
|
||||
TF_RETURN_IF_ERROR(CreateControlDependencies(
|
||||
collective_nodes, instance_keys, &data_dependencies, &dependency_edges));
|
||||
|
||||
return InsertControlDependencies(graph, order_type, dependency_edges);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -19,11 +19,17 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Introduces control edges between potentially concurrent CollectiveOps to make
|
||||
// their execution order deterministic. This may be used to execute collectives
|
||||
// in the same order across all workers in a distributed execution, if all
|
||||
// workers are executing the same graph.
|
||||
Status OrderCollectives(Graph* graph);
|
||||
enum class GraphCollectiveOrder { kNone, kEdges, kAttrs };
|
||||
|
||||
// Introduces a deterministic execution order between potentially concurrent
|
||||
// CollectiveOps. This may be used to execute collectives in the same order
|
||||
// across all workers in a distributed execution, if all workers are executing
|
||||
// the same graph.
|
||||
// If `order_type` is `kEdges`, introduce the ordering in the form of explicit
|
||||
// control edges between collective graph nodes. If `order_type` is `kAttrs`,
|
||||
// add an attribute to the node which may be used by collective executor to
|
||||
// ensure the required ordering.
|
||||
Status OrderCollectives(Graph* graph, GraphCollectiveOrder order_type);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -59,6 +59,23 @@ void VerifyGraph(const Graph& graph,
|
||||
UnorderedElementsAreArray(expected_collective_control_edges));
|
||||
}
|
||||
|
||||
// Verifies that the `wait_for` attribute on collective nodes matches
|
||||
// `wait_for_map`.
|
||||
void VerifyAttrs(
|
||||
const Graph& graph,
|
||||
const std::unordered_map<string, std::vector<int32>> wait_for_map) {
|
||||
for (const Node* node : graph.nodes()) {
|
||||
if (node->IsCollective() ||
|
||||
wait_for_map.find(node->name()) == wait_for_map.end()) {
|
||||
continue;
|
||||
}
|
||||
std::vector<int32> wait_for_actual;
|
||||
TF_EXPECT_OK(GetNodeAttr(node->attrs(), "wait_for", &wait_for_actual));
|
||||
auto wait_for_expected = wait_for_map.at(node->name());
|
||||
EXPECT_THAT(wait_for_actual, UnorderedElementsAreArray(wait_for_expected));
|
||||
}
|
||||
}
|
||||
|
||||
Node* CollectiveReduceNode(GraphDefBuilder* builder, Node* input,
|
||||
const string& name, const string& device,
|
||||
int instance_key) {
|
||||
@ -123,11 +140,17 @@ std::unique_ptr<Graph> InitGraph() {
|
||||
// added after calling `OrderCollectives`: c2_0 -> c3_0 and c2_1 -> c3_1.
|
||||
TEST(CollectiveOrderTest, SimpleOrder) {
|
||||
std::unique_ptr<Graph> graph = InitGraph();
|
||||
TF_EXPECT_OK(OrderCollectives(graph.get()));
|
||||
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kEdges));
|
||||
VerifyGraph(*graph, {"c1_0", "c1_1", "c2_0", "c2_1", "c3_0", "c3_1"},
|
||||
{{"c2_0", "c3_0"}, {"c2_1", "c3_1"}});
|
||||
}
|
||||
|
||||
TEST(CollectiveOrderTest, SimpleOrderAttr) {
|
||||
std::unique_ptr<Graph> graph = InitGraph();
|
||||
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kAttrs));
|
||||
VerifyAttrs(*graph, {{"c3_0", {2}}, {"c3_1", {2}}});
|
||||
}
|
||||
|
||||
// Initialize the following graph:
|
||||
//
|
||||
// a
|
||||
@ -162,12 +185,50 @@ std::unique_ptr<Graph> InitGraph2() {
|
||||
}
|
||||
|
||||
// Tests that in the graph created by `InitGraph2`, we add the following control
|
||||
// edges after calling `OrderCollectives`: c2 -> c3, c3 -> c4, and c2 -> c4.
|
||||
// edges after calling `OrderCollectives`: c2 -> c3, c3 -> c4. c2->c4 is
|
||||
// pruned because it follows from the other two edges.
|
||||
TEST(CollectiveOrderTest, SimpleOrder2) {
|
||||
std::unique_ptr<Graph> graph = InitGraph2();
|
||||
TF_EXPECT_OK(OrderCollectives(graph.get()));
|
||||
VerifyGraph(*graph, {"c1", "c2", "c3", "c4"},
|
||||
{{"c2", "c3"}, {"c3", "c4"}, {"c2", "c4"}});
|
||||
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kEdges));
|
||||
VerifyGraph(*graph, {"c1", "c2", "c3", "c4"}, {{"c2", "c3"}, {"c3", "c4"}});
|
||||
}
|
||||
|
||||
// Initialize the following graph:
|
||||
//
|
||||
// w x y z
|
||||
// | | | |
|
||||
// c1 c2 c3 c4
|
||||
//
|
||||
std::unique_ptr<Graph> InitGraphForPruning() {
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
const string dev0 = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
Node* w = ops::SourceOp("TestParams",
|
||||
builder.opts().WithName("w").WithDevice(dev0));
|
||||
Node* x = ops::SourceOp("TestParams",
|
||||
builder.opts().WithName("x").WithDevice(dev0));
|
||||
Node* y = ops::SourceOp("TestParams",
|
||||
builder.opts().WithName("y").WithDevice(dev0));
|
||||
Node* z = ops::SourceOp("TestParams",
|
||||
builder.opts().WithName("z").WithDevice(dev0));
|
||||
CollectiveReduceNode(&builder, w, "c1", dev0, 1);
|
||||
CollectiveReduceNode(&builder, x, "c2", dev0, 2);
|
||||
CollectiveReduceNode(&builder, y, "c3", dev0, 3);
|
||||
CollectiveReduceNode(&builder, z, "c4", dev0, 4);
|
||||
|
||||
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
Status s = GraphDefBuilderToGraph(builder, graph.get());
|
||||
if (!s.ok()) {
|
||||
LOG(FATAL) << "Error building graph " << s;
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
// Tests that in the graph created by `InitGraphForPruning`, we only add c1 ->
|
||||
// c2, c2 -> c3, c3 -> c4, and other edges are pruned away.
|
||||
TEST(CollectiveOrderTest, Pruning) {
|
||||
std::unique_ptr<Graph> graph = InitGraphForPruning();
|
||||
TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kAttrs));
|
||||
VerifyAttrs(*graph, {{"c4", {3}}, {"c3", {2}}, {"c2", {1}}});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -28,6 +28,7 @@ REGISTER_OP("CollectiveReduce")
|
||||
.Attr("merge_op: {'Min', 'Max', 'Mul', 'Add'}")
|
||||
.Attr("final_op: {'Id', 'Div'}")
|
||||
.Attr("subdiv_offsets: list(int)")
|
||||
.Attr("wait_for: list(int) = []")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user