Enable a deterministic and sequential execution order of potentially concurrent
collective instances. Before this change, collective op instances that were concurrent could execute in a random, overlapping order across devices and workers. This change introduces a mechanism to enforce the same execution order of collective instances across devices and workers. It first determines dependencies between collective op nodes in the graph. For each pair of collective instances that do not depend on each other, this change adds an arbitrary but deterministic control edge between the nodes. This ordering only kicks in if `collective_deterministic_sequential_execution` is true in ConfigProto. PiperOrigin-RevId: 226397542
This commit is contained in:
parent
795f9cbf91
commit
d5216948d1
@ -2772,6 +2772,7 @@ cc_library(
|
|||||||
# in this library.
|
# in this library.
|
||||||
GRAPH_HDRS = [
|
GRAPH_HDRS = [
|
||||||
"graph/algorithm.h",
|
"graph/algorithm.h",
|
||||||
|
"graph/collective_order.h",
|
||||||
"graph/colors.h",
|
"graph/colors.h",
|
||||||
"graph/control_flow.h",
|
"graph/control_flow.h",
|
||||||
"graph/costmodel.h",
|
"graph/costmodel.h",
|
||||||
@ -2798,6 +2799,7 @@ tf_cuda_library(
|
|||||||
name = "graph",
|
name = "graph",
|
||||||
srcs = [
|
srcs = [
|
||||||
"graph/algorithm.cc",
|
"graph/algorithm.cc",
|
||||||
|
"graph/collective_order.cc",
|
||||||
"graph/colors.cc",
|
"graph/colors.cc",
|
||||||
"graph/control_flow.cc",
|
"graph/control_flow.cc",
|
||||||
"graph/costmodel.cc",
|
"graph/costmodel.cc",
|
||||||
@ -3842,6 +3844,27 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_tests(
|
||||||
|
name = "collective_order_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = [
|
||||||
|
"graph/collective_order_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":core",
|
||||||
|
":core_cpu",
|
||||||
|
":core_cpu_internal",
|
||||||
|
":framework",
|
||||||
|
":framework_internal",
|
||||||
|
":lib",
|
||||||
|
":lib_internal",
|
||||||
|
":ops",
|
||||||
|
":protos_all_cc",
|
||||||
|
":test",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_tests_gpu(
|
tf_cc_tests_gpu(
|
||||||
name = "ring_reducer_test",
|
name = "ring_reducer_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
@ -45,6 +45,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/versions.pb.h"
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/algorithm.h"
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
|
#include "tensorflow/core/graph/collective_order.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/graph/graph_partition.h"
|
#include "tensorflow/core/graph/graph_partition.h"
|
||||||
@ -1519,6 +1520,12 @@ Status DirectSession::CreateGraphs(
|
|||||||
CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
|
CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Make collective execution order deterministic if needed.
|
||||||
|
if (options_.config.experimental()
|
||||||
|
.collective_deterministic_sequential_execution()) {
|
||||||
|
TF_RETURN_IF_ERROR(OrderCollectives(&client_graph->graph));
|
||||||
|
}
|
||||||
|
|
||||||
// Partition the graph across devices.
|
// Partition the graph across devices.
|
||||||
PartitionOptions popts;
|
PartitionOptions popts;
|
||||||
popts.node_to_loc = [](const Node* node) {
|
popts.node_to_loc = [](const Node* node) {
|
||||||
|
95
tensorflow/core/graph/collective_order.cc
Normal file
95
tensorflow/core/graph/collective_order.cc
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/graph/collective_order.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
Status OrderCollectives(Graph* graph) {
|
||||||
|
// `instance_keys[i]` corresponds to `collective_nodes[i]`
|
||||||
|
std::vector<Node*> collective_nodes;
|
||||||
|
std::vector<int32> instance_keys;
|
||||||
|
// node -> set of collectives on which node depends.
|
||||||
|
std::unordered_map<Node*, std::unordered_set<int32>> node_dependencies;
|
||||||
|
Status s;
|
||||||
|
|
||||||
|
// Algorithm: do Reverse DFS starting at sink. `node_leave` is called when
|
||||||
|
// all parents of `node` have been visited. At that point, the collectives
|
||||||
|
// on which this node depends on are up to date. For this node's children,
|
||||||
|
// add all these collectives. Also, if this node is collective, add as a
|
||||||
|
// dependency for the children.
|
||||||
|
auto node_leave = [&collective_nodes, &instance_keys, &node_dependencies,
|
||||||
|
&s](Node* node) {
|
||||||
|
int32 instance_key;
|
||||||
|
if (node->IsCollective()) {
|
||||||
|
Status get_attr_status =
|
||||||
|
GetNodeAttr(node->attrs(), "instance_key", &instance_key);
|
||||||
|
s.Update(get_attr_status);
|
||||||
|
collective_nodes.push_back(node);
|
||||||
|
instance_keys.push_back(instance_key);
|
||||||
|
VLOG(2) << "collective node " << node->DebugString();
|
||||||
|
}
|
||||||
|
const auto& node_deps = node_dependencies[node];
|
||||||
|
for (const Edge* out_edge : node->out_edges()) {
|
||||||
|
auto& child_deps = node_dependencies[out_edge->dst()];
|
||||||
|
child_deps.insert(node_deps.begin(), node_deps.end());
|
||||||
|
if (node->IsCollective() && s.ok()) {
|
||||||
|
child_deps.insert(instance_key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
ReverseDFS(*graph, nullptr, node_leave);
|
||||||
|
if (!s.ok()) return s;
|
||||||
|
|
||||||
|
// For all pairs of collective nodes n1 and n2 on the same device, if n1 does
|
||||||
|
// not depend on n2 and n2 does not depend on n1, then they are potentially
|
||||||
|
// concurrent. Add an arbitrary, deterministic control edge between them.
|
||||||
|
for (int i = 0; i < collective_nodes.size() - 1; i++) {
|
||||||
|
if (!collective_nodes[i]->IsCollective()) {
|
||||||
|
return errors::Internal("Unexpected node ",
|
||||||
|
collective_nodes[i]->DebugString());
|
||||||
|
}
|
||||||
|
const auto& deps_i = node_dependencies[collective_nodes[i]];
|
||||||
|
for (int j = i + 1; j < collective_nodes.size(); j++) {
|
||||||
|
if (collective_nodes[i]->requested_device() !=
|
||||||
|
collective_nodes[j]->requested_device()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (instance_keys[i] == instance_keys[j]) {
|
||||||
|
return errors::Internal("Unexpected same instance_key ",
|
||||||
|
instance_keys[i],
|
||||||
|
" on 2 nodes with the same device ",
|
||||||
|
collective_nodes[i]->requested_device());
|
||||||
|
}
|
||||||
|
const auto& deps_j = node_dependencies[collective_nodes[j]];
|
||||||
|
if (deps_i.find(instance_keys[j]) == deps_i.end() &&
|
||||||
|
deps_j.find(instance_keys[i]) == deps_j.end()) {
|
||||||
|
int src_idx = instance_keys[i] < instance_keys[j] ? i : j;
|
||||||
|
int dst_idx = instance_keys[i] < instance_keys[j] ? j : i;
|
||||||
|
Node* src_node = collective_nodes[src_idx];
|
||||||
|
Node* dst_node = collective_nodes[dst_idx];
|
||||||
|
VLOG(1) << "Adding control edge from node " << src_node->name()
|
||||||
|
<< " instance " << instance_keys[src_idx] << " to node "
|
||||||
|
<< dst_node->name() << " instance " << instance_keys[dst_idx];
|
||||||
|
graph->AddControlEdge(src_node, dst_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
30
tensorflow/core/graph/collective_order.h
Normal file
30
tensorflow/core/graph/collective_order.h
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_GRAPH_COLLECTIVE_ORDER_H_
|
||||||
|
#define TENSORFLOW_CORE_GRAPH_COLLECTIVE_ORDER_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Introduces control edges between potentially concurrent CollectiveOps to make
|
||||||
|
// their execution order deterministic. This may be used to execute collectives
|
||||||
|
// in the same order across all workers in a distributed execution, if all
|
||||||
|
// workers are executing the same graph.
|
||||||
|
Status OrderCollectives(Graph* graph);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_GRAPH_COLLECTIVE_ORDER_H_
|
174
tensorflow/core/graph/collective_order_test.cc
Normal file
174
tensorflow/core/graph/collective_order_test.cc
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/graph/collective_order.h"
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||||
|
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::UnorderedElementsAreArray;
|
||||||
|
|
||||||
|
REGISTER_OP("TestParams").Output("o: float");
|
||||||
|
|
||||||
|
// Verifies that the list of collective nodes in `graph` matches
|
||||||
|
// `expected_collective_nodes`, and that the list of control edges between these
|
||||||
|
// collective nodes matches `expected_collective_control_edges`.
|
||||||
|
void VerifyGraph(const Graph& graph,
|
||||||
|
const std::vector<string>& expected_collective_nodes,
|
||||||
|
const std::vector<std::pair<string, string>>&
|
||||||
|
expected_collective_control_edges) {
|
||||||
|
std::vector<string> actual_collective_nodes;
|
||||||
|
std::vector<std::pair<string, string>> actual_collective_control_edges;
|
||||||
|
for (const Node* src : graph.nodes()) {
|
||||||
|
if (!src->IsCollective()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
actual_collective_nodes.push_back(src->name());
|
||||||
|
for (const Edge* edge : src->out_edges()) {
|
||||||
|
VLOG(2) << "collective edge " << edge->src()->name() << " -> "
|
||||||
|
<< edge->dst()->name();
|
||||||
|
// Add all control edges found except those to `_SINK`.
|
||||||
|
if (!edge->IsControlEdge() || edge->dst()->name() == "_SINK") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
actual_collective_control_edges.emplace_back(src->name(),
|
||||||
|
edge->dst()->name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_THAT(actual_collective_nodes,
|
||||||
|
UnorderedElementsAreArray(expected_collective_nodes));
|
||||||
|
EXPECT_THAT(actual_collective_control_edges,
|
||||||
|
UnorderedElementsAreArray(expected_collective_control_edges));
|
||||||
|
}
|
||||||
|
|
||||||
|
Node* CollectiveReduceNode(GraphDefBuilder* builder, Node* input,
|
||||||
|
const string& name, const string& device,
|
||||||
|
int instance_key) {
|
||||||
|
Node* collective_node =
|
||||||
|
ops::UnaryOp("CollectiveReduce", input,
|
||||||
|
builder->opts()
|
||||||
|
.WithName(name)
|
||||||
|
.WithDevice(device)
|
||||||
|
.WithAttr("T", DT_FLOAT)
|
||||||
|
.WithAttr("group_size", 2)
|
||||||
|
.WithAttr("group_key", 1)
|
||||||
|
.WithAttr("instance_key", instance_key)
|
||||||
|
.WithAttr("merge_op", "Add")
|
||||||
|
.WithAttr("final_op", "Id")
|
||||||
|
.WithAttr("subdiv_offsets", {1}));
|
||||||
|
return collective_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the following graph:
|
||||||
|
//
|
||||||
|
// (cpu0) (cpu1)
|
||||||
|
// a b
|
||||||
|
// | |
|
||||||
|
// c1 c1
|
||||||
|
// | |
|
||||||
|
// id id
|
||||||
|
// / \ / \
|
||||||
|
// c2 c3 c2 c3
|
||||||
|
//
|
||||||
|
// Here ci denotes a collective node with `instance_key` i. `a` and `b` are
|
||||||
|
// inputs, `id` is identity node.
|
||||||
|
std::unique_ptr<Graph> InitGraph() {
|
||||||
|
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||||
|
const string dev0 = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
|
const string dev1 = "/job:localhost/replica:0/task:0/device:CPU:1";
|
||||||
|
Node* a = ops::SourceOp("TestParams",
|
||||||
|
builder.opts().WithName("a").WithDevice(dev0));
|
||||||
|
Node* b = ops::SourceOp("TestParams",
|
||||||
|
builder.opts().WithName("b").WithDevice(dev1));
|
||||||
|
Node* c1_0 = CollectiveReduceNode(&builder, a, "c1_0", dev0, 1);
|
||||||
|
Node* c1_1 = CollectiveReduceNode(&builder, b, "c1_1", dev1, 1);
|
||||||
|
Node* id0 = ops::UnaryOp(
|
||||||
|
"Identity", c1_0,
|
||||||
|
builder.opts().WithName("id0").WithDevice(dev0).WithAttr("T", DT_FLOAT));
|
||||||
|
Node* id1 = ops::UnaryOp(
|
||||||
|
"Identity", c1_1,
|
||||||
|
builder.opts().WithName("id1").WithDevice(dev1).WithAttr("T", DT_FLOAT));
|
||||||
|
CollectiveReduceNode(&builder, id0, "c2_0", dev0, 2);
|
||||||
|
CollectiveReduceNode(&builder, id1, "c2_1", dev1, 2);
|
||||||
|
CollectiveReduceNode(&builder, id0, "c3_0", dev0, 3);
|
||||||
|
CollectiveReduceNode(&builder, id1, "c3_1", dev1, 3);
|
||||||
|
|
||||||
|
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||||
|
Status s = GraphDefBuilderToGraph(builder, graph.get());
|
||||||
|
if (!s.ok()) {
|
||||||
|
LOG(FATAL) << "Error building graph " << s;
|
||||||
|
}
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that in the graph created by `InitGraph`, exactly 2 control edges are
|
||||||
|
// added after calling `OrderCollectives`: c2_0 -> c3_0 and c2_1 -> c3_1.
|
||||||
|
TEST(CollectiveOrderTest, SimpleOrder) {
|
||||||
|
std::unique_ptr<Graph> graph = InitGraph();
|
||||||
|
TF_EXPECT_OK(OrderCollectives(graph.get()));
|
||||||
|
VerifyGraph(*graph, {"c1_0", "c1_1", "c2_0", "c2_1", "c3_0", "c3_1"},
|
||||||
|
{{"c2_0", "c3_0"}, {"c2_1", "c3_1"}});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the following graph:
|
||||||
|
//
|
||||||
|
// a
|
||||||
|
// |
|
||||||
|
// c1
|
||||||
|
// / \
|
||||||
|
// c4 id
|
||||||
|
// / \
|
||||||
|
// c2 c3
|
||||||
|
//
|
||||||
|
// Here ci denotes a collective node with `instance_key` i. `a` is an input,
|
||||||
|
// `id` is identity node.
|
||||||
|
std::unique_ptr<Graph> InitGraph2() {
|
||||||
|
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||||
|
const string dev0 = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
|
Node* a = ops::SourceOp("TestParams",
|
||||||
|
builder.opts().WithName("a").WithDevice(dev0));
|
||||||
|
Node* c1 = CollectiveReduceNode(&builder, a, "c1", dev0, 1);
|
||||||
|
CollectiveReduceNode(&builder, c1, "c4", dev0, 4);
|
||||||
|
Node* id = ops::UnaryOp(
|
||||||
|
"Identity", c1,
|
||||||
|
builder.opts().WithName("id").WithDevice(dev0).WithAttr("T", DT_FLOAT));
|
||||||
|
CollectiveReduceNode(&builder, id, "c2", dev0, 2);
|
||||||
|
CollectiveReduceNode(&builder, id, "c3", dev0, 3);
|
||||||
|
|
||||||
|
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||||
|
Status s = GraphDefBuilderToGraph(builder, graph.get());
|
||||||
|
if (!s.ok()) {
|
||||||
|
LOG(FATAL) << "Error building graph " << s;
|
||||||
|
}
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that in the graph created by `InitGraph2`, we add the following control
|
||||||
|
// edges after calling `OrderCollectives`: c2 -> c3, c3 -> c4, and c2 -> c4.
|
||||||
|
TEST(CollectiveOrderTest, SimpleOrder2) {
|
||||||
|
std::unique_ptr<Graph> graph = InitGraph2();
|
||||||
|
TF_EXPECT_OK(OrderCollectives(graph.get()));
|
||||||
|
VerifyGraph(*graph, {"c1", "c2", "c3", "c4"},
|
||||||
|
{{"c2", "c3"}, {"c3", "c4"}, {"c2", "c4"}});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -425,6 +425,10 @@ message ConfigProto {
|
|||||||
// use NUMA affinity where applicable. One consequence will be the
|
// use NUMA affinity where applicable. One consequence will be the
|
||||||
// existence of as many CPU devices as there are available NUMA nodes.
|
// existence of as many CPU devices as there are available NUMA nodes.
|
||||||
bool use_numa_affinity = 5;
|
bool use_numa_affinity = 5;
|
||||||
|
|
||||||
|
// If true, make collective op execution order sequential and deterministic
|
||||||
|
// for potentially concurrent collective instances.
|
||||||
|
bool collective_deterministic_sequential_execution = 6;
|
||||||
};
|
};
|
||||||
|
|
||||||
Experimental experimental = 16;
|
Experimental experimental = 16;
|
||||||
|
@ -50,6 +50,24 @@ class CollectiveOpTest(test.TestCase):
|
|||||||
self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
|
self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
|
||||||
self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5)
|
self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
def _testMultipleConcurrentCollectiveReduce(self, t0, t1, expected):
|
||||||
|
group_key = 1
|
||||||
|
group_size = 2
|
||||||
|
num_instances = 2
|
||||||
|
all_reduces = []
|
||||||
|
config = config_pb2.ConfigProto(device_count={'CPU': group_size})
|
||||||
|
config.experimental.collective_deterministic_sequential_execution = True
|
||||||
|
with self.session(config=config) as sess:
|
||||||
|
for cpu in range(group_size):
|
||||||
|
with ops.device('/CPU:%d' % cpu):
|
||||||
|
in_tensor = constant_op.constant(t0 if cpu == 0 else t1)
|
||||||
|
for instance in range(num_instances):
|
||||||
|
all_reduces.append(collective_ops.all_reduce(
|
||||||
|
in_tensor, group_size, group_key, instance, 'Add', 'Div'))
|
||||||
|
results = sess.run(all_reduces)
|
||||||
|
for i in range(group_size * num_instances):
|
||||||
|
self.assertAllClose(results[i], expected, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testCollectiveReduce(self):
|
def testCollectiveReduce(self):
|
||||||
self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
|
self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
|
||||||
@ -62,6 +80,13 @@ class CollectiveOpTest(test.TestCase):
|
|||||||
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
|
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
|
||||||
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], False)
|
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], False)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testCollectiveMultipleConcurrentReduce(self):
|
||||||
|
self._testMultipleConcurrentCollectiveReduce(
|
||||||
|
[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
|
||||||
|
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
|
||||||
|
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testCollectiveReduceScalar(self):
|
def testCollectiveReduceScalar(self):
|
||||||
self._testCollectiveReduce(0.1, 0.3, 0.2, True)
|
self._testCollectiveReduce(0.1, 0.3, 0.2, True)
|
||||||
|
@ -26,6 +26,12 @@ tf_proto {
|
|||||||
label: LABEL_OPTIONAL
|
label: LABEL_OPTIONAL
|
||||||
type: TYPE_BOOL
|
type: TYPE_BOOL
|
||||||
}
|
}
|
||||||
|
field {
|
||||||
|
name: "collective_deterministic_sequential_execution"
|
||||||
|
number: 6
|
||||||
|
label: LABEL_OPTIONAL
|
||||||
|
type: TYPE_BOOL
|
||||||
|
}
|
||||||
reserved_range {
|
reserved_range {
|
||||||
start: 2
|
start: 2
|
||||||
end: 3
|
end: 3
|
||||||
|
@ -149,6 +149,12 @@ tf_proto {
|
|||||||
label: LABEL_OPTIONAL
|
label: LABEL_OPTIONAL
|
||||||
type: TYPE_BOOL
|
type: TYPE_BOOL
|
||||||
}
|
}
|
||||||
|
field {
|
||||||
|
name: "collective_deterministic_sequential_execution"
|
||||||
|
number: 6
|
||||||
|
label: LABEL_OPTIONAL
|
||||||
|
type: TYPE_BOOL
|
||||||
|
}
|
||||||
reserved_range {
|
reserved_range {
|
||||||
start: 2
|
start: 2
|
||||||
end: 3
|
end: 3
|
||||||
|
Loading…
Reference in New Issue
Block a user