Automated rollback of commit 73a3477356

PiperOrigin-RevId: 211037202
This commit is contained in:
Ayush Dubey 2018-08-30 22:35:08 -07:00 committed by TensorFlower Gardener
parent 348367a88e
commit cb94438312
8 changed files with 24 additions and 215 deletions

View File

@ -4078,7 +4078,6 @@ tf_cuda_cc_test(
":testlib", ":testlib",
"//third_party/eigen3", "//third_party/eigen3",
"//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops",
"//tensorflow/core/kernels:collective_ops",
"//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:dense_update_ops", "//tensorflow/core/kernels:dense_update_ops",
@ -4120,7 +4119,6 @@ tf_cc_test(
"//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops",
# Link with support for TensorFlow Debugger (tfdbg). # Link with support for TensorFlow Debugger (tfdbg).
"//tensorflow/core/debug", "//tensorflow/core/debug",
"//tensorflow/core/kernels:collective_ops",
"//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:dense_update_ops", "//tensorflow/core/kernels:dense_update_ops",

View File

@ -451,22 +451,8 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
RunState run_state(step_id, &devices_); RunState run_state(step_id, &devices_);
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get()); run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
#ifndef __ANDROID__ #ifndef __ANDROID__
// Set up for collectives if ExecutorsAndKeys declares a key. // Set up for collectives if the RunOption declares a key.
if (executors_and_keys->collective_graph_key != if (run_options.experimental().collective_graph_key() > 0) {
BuildGraphOptions::kNoCollectiveGraphKey) {
if (run_options.experimental().collective_graph_key() !=
BuildGraphOptions::kNoCollectiveGraphKey) {
// If a collective_graph_key was specified in run_options, ensure that it
// matches what came out of GraphExecutionState::BuildGraph().
if (run_options.experimental().collective_graph_key() !=
executors_and_keys->collective_graph_key) {
return errors::Internal(
"collective_graph_key in RunOptions ",
run_options.experimental().collective_graph_key(),
" should match collective_graph_key from optimized graph ",
executors_and_keys->collective_graph_key);
}
}
if (!collective_executor_mgr_) { if (!collective_executor_mgr_) {
std::unique_ptr<DeviceResolverInterface> drl( std::unique_ptr<DeviceResolverInterface> drl(
new DeviceResolverLocal(device_mgr_.get())); new DeviceResolverLocal(device_mgr_.get()));
@ -692,13 +678,10 @@ Status DirectSession::Run(const RunOptions& run_options,
// Check if we already have an executor for these arguments. // Check if we already have an executor for these arguments.
ExecutorsAndKeys* executors_and_keys; ExecutorsAndKeys* executors_and_keys;
RunStateArgs run_state_args(run_options.debug_options()); RunStateArgs run_state_args(run_options.debug_options());
run_state_args.collective_graph_key =
run_options.experimental().collective_graph_key();
TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names, TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
target_nodes, &executors_and_keys, target_nodes, &executors_and_keys,
&run_state_args)); &run_state_args));
collective_graph_key_ = executors_and_keys->collective_graph_key;
// Configure a call frame for the step, which we use to feed and // Configure a call frame for the step, which we use to feed and
// fetch values to and from the executors. // fetch values to and from the executors.
@ -1133,8 +1116,6 @@ Status DirectSession::CreateExecutors(
BuildGraphOptions options; BuildGraphOptions options;
options.callable_options = callable_options; options.callable_options = callable_options;
options.use_function_convention = !run_state_args->is_partial_run; options.use_function_convention = !run_state_args->is_partial_run;
options.collective_graph_key =
callable_options.run_options().experimental().collective_graph_key();
std::unique_ptr<FunctionInfo> func_info(new FunctionInfo); std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys); std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
@ -1142,9 +1123,9 @@ Status DirectSession::CreateExecutors(
ek->callable_options = callable_options; ek->callable_options = callable_options;
std::unordered_map<string, std::unique_ptr<Graph>> graphs; std::unordered_map<string, std::unique_ptr<Graph>> graphs;
TF_RETURN_IF_ERROR(CreateGraphs( TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &func_info->flib_def,
options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types, run_state_args, &ek->input_types,
&ek->output_types, &ek->collective_graph_key)); &ek->output_types));
if (run_state_args->is_partial_run) { if (run_state_args->is_partial_run) {
ek->graph = std::move(run_state_args->graph); ek->graph = std::move(run_state_args->graph);
@ -1372,9 +1353,6 @@ Status DirectSession::GetOrCreateExecutors(
} }
*callable_options.mutable_run_options()->mutable_debug_options() = *callable_options.mutable_run_options()->mutable_debug_options() =
run_state_args->debug_options; run_state_args->debug_options;
callable_options.mutable_run_options()
->mutable_experimental()
->set_collective_graph_key(run_state_args->collective_graph_key);
std::unique_ptr<ExecutorsAndKeys> ek; std::unique_ptr<ExecutorsAndKeys> ek;
std::unique_ptr<FunctionInfo> func_info; std::unique_ptr<FunctionInfo> func_info;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
@ -1401,7 +1379,7 @@ Status DirectSession::CreateGraphs(
std::unordered_map<string, std::unique_ptr<Graph>>* outputs, std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
std::unique_ptr<FunctionLibraryDefinition>* flib_def, std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args, DataTypeVector* input_types, RunStateArgs* run_state_args, DataTypeVector* input_types,
DataTypeVector* output_types, int64* collective_graph_key) { DataTypeVector* output_types) {
mutex_lock l(graph_def_lock_); mutex_lock l(graph_def_lock_);
std::unique_ptr<ClientGraph> client_graph; std::unique_ptr<ClientGraph> client_graph;
@ -1425,7 +1403,6 @@ Status DirectSession::CreateGraphs(
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
execution_state->BuildGraph(subgraph_options, &client_graph)); execution_state->BuildGraph(subgraph_options, &client_graph));
} }
*collective_graph_key = client_graph->collective_graph_key;
if (subgraph_options.callable_options.feed_size() != if (subgraph_options.callable_options.feed_size() !=
client_graph->feed_types.size()) { client_graph->feed_types.size()) {

View File

@ -117,9 +117,6 @@ class DirectSession : public Session {
::tensorflow::Status ReleaseCallable(CallableHandle handle) override; ::tensorflow::Status ReleaseCallable(CallableHandle handle) override;
private: private:
// For access to collective_graph_key_.
friend class DirectSessionCollectiveTest;
// We create one executor and its dependent library runtime for // We create one executor and its dependent library runtime for
// every partition. // every partition.
struct PerPartitionExecutorsAndLib { struct PerPartitionExecutorsAndLib {
@ -153,8 +150,6 @@ class DirectSession : public Session {
DataTypeVector output_types; DataTypeVector output_types;
CallableOptions callable_options; CallableOptions callable_options;
int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
}; };
// A FunctionInfo object is created for every unique set of feeds/fetches. // A FunctionInfo object is created for every unique set of feeds/fetches.
@ -208,7 +203,6 @@ class DirectSession : public Session {
string handle; string handle;
std::unique_ptr<Graph> graph; std::unique_ptr<Graph> graph;
const DebugOptions& debug_options; const DebugOptions& debug_options;
int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
}; };
// Initializes the base execution state given the 'graph', // Initializes the base execution state given the 'graph',
@ -240,7 +234,7 @@ class DirectSession : public Session {
std::unordered_map<string, std::unique_ptr<Graph>>* outputs, std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
std::unique_ptr<FunctionLibraryDefinition>* flib_def, std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args, DataTypeVector* input_types, RunStateArgs* run_state_args, DataTypeVector* input_types,
DataTypeVector* output_types, int64* collective_graph_key); DataTypeVector* output_types);
::tensorflow::Status RunInternal(int64 step_id, const RunOptions& run_options, ::tensorflow::Status RunInternal(int64 step_id, const RunOptions& run_options,
CallFrameInterface* call_frame, CallFrameInterface* call_frame,
@ -397,9 +391,6 @@ class DirectSession : public Session {
Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr; Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
// For testing collective graph key generation.
int64 collective_graph_key_ = -1;
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession); TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
// EXPERIMENTAL: debugger (tfdbg) related // EXPERIMENTAL: debugger (tfdbg) related

View File

@ -2218,118 +2218,4 @@ BENCHMARK(BM_FeedFetch)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
BENCHMARK(BM_FeedFetchCallable)->Arg(1)->Arg(2)->Arg(5)->Arg(10); BENCHMARK(BM_FeedFetchCallable)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
} // namespace } // namespace
class DirectSessionCollectiveTest : public ::testing::Test {
public:
// Creates a graph with CollectiveOps inside functions and runs it. Returns
// the generated collective_graph_key.
Status RunGraphWithCollectiveFunctions(bool add_unused_function,
int64* collective_graph_key) {
GraphDef g = CreateGraph(add_unused_function);
const Tensor t1 =
test::AsTensor<float>({0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1});
const Tensor t2 =
test::AsTensor<float>({0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3});
auto session = CreateSession();
TF_RETURN_IF_ERROR(session->Create(g));
std::vector<Tensor> outputs;
TF_RETURN_IF_ERROR(
session->Run({{"input1:0", t1}, {"input2:0", t2}}, {},
{"collective_call1:0", "collective_call2:0"}, &outputs));
DirectSession* direct_session = static_cast<DirectSession*>(session.get());
*collective_graph_key = direct_session->collective_graph_key_;
return Status::OK();
}
private:
// Creates a function with name `function_name` and a single CollectiveReduce
// node with instance key set as `instance_key`.
FunctionDef CollectiveFunction(const string& function_name,
int instance_key) {
return FunctionDefHelper::Define(
// Function name
function_name,
// In def
{"arg:float"},
// Out def
{"reduce:float"},
// Attr def
{},
// Node def
{{
{"reduce"},
"CollectiveReduce",
{"arg"},
{{"group_size", 2},
{"group_key", 1},
{"instance_key", instance_key},
{"subdiv_offsets", gtl::ArraySlice<int32>({0})},
{"merge_op", "Add"},
{"final_op", "Div"},
{"T", DT_FLOAT}},
}});
}
// Creates a GraphDef that adds two CollectiveFunctions, one each on CPU0 and
// CPU1, with instance_key 1, and appropriate placeholder inputs. If
// `add_unused_function` is true, adds another CollectiveFunction with
// instance_key 2 that is not invoked in the graph.
GraphDef CreateGraph(bool add_unused_function) {
GraphDef g;
FunctionDef collective_function =
CollectiveFunction("CollectiveFunction1", 1);
FunctionDefLibrary* lib = g.mutable_library();
*lib->add_function() = collective_function;
if (add_unused_function) {
FunctionDef unused_function =
CollectiveFunction("CollectiveFunction2", 2);
*lib->add_function() = unused_function;
}
// Inputs.
AttrValue dtype_attr;
SetAttrValue(DT_FLOAT, &dtype_attr);
NodeDef input1;
input1.set_name("input1");
input1.set_op("Placeholder");
input1.mutable_attr()->insert({"dtype", dtype_attr});
NodeDef input2;
input2.set_name("input2");
input2.set_op("Placeholder");
input2.mutable_attr()->insert({"dtype", dtype_attr});
// CollectiveReduce on CPU0 with instance_key 1.
NodeDef collective_call1;
collective_call1.set_name("collective_call1");
collective_call1.set_op("CollectiveFunction1");
collective_call1.add_input("input1");
collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:0");
// CollectiveReduce on CPU1 with instance_key 1.
NodeDef collective_call2;
collective_call2.set_name("collective_call2");
collective_call2.set_op("CollectiveFunction1");
collective_call2.add_input("input2");
collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:1");
*g.add_node() = input1;
*g.add_node() = input2;
*g.add_node() = collective_call1;
*g.add_node() = collective_call2;
return g;
}
};
#ifndef GOOGLE_CUDA
// TODO(ayushd): enable this test for GPU builds.
TEST_F(DirectSessionCollectiveTest,
TestCollectiveGraphKeyUsesOnlyCalledFunctions) {
int64 key1;
TF_ASSERT_OK(RunGraphWithCollectiveFunctions(false, &key1));
int64 key2;
TF_ASSERT_OK(RunGraphWithCollectiveFunctions(true, &key2));
ASSERT_EQ(key1, key2);
}
#endif
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_execution_state.h" #include "tensorflow/core/common_runtime/graph_execution_state.h"
#include <memory> #include <memory>
#include <set>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
@ -728,50 +727,12 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options)); OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
int64 collective_graph_key = options.collective_graph_key;
if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
// BuildGraphOptions does not specify a collective_graph_key. Check all
// nodes in the Graph and FunctionLibraryDefinition for collective ops and
// if found, initialize a collective_graph_key as a hash of the ordered set
// of instance keys.
std::set<int32> instance_key_set;
for (Node* node : optimized_graph->nodes()) {
if (node->IsCollective()) {
int32 instance_key;
TF_RETURN_IF_ERROR(
GetNodeAttr(node->attrs(), "instance_key", &instance_key));
instance_key_set.emplace(instance_key);
} else {
const FunctionDef* fdef = optimized_flib->Find(node->def().op());
if (fdef != nullptr) {
for (const NodeDef& ndef : fdef->node_def()) {
if (ndef.op() == "CollectiveReduce" ||
ndef.op() == "CollectiveBcastSend" ||
ndef.op() == "CollectiveBcastRecv") {
int32 instance_key;
TF_RETURN_IF_ERROR(
GetNodeAttr(ndef, "instance_key", &instance_key));
instance_key_set.emplace(instance_key);
}
}
}
}
}
if (!instance_key_set.empty()) {
uint64 hash = 0x8774aa605c729c72ULL;
for (int32 instance_key : instance_key_set) {
hash = Hash64Combine(instance_key, hash);
}
collective_graph_key = hash;
}
}
// Copy the extracted graph in order to make its node ids dense, // Copy the extracted graph in order to make its node ids dense,
// since the local CostModel used to record its stats is sized by // since the local CostModel used to record its stats is sized by
// the largest node id. // the largest node id.
std::unique_ptr<ClientGraph> dense_copy( std::unique_ptr<ClientGraph> dense_copy(
new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types, new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types,
rewrite_metadata.fetch_types, collective_graph_key)); rewrite_metadata.fetch_types));
CopyGraph(*optimized_graph, &dense_copy->graph); CopyGraph(*optimized_graph, &dense_copy->graph);
// TODO(vrv): We should check invariants of the graph here. // TODO(vrv): We should check invariants of the graph here.

View File

@ -50,20 +50,17 @@ struct GraphExecutionStateOptions {
// BuildGraphOptions. // BuildGraphOptions.
struct ClientGraph { struct ClientGraph {
explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib, explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
DataTypeVector feed_types, DataTypeVector fetch_types, DataTypeVector feed_types, DataTypeVector fetch_types)
int64 collective_graph_key)
: flib_def(std::move(flib)), : flib_def(std::move(flib)),
graph(flib_def.get()), graph(flib_def.get()),
feed_types(std::move(feed_types)), feed_types(std::move(feed_types)),
fetch_types(std::move(fetch_types)), fetch_types(std::move(fetch_types)) {}
collective_graph_key(collective_graph_key) {}
// Each client-graph gets its own function library since optimization passes // Each client-graph gets its own function library since optimization passes
// post rewrite for execution might want to introduce new functions. // post rewrite for execution might want to introduce new functions.
std::unique_ptr<FunctionLibraryDefinition> flib_def; std::unique_ptr<FunctionLibraryDefinition> flib_def;
Graph graph; Graph graph;
DataTypeVector feed_types; DataTypeVector feed_types;
DataTypeVector fetch_types; DataTypeVector fetch_types;
int64 collective_graph_key;
}; };
// GraphExecutionState is responsible for generating an // GraphExecutionState is responsible for generating an

View File

@ -449,7 +449,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
*c->req.mutable_graph_options() = session_opts_.config.graph_options(); *c->req.mutable_graph_options() = session_opts_.config.graph_options();
*c->req.mutable_debug_options() = *c->req.mutable_debug_options() =
callable_opts_.run_options().debug_options(); callable_opts_.run_options().debug_options();
c->req.set_collective_graph_key(client_graph()->collective_graph_key); c->req.set_collective_graph_key(bg_opts_.collective_graph_key);
VLOG(2) << "Register " << c->req.graph_def().DebugString(); VLOG(2) << "Register " << c->req.graph_def().DebugString();
auto cb = [c, &done](const Status& s) { auto cb = [c, &done](const Status& s) {
c->status = s; c->status = s;
@ -1111,6 +1111,10 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
h = Hash64(watch_summary.c_str(), watch_summary.size(), h); h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
} }
if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
h = Hash64Combine(opts.collective_graph_key, h);
}
return h; return h;
} }
@ -1784,10 +1788,10 @@ Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
Status s = run_status; Status s = run_status;
if (s.ok()) { if (s.ok()) {
pss->end_micros = Env::Default()->NowMicros(); pss->end_micros = Env::Default()->NowMicros();
if (rcg->client_graph()->collective_graph_key != if (rcg->build_graph_options().collective_graph_key !=
BuildGraphOptions::kNoCollectiveGraphKey) { BuildGraphOptions::kNoCollectiveGraphKey) {
env_->collective_executor_mgr->RetireStepId( env_->collective_executor_mgr->RetireStepId(
rcg->client_graph()->collective_graph_key, step_id); rcg->build_graph_options().collective_graph_key, step_id);
} }
// Schedule post-processing and cleanup to be done asynchronously. // Schedule post-processing and cleanup to be done asynchronously.
rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata); rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
@ -1846,7 +1850,7 @@ Status MasterSession::DoRunWithLocalExecution(
// Keeps the highest 8 bits 0x01: we reserve some bits of the // Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use. // step_id for future use.
uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key); uint64 step_id = NewStepId(bgopts.collective_graph_key);
TRACEPRINTF("stepid %llu", step_id); TRACEPRINTF("stepid %llu", step_id);
std::unique_ptr<ProfileHandler> ph; std::unique_ptr<ProfileHandler> ph;
@ -1910,7 +1914,8 @@ Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
// Prepare. // Prepare.
int64 count = rcg->get_and_increment_execution_count(); int64 count = rcg->get_and_increment_execution_count();
const uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key); const uint64 step_id =
NewStepId(rcg->build_graph_options().collective_graph_key);
TRACEPRINTF("stepid %llu", step_id); TRACEPRINTF("stepid %llu", step_id);
const RunOptions& run_options = rcg->callable_options().run_options(); const RunOptions& run_options = rcg->callable_options().run_options();

View File

@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class CollectiveOpTest(test.TestCase): class CollectiveOpTest(test.TestCase):
def _testCollectiveReduce(self, t0, t1, expected, set_graph_key): def _testCollectiveReduce(self, t0, t1, expected):
group_key = 1 group_key = 1
instance_key = 1 instance_key = 1
with self.test_session( with self.test_session(
@ -43,7 +43,6 @@ class CollectiveOpTest(test.TestCase):
colred1 = collective_ops.all_reduce(in1, 2, group_key, instance_key, colred1 = collective_ops.all_reduce(in1, 2, group_key, instance_key,
'Add', 'Div') 'Add', 'Div')
run_options = config_pb2.RunOptions() run_options = config_pb2.RunOptions()
if set_graph_key:
run_options.experimental.collective_graph_key = 1 run_options.experimental.collective_graph_key = 1
results = sess.run([colred0, colred1], options=run_options) results = sess.run([colred0, colred1], options=run_options)
self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5) self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
@ -52,15 +51,10 @@ class CollectiveOpTest(test.TestCase):
def testCollectiveReduce(self): def testCollectiveReduce(self):
self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1], self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3], [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], True) [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2])
def testCollectiveAutoGraphKey(self):
self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], False)
def testCollectiveReduceScalar(self): def testCollectiveReduceScalar(self):
self._testCollectiveReduce(0.1, 0.3, 0.2, True) self._testCollectiveReduce(0.1, 0.3, 0.2)
def _testCollectiveBroadcast(self, t0): def _testCollectiveBroadcast(self, t0):
group_key = 1 group_key = 1