parent
348367a88e
commit
cb94438312
@ -4078,7 +4078,6 @@ tf_cuda_cc_test(
|
|||||||
":testlib",
|
":testlib",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
"//tensorflow/cc:cc_ops",
|
"//tensorflow/cc:cc_ops",
|
||||||
"//tensorflow/core/kernels:collective_ops",
|
|
||||||
"//tensorflow/core/kernels:control_flow_ops",
|
"//tensorflow/core/kernels:control_flow_ops",
|
||||||
"//tensorflow/core/kernels:cwise_op",
|
"//tensorflow/core/kernels:cwise_op",
|
||||||
"//tensorflow/core/kernels:dense_update_ops",
|
"//tensorflow/core/kernels:dense_update_ops",
|
||||||
@ -4120,7 +4119,6 @@ tf_cc_test(
|
|||||||
"//tensorflow/cc:cc_ops",
|
"//tensorflow/cc:cc_ops",
|
||||||
# Link with support for TensorFlow Debugger (tfdbg).
|
# Link with support for TensorFlow Debugger (tfdbg).
|
||||||
"//tensorflow/core/debug",
|
"//tensorflow/core/debug",
|
||||||
"//tensorflow/core/kernels:collective_ops",
|
|
||||||
"//tensorflow/core/kernels:control_flow_ops",
|
"//tensorflow/core/kernels:control_flow_ops",
|
||||||
"//tensorflow/core/kernels:cwise_op",
|
"//tensorflow/core/kernels:cwise_op",
|
||||||
"//tensorflow/core/kernels:dense_update_ops",
|
"//tensorflow/core/kernels:dense_update_ops",
|
||||||
|
@ -451,22 +451,8 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
|
|||||||
RunState run_state(step_id, &devices_);
|
RunState run_state(step_id, &devices_);
|
||||||
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
|
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
|
||||||
#ifndef __ANDROID__
|
#ifndef __ANDROID__
|
||||||
// Set up for collectives if ExecutorsAndKeys declares a key.
|
// Set up for collectives if the RunOption declares a key.
|
||||||
if (executors_and_keys->collective_graph_key !=
|
if (run_options.experimental().collective_graph_key() > 0) {
|
||||||
BuildGraphOptions::kNoCollectiveGraphKey) {
|
|
||||||
if (run_options.experimental().collective_graph_key() !=
|
|
||||||
BuildGraphOptions::kNoCollectiveGraphKey) {
|
|
||||||
// If a collective_graph_key was specified in run_options, ensure that it
|
|
||||||
// matches what came out of GraphExecutionState::BuildGraph().
|
|
||||||
if (run_options.experimental().collective_graph_key() !=
|
|
||||||
executors_and_keys->collective_graph_key) {
|
|
||||||
return errors::Internal(
|
|
||||||
"collective_graph_key in RunOptions ",
|
|
||||||
run_options.experimental().collective_graph_key(),
|
|
||||||
" should match collective_graph_key from optimized graph ",
|
|
||||||
executors_and_keys->collective_graph_key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!collective_executor_mgr_) {
|
if (!collective_executor_mgr_) {
|
||||||
std::unique_ptr<DeviceResolverInterface> drl(
|
std::unique_ptr<DeviceResolverInterface> drl(
|
||||||
new DeviceResolverLocal(device_mgr_.get()));
|
new DeviceResolverLocal(device_mgr_.get()));
|
||||||
@ -692,13 +678,10 @@ Status DirectSession::Run(const RunOptions& run_options,
|
|||||||
// Check if we already have an executor for these arguments.
|
// Check if we already have an executor for these arguments.
|
||||||
ExecutorsAndKeys* executors_and_keys;
|
ExecutorsAndKeys* executors_and_keys;
|
||||||
RunStateArgs run_state_args(run_options.debug_options());
|
RunStateArgs run_state_args(run_options.debug_options());
|
||||||
run_state_args.collective_graph_key =
|
|
||||||
run_options.experimental().collective_graph_key();
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
|
TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
|
||||||
target_nodes, &executors_and_keys,
|
target_nodes, &executors_and_keys,
|
||||||
&run_state_args));
|
&run_state_args));
|
||||||
collective_graph_key_ = executors_and_keys->collective_graph_key;
|
|
||||||
|
|
||||||
// Configure a call frame for the step, which we use to feed and
|
// Configure a call frame for the step, which we use to feed and
|
||||||
// fetch values to and from the executors.
|
// fetch values to and from the executors.
|
||||||
@ -1133,8 +1116,6 @@ Status DirectSession::CreateExecutors(
|
|||||||
BuildGraphOptions options;
|
BuildGraphOptions options;
|
||||||
options.callable_options = callable_options;
|
options.callable_options = callable_options;
|
||||||
options.use_function_convention = !run_state_args->is_partial_run;
|
options.use_function_convention = !run_state_args->is_partial_run;
|
||||||
options.collective_graph_key =
|
|
||||||
callable_options.run_options().experimental().collective_graph_key();
|
|
||||||
|
|
||||||
std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
|
std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
|
||||||
std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
|
std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
|
||||||
@ -1142,9 +1123,9 @@ Status DirectSession::CreateExecutors(
|
|||||||
ek->callable_options = callable_options;
|
ek->callable_options = callable_options;
|
||||||
|
|
||||||
std::unordered_map<string, std::unique_ptr<Graph>> graphs;
|
std::unordered_map<string, std::unique_ptr<Graph>> graphs;
|
||||||
TF_RETURN_IF_ERROR(CreateGraphs(
|
TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &func_info->flib_def,
|
||||||
options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types,
|
run_state_args, &ek->input_types,
|
||||||
&ek->output_types, &ek->collective_graph_key));
|
&ek->output_types));
|
||||||
|
|
||||||
if (run_state_args->is_partial_run) {
|
if (run_state_args->is_partial_run) {
|
||||||
ek->graph = std::move(run_state_args->graph);
|
ek->graph = std::move(run_state_args->graph);
|
||||||
@ -1372,9 +1353,6 @@ Status DirectSession::GetOrCreateExecutors(
|
|||||||
}
|
}
|
||||||
*callable_options.mutable_run_options()->mutable_debug_options() =
|
*callable_options.mutable_run_options()->mutable_debug_options() =
|
||||||
run_state_args->debug_options;
|
run_state_args->debug_options;
|
||||||
callable_options.mutable_run_options()
|
|
||||||
->mutable_experimental()
|
|
||||||
->set_collective_graph_key(run_state_args->collective_graph_key);
|
|
||||||
std::unique_ptr<ExecutorsAndKeys> ek;
|
std::unique_ptr<ExecutorsAndKeys> ek;
|
||||||
std::unique_ptr<FunctionInfo> func_info;
|
std::unique_ptr<FunctionInfo> func_info;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -1401,7 +1379,7 @@ Status DirectSession::CreateGraphs(
|
|||||||
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
|
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
|
||||||
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
||||||
RunStateArgs* run_state_args, DataTypeVector* input_types,
|
RunStateArgs* run_state_args, DataTypeVector* input_types,
|
||||||
DataTypeVector* output_types, int64* collective_graph_key) {
|
DataTypeVector* output_types) {
|
||||||
mutex_lock l(graph_def_lock_);
|
mutex_lock l(graph_def_lock_);
|
||||||
std::unique_ptr<ClientGraph> client_graph;
|
std::unique_ptr<ClientGraph> client_graph;
|
||||||
|
|
||||||
@ -1425,7 +1403,6 @@ Status DirectSession::CreateGraphs(
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
execution_state->BuildGraph(subgraph_options, &client_graph));
|
execution_state->BuildGraph(subgraph_options, &client_graph));
|
||||||
}
|
}
|
||||||
*collective_graph_key = client_graph->collective_graph_key;
|
|
||||||
|
|
||||||
if (subgraph_options.callable_options.feed_size() !=
|
if (subgraph_options.callable_options.feed_size() !=
|
||||||
client_graph->feed_types.size()) {
|
client_graph->feed_types.size()) {
|
||||||
|
@ -117,9 +117,6 @@ class DirectSession : public Session {
|
|||||||
::tensorflow::Status ReleaseCallable(CallableHandle handle) override;
|
::tensorflow::Status ReleaseCallable(CallableHandle handle) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// For access to collective_graph_key_.
|
|
||||||
friend class DirectSessionCollectiveTest;
|
|
||||||
|
|
||||||
// We create one executor and its dependent library runtime for
|
// We create one executor and its dependent library runtime for
|
||||||
// every partition.
|
// every partition.
|
||||||
struct PerPartitionExecutorsAndLib {
|
struct PerPartitionExecutorsAndLib {
|
||||||
@ -153,8 +150,6 @@ class DirectSession : public Session {
|
|||||||
DataTypeVector output_types;
|
DataTypeVector output_types;
|
||||||
|
|
||||||
CallableOptions callable_options;
|
CallableOptions callable_options;
|
||||||
|
|
||||||
int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// A FunctionInfo object is created for every unique set of feeds/fetches.
|
// A FunctionInfo object is created for every unique set of feeds/fetches.
|
||||||
@ -208,7 +203,6 @@ class DirectSession : public Session {
|
|||||||
string handle;
|
string handle;
|
||||||
std::unique_ptr<Graph> graph;
|
std::unique_ptr<Graph> graph;
|
||||||
const DebugOptions& debug_options;
|
const DebugOptions& debug_options;
|
||||||
int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initializes the base execution state given the 'graph',
|
// Initializes the base execution state given the 'graph',
|
||||||
@ -240,7 +234,7 @@ class DirectSession : public Session {
|
|||||||
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
|
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
|
||||||
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
||||||
RunStateArgs* run_state_args, DataTypeVector* input_types,
|
RunStateArgs* run_state_args, DataTypeVector* input_types,
|
||||||
DataTypeVector* output_types, int64* collective_graph_key);
|
DataTypeVector* output_types);
|
||||||
|
|
||||||
::tensorflow::Status RunInternal(int64 step_id, const RunOptions& run_options,
|
::tensorflow::Status RunInternal(int64 step_id, const RunOptions& run_options,
|
||||||
CallFrameInterface* call_frame,
|
CallFrameInterface* call_frame,
|
||||||
@ -397,9 +391,6 @@ class DirectSession : public Session {
|
|||||||
|
|
||||||
Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
|
Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
|
||||||
|
|
||||||
// For testing collective graph key generation.
|
|
||||||
int64 collective_graph_key_ = -1;
|
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
|
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
|
||||||
|
|
||||||
// EXPERIMENTAL: debugger (tfdbg) related
|
// EXPERIMENTAL: debugger (tfdbg) related
|
||||||
|
@ -2218,118 +2218,4 @@ BENCHMARK(BM_FeedFetch)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
|
|||||||
BENCHMARK(BM_FeedFetchCallable)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
|
BENCHMARK(BM_FeedFetchCallable)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class DirectSessionCollectiveTest : public ::testing::Test {
|
|
||||||
public:
|
|
||||||
// Creates a graph with CollectiveOps inside functions and runs it. Returns
|
|
||||||
// the generated collective_graph_key.
|
|
||||||
Status RunGraphWithCollectiveFunctions(bool add_unused_function,
|
|
||||||
int64* collective_graph_key) {
|
|
||||||
GraphDef g = CreateGraph(add_unused_function);
|
|
||||||
const Tensor t1 =
|
|
||||||
test::AsTensor<float>({0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1});
|
|
||||||
const Tensor t2 =
|
|
||||||
test::AsTensor<float>({0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3});
|
|
||||||
auto session = CreateSession();
|
|
||||||
TF_RETURN_IF_ERROR(session->Create(g));
|
|
||||||
std::vector<Tensor> outputs;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
session->Run({{"input1:0", t1}, {"input2:0", t2}}, {},
|
|
||||||
{"collective_call1:0", "collective_call2:0"}, &outputs));
|
|
||||||
DirectSession* direct_session = static_cast<DirectSession*>(session.get());
|
|
||||||
*collective_graph_key = direct_session->collective_graph_key_;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Creates a function with name `function_name` and a single CollectiveReduce
|
|
||||||
// node with instance key set as `instance_key`.
|
|
||||||
FunctionDef CollectiveFunction(const string& function_name,
|
|
||||||
int instance_key) {
|
|
||||||
return FunctionDefHelper::Define(
|
|
||||||
// Function name
|
|
||||||
function_name,
|
|
||||||
// In def
|
|
||||||
{"arg:float"},
|
|
||||||
// Out def
|
|
||||||
{"reduce:float"},
|
|
||||||
// Attr def
|
|
||||||
{},
|
|
||||||
// Node def
|
|
||||||
{{
|
|
||||||
{"reduce"},
|
|
||||||
"CollectiveReduce",
|
|
||||||
{"arg"},
|
|
||||||
{{"group_size", 2},
|
|
||||||
{"group_key", 1},
|
|
||||||
{"instance_key", instance_key},
|
|
||||||
{"subdiv_offsets", gtl::ArraySlice<int32>({0})},
|
|
||||||
{"merge_op", "Add"},
|
|
||||||
{"final_op", "Div"},
|
|
||||||
{"T", DT_FLOAT}},
|
|
||||||
}});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates a GraphDef that adds two CollectiveFunctions, one each on CPU0 and
|
|
||||||
// CPU1, with instance_key 1, and appropriate placeholder inputs. If
|
|
||||||
// `add_unused_function` is true, adds another CollectiveFunction with
|
|
||||||
// instance_key 2 that is not invoked in the graph.
|
|
||||||
GraphDef CreateGraph(bool add_unused_function) {
|
|
||||||
GraphDef g;
|
|
||||||
FunctionDef collective_function =
|
|
||||||
CollectiveFunction("CollectiveFunction1", 1);
|
|
||||||
FunctionDefLibrary* lib = g.mutable_library();
|
|
||||||
*lib->add_function() = collective_function;
|
|
||||||
if (add_unused_function) {
|
|
||||||
FunctionDef unused_function =
|
|
||||||
CollectiveFunction("CollectiveFunction2", 2);
|
|
||||||
*lib->add_function() = unused_function;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Inputs.
|
|
||||||
AttrValue dtype_attr;
|
|
||||||
SetAttrValue(DT_FLOAT, &dtype_attr);
|
|
||||||
NodeDef input1;
|
|
||||||
input1.set_name("input1");
|
|
||||||
input1.set_op("Placeholder");
|
|
||||||
input1.mutable_attr()->insert({"dtype", dtype_attr});
|
|
||||||
NodeDef input2;
|
|
||||||
input2.set_name("input2");
|
|
||||||
input2.set_op("Placeholder");
|
|
||||||
input2.mutable_attr()->insert({"dtype", dtype_attr});
|
|
||||||
|
|
||||||
// CollectiveReduce on CPU0 with instance_key 1.
|
|
||||||
NodeDef collective_call1;
|
|
||||||
collective_call1.set_name("collective_call1");
|
|
||||||
collective_call1.set_op("CollectiveFunction1");
|
|
||||||
collective_call1.add_input("input1");
|
|
||||||
collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:0");
|
|
||||||
// CollectiveReduce on CPU1 with instance_key 1.
|
|
||||||
NodeDef collective_call2;
|
|
||||||
collective_call2.set_name("collective_call2");
|
|
||||||
collective_call2.set_op("CollectiveFunction1");
|
|
||||||
collective_call2.add_input("input2");
|
|
||||||
collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:1");
|
|
||||||
|
|
||||||
*g.add_node() = input1;
|
|
||||||
*g.add_node() = input2;
|
|
||||||
*g.add_node() = collective_call1;
|
|
||||||
*g.add_node() = collective_call2;
|
|
||||||
|
|
||||||
return g;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#ifndef GOOGLE_CUDA
|
|
||||||
// TODO(ayushd): enable this test for GPU builds.
|
|
||||||
TEST_F(DirectSessionCollectiveTest,
|
|
||||||
TestCollectiveGraphKeyUsesOnlyCalledFunctions) {
|
|
||||||
int64 key1;
|
|
||||||
TF_ASSERT_OK(RunGraphWithCollectiveFunctions(false, &key1));
|
|
||||||
int64 key2;
|
|
||||||
TF_ASSERT_OK(RunGraphWithCollectiveFunctions(true, &key2));
|
|
||||||
ASSERT_EQ(key1, key2);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/graph_execution_state.h"
|
#include "tensorflow/core/common_runtime/graph_execution_state.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <set>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -728,50 +727,12 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
|
|||||||
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
|
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
|
||||||
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
|
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
|
||||||
|
|
||||||
int64 collective_graph_key = options.collective_graph_key;
|
|
||||||
if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
|
|
||||||
// BuildGraphOptions does not specify a collective_graph_key. Check all
|
|
||||||
// nodes in the Graph and FunctionLibraryDefinition for collective ops and
|
|
||||||
// if found, initialize a collective_graph_key as a hash of the ordered set
|
|
||||||
// of instance keys.
|
|
||||||
std::set<int32> instance_key_set;
|
|
||||||
for (Node* node : optimized_graph->nodes()) {
|
|
||||||
if (node->IsCollective()) {
|
|
||||||
int32 instance_key;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
GetNodeAttr(node->attrs(), "instance_key", &instance_key));
|
|
||||||
instance_key_set.emplace(instance_key);
|
|
||||||
} else {
|
|
||||||
const FunctionDef* fdef = optimized_flib->Find(node->def().op());
|
|
||||||
if (fdef != nullptr) {
|
|
||||||
for (const NodeDef& ndef : fdef->node_def()) {
|
|
||||||
if (ndef.op() == "CollectiveReduce" ||
|
|
||||||
ndef.op() == "CollectiveBcastSend" ||
|
|
||||||
ndef.op() == "CollectiveBcastRecv") {
|
|
||||||
int32 instance_key;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
GetNodeAttr(ndef, "instance_key", &instance_key));
|
|
||||||
instance_key_set.emplace(instance_key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!instance_key_set.empty()) {
|
|
||||||
uint64 hash = 0x8774aa605c729c72ULL;
|
|
||||||
for (int32 instance_key : instance_key_set) {
|
|
||||||
hash = Hash64Combine(instance_key, hash);
|
|
||||||
}
|
|
||||||
collective_graph_key = hash;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy the extracted graph in order to make its node ids dense,
|
// Copy the extracted graph in order to make its node ids dense,
|
||||||
// since the local CostModel used to record its stats is sized by
|
// since the local CostModel used to record its stats is sized by
|
||||||
// the largest node id.
|
// the largest node id.
|
||||||
std::unique_ptr<ClientGraph> dense_copy(
|
std::unique_ptr<ClientGraph> dense_copy(
|
||||||
new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types,
|
new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types,
|
||||||
rewrite_metadata.fetch_types, collective_graph_key));
|
rewrite_metadata.fetch_types));
|
||||||
CopyGraph(*optimized_graph, &dense_copy->graph);
|
CopyGraph(*optimized_graph, &dense_copy->graph);
|
||||||
|
|
||||||
// TODO(vrv): We should check invariants of the graph here.
|
// TODO(vrv): We should check invariants of the graph here.
|
||||||
|
@ -50,20 +50,17 @@ struct GraphExecutionStateOptions {
|
|||||||
// BuildGraphOptions.
|
// BuildGraphOptions.
|
||||||
struct ClientGraph {
|
struct ClientGraph {
|
||||||
explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
|
explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
|
||||||
DataTypeVector feed_types, DataTypeVector fetch_types,
|
DataTypeVector feed_types, DataTypeVector fetch_types)
|
||||||
int64 collective_graph_key)
|
|
||||||
: flib_def(std::move(flib)),
|
: flib_def(std::move(flib)),
|
||||||
graph(flib_def.get()),
|
graph(flib_def.get()),
|
||||||
feed_types(std::move(feed_types)),
|
feed_types(std::move(feed_types)),
|
||||||
fetch_types(std::move(fetch_types)),
|
fetch_types(std::move(fetch_types)) {}
|
||||||
collective_graph_key(collective_graph_key) {}
|
|
||||||
// Each client-graph gets its own function library since optimization passes
|
// Each client-graph gets its own function library since optimization passes
|
||||||
// post rewrite for execution might want to introduce new functions.
|
// post rewrite for execution might want to introduce new functions.
|
||||||
std::unique_ptr<FunctionLibraryDefinition> flib_def;
|
std::unique_ptr<FunctionLibraryDefinition> flib_def;
|
||||||
Graph graph;
|
Graph graph;
|
||||||
DataTypeVector feed_types;
|
DataTypeVector feed_types;
|
||||||
DataTypeVector fetch_types;
|
DataTypeVector fetch_types;
|
||||||
int64 collective_graph_key;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// GraphExecutionState is responsible for generating an
|
// GraphExecutionState is responsible for generating an
|
||||||
|
@ -449,7 +449,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
|
|||||||
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
|
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
|
||||||
*c->req.mutable_debug_options() =
|
*c->req.mutable_debug_options() =
|
||||||
callable_opts_.run_options().debug_options();
|
callable_opts_.run_options().debug_options();
|
||||||
c->req.set_collective_graph_key(client_graph()->collective_graph_key);
|
c->req.set_collective_graph_key(bg_opts_.collective_graph_key);
|
||||||
VLOG(2) << "Register " << c->req.graph_def().DebugString();
|
VLOG(2) << "Register " << c->req.graph_def().DebugString();
|
||||||
auto cb = [c, &done](const Status& s) {
|
auto cb = [c, &done](const Status& s) {
|
||||||
c->status = s;
|
c->status = s;
|
||||||
@ -1111,6 +1111,10 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
|
|||||||
h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
|
h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
|
||||||
|
h = Hash64Combine(opts.collective_graph_key, h);
|
||||||
|
}
|
||||||
|
|
||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1784,10 +1788,10 @@ Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
|
|||||||
Status s = run_status;
|
Status s = run_status;
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
pss->end_micros = Env::Default()->NowMicros();
|
pss->end_micros = Env::Default()->NowMicros();
|
||||||
if (rcg->client_graph()->collective_graph_key !=
|
if (rcg->build_graph_options().collective_graph_key !=
|
||||||
BuildGraphOptions::kNoCollectiveGraphKey) {
|
BuildGraphOptions::kNoCollectiveGraphKey) {
|
||||||
env_->collective_executor_mgr->RetireStepId(
|
env_->collective_executor_mgr->RetireStepId(
|
||||||
rcg->client_graph()->collective_graph_key, step_id);
|
rcg->build_graph_options().collective_graph_key, step_id);
|
||||||
}
|
}
|
||||||
// Schedule post-processing and cleanup to be done asynchronously.
|
// Schedule post-processing and cleanup to be done asynchronously.
|
||||||
rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
|
rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
|
||||||
@ -1846,7 +1850,7 @@ Status MasterSession::DoRunWithLocalExecution(
|
|||||||
|
|
||||||
// Keeps the highest 8 bits 0x01: we reserve some bits of the
|
// Keeps the highest 8 bits 0x01: we reserve some bits of the
|
||||||
// step_id for future use.
|
// step_id for future use.
|
||||||
uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key);
|
uint64 step_id = NewStepId(bgopts.collective_graph_key);
|
||||||
TRACEPRINTF("stepid %llu", step_id);
|
TRACEPRINTF("stepid %llu", step_id);
|
||||||
|
|
||||||
std::unique_ptr<ProfileHandler> ph;
|
std::unique_ptr<ProfileHandler> ph;
|
||||||
@ -1910,7 +1914,8 @@ Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
|
|||||||
// Prepare.
|
// Prepare.
|
||||||
int64 count = rcg->get_and_increment_execution_count();
|
int64 count = rcg->get_and_increment_execution_count();
|
||||||
|
|
||||||
const uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key);
|
const uint64 step_id =
|
||||||
|
NewStepId(rcg->build_graph_options().collective_graph_key);
|
||||||
TRACEPRINTF("stepid %llu", step_id);
|
TRACEPRINTF("stepid %llu", step_id);
|
||||||
|
|
||||||
const RunOptions& run_options = rcg->callable_options().run_options();
|
const RunOptions& run_options = rcg->callable_options().run_options();
|
||||||
|
@ -29,7 +29,7 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
class CollectiveOpTest(test.TestCase):
|
class CollectiveOpTest(test.TestCase):
|
||||||
|
|
||||||
def _testCollectiveReduce(self, t0, t1, expected, set_graph_key):
|
def _testCollectiveReduce(self, t0, t1, expected):
|
||||||
group_key = 1
|
group_key = 1
|
||||||
instance_key = 1
|
instance_key = 1
|
||||||
with self.test_session(
|
with self.test_session(
|
||||||
@ -43,7 +43,6 @@ class CollectiveOpTest(test.TestCase):
|
|||||||
colred1 = collective_ops.all_reduce(in1, 2, group_key, instance_key,
|
colred1 = collective_ops.all_reduce(in1, 2, group_key, instance_key,
|
||||||
'Add', 'Div')
|
'Add', 'Div')
|
||||||
run_options = config_pb2.RunOptions()
|
run_options = config_pb2.RunOptions()
|
||||||
if set_graph_key:
|
|
||||||
run_options.experimental.collective_graph_key = 1
|
run_options.experimental.collective_graph_key = 1
|
||||||
results = sess.run([colred0, colred1], options=run_options)
|
results = sess.run([colred0, colred1], options=run_options)
|
||||||
self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
|
self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
|
||||||
@ -52,15 +51,10 @@ class CollectiveOpTest(test.TestCase):
|
|||||||
def testCollectiveReduce(self):
|
def testCollectiveReduce(self):
|
||||||
self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
|
self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
|
||||||
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
|
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
|
||||||
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], True)
|
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2])
|
||||||
|
|
||||||
def testCollectiveAutoGraphKey(self):
|
|
||||||
self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
|
|
||||||
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
|
|
||||||
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], False)
|
|
||||||
|
|
||||||
def testCollectiveReduceScalar(self):
|
def testCollectiveReduceScalar(self):
|
||||||
self._testCollectiveReduce(0.1, 0.3, 0.2, True)
|
self._testCollectiveReduce(0.1, 0.3, 0.2)
|
||||||
|
|
||||||
def _testCollectiveBroadcast(self, t0):
|
def _testCollectiveBroadcast(self, t0):
|
||||||
group_key = 1
|
group_key = 1
|
||||||
|
Loading…
Reference in New Issue
Block a user