parent
348367a88e
commit
cb94438312
@ -4078,7 +4078,6 @@ tf_cuda_cc_test(
|
||||
":testlib",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core/kernels:collective_ops",
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:dense_update_ops",
|
||||
@ -4120,7 +4119,6 @@ tf_cc_test(
|
||||
"//tensorflow/cc:cc_ops",
|
||||
# Link with support for TensorFlow Debugger (tfdbg).
|
||||
"//tensorflow/core/debug",
|
||||
"//tensorflow/core/kernels:collective_ops",
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:dense_update_ops",
|
||||
|
@ -451,22 +451,8 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
|
||||
RunState run_state(step_id, &devices_);
|
||||
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
|
||||
#ifndef __ANDROID__
|
||||
// Set up for collectives if ExecutorsAndKeys declares a key.
|
||||
if (executors_and_keys->collective_graph_key !=
|
||||
BuildGraphOptions::kNoCollectiveGraphKey) {
|
||||
if (run_options.experimental().collective_graph_key() !=
|
||||
BuildGraphOptions::kNoCollectiveGraphKey) {
|
||||
// If a collective_graph_key was specified in run_options, ensure that it
|
||||
// matches what came out of GraphExecutionState::BuildGraph().
|
||||
if (run_options.experimental().collective_graph_key() !=
|
||||
executors_and_keys->collective_graph_key) {
|
||||
return errors::Internal(
|
||||
"collective_graph_key in RunOptions ",
|
||||
run_options.experimental().collective_graph_key(),
|
||||
" should match collective_graph_key from optimized graph ",
|
||||
executors_and_keys->collective_graph_key);
|
||||
}
|
||||
}
|
||||
// Set up for collectives if the RunOption declares a key.
|
||||
if (run_options.experimental().collective_graph_key() > 0) {
|
||||
if (!collective_executor_mgr_) {
|
||||
std::unique_ptr<DeviceResolverInterface> drl(
|
||||
new DeviceResolverLocal(device_mgr_.get()));
|
||||
@ -692,13 +678,10 @@ Status DirectSession::Run(const RunOptions& run_options,
|
||||
// Check if we already have an executor for these arguments.
|
||||
ExecutorsAndKeys* executors_and_keys;
|
||||
RunStateArgs run_state_args(run_options.debug_options());
|
||||
run_state_args.collective_graph_key =
|
||||
run_options.experimental().collective_graph_key();
|
||||
|
||||
TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
|
||||
target_nodes, &executors_and_keys,
|
||||
&run_state_args));
|
||||
collective_graph_key_ = executors_and_keys->collective_graph_key;
|
||||
|
||||
// Configure a call frame for the step, which we use to feed and
|
||||
// fetch values to and from the executors.
|
||||
@ -1133,8 +1116,6 @@ Status DirectSession::CreateExecutors(
|
||||
BuildGraphOptions options;
|
||||
options.callable_options = callable_options;
|
||||
options.use_function_convention = !run_state_args->is_partial_run;
|
||||
options.collective_graph_key =
|
||||
callable_options.run_options().experimental().collective_graph_key();
|
||||
|
||||
std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
|
||||
std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
|
||||
@ -1142,9 +1123,9 @@ Status DirectSession::CreateExecutors(
|
||||
ek->callable_options = callable_options;
|
||||
|
||||
std::unordered_map<string, std::unique_ptr<Graph>> graphs;
|
||||
TF_RETURN_IF_ERROR(CreateGraphs(
|
||||
options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types,
|
||||
&ek->output_types, &ek->collective_graph_key));
|
||||
TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &func_info->flib_def,
|
||||
run_state_args, &ek->input_types,
|
||||
&ek->output_types));
|
||||
|
||||
if (run_state_args->is_partial_run) {
|
||||
ek->graph = std::move(run_state_args->graph);
|
||||
@ -1372,9 +1353,6 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
}
|
||||
*callable_options.mutable_run_options()->mutable_debug_options() =
|
||||
run_state_args->debug_options;
|
||||
callable_options.mutable_run_options()
|
||||
->mutable_experimental()
|
||||
->set_collective_graph_key(run_state_args->collective_graph_key);
|
||||
std::unique_ptr<ExecutorsAndKeys> ek;
|
||||
std::unique_ptr<FunctionInfo> func_info;
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -1401,7 +1379,7 @@ Status DirectSession::CreateGraphs(
|
||||
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
|
||||
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
||||
RunStateArgs* run_state_args, DataTypeVector* input_types,
|
||||
DataTypeVector* output_types, int64* collective_graph_key) {
|
||||
DataTypeVector* output_types) {
|
||||
mutex_lock l(graph_def_lock_);
|
||||
std::unique_ptr<ClientGraph> client_graph;
|
||||
|
||||
@ -1425,7 +1403,6 @@ Status DirectSession::CreateGraphs(
|
||||
TF_RETURN_IF_ERROR(
|
||||
execution_state->BuildGraph(subgraph_options, &client_graph));
|
||||
}
|
||||
*collective_graph_key = client_graph->collective_graph_key;
|
||||
|
||||
if (subgraph_options.callable_options.feed_size() !=
|
||||
client_graph->feed_types.size()) {
|
||||
|
@ -117,9 +117,6 @@ class DirectSession : public Session {
|
||||
::tensorflow::Status ReleaseCallable(CallableHandle handle) override;
|
||||
|
||||
private:
|
||||
// For access to collective_graph_key_.
|
||||
friend class DirectSessionCollectiveTest;
|
||||
|
||||
// We create one executor and its dependent library runtime for
|
||||
// every partition.
|
||||
struct PerPartitionExecutorsAndLib {
|
||||
@ -153,8 +150,6 @@ class DirectSession : public Session {
|
||||
DataTypeVector output_types;
|
||||
|
||||
CallableOptions callable_options;
|
||||
|
||||
int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
|
||||
};
|
||||
|
||||
// A FunctionInfo object is created for every unique set of feeds/fetches.
|
||||
@ -208,7 +203,6 @@ class DirectSession : public Session {
|
||||
string handle;
|
||||
std::unique_ptr<Graph> graph;
|
||||
const DebugOptions& debug_options;
|
||||
int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
|
||||
};
|
||||
|
||||
// Initializes the base execution state given the 'graph',
|
||||
@ -240,7 +234,7 @@ class DirectSession : public Session {
|
||||
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
|
||||
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
||||
RunStateArgs* run_state_args, DataTypeVector* input_types,
|
||||
DataTypeVector* output_types, int64* collective_graph_key);
|
||||
DataTypeVector* output_types);
|
||||
|
||||
::tensorflow::Status RunInternal(int64 step_id, const RunOptions& run_options,
|
||||
CallFrameInterface* call_frame,
|
||||
@ -397,9 +391,6 @@ class DirectSession : public Session {
|
||||
|
||||
Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
|
||||
|
||||
// For testing collective graph key generation.
|
||||
int64 collective_graph_key_ = -1;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
|
||||
|
||||
// EXPERIMENTAL: debugger (tfdbg) related
|
||||
|
@ -2218,118 +2218,4 @@ BENCHMARK(BM_FeedFetch)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
|
||||
BENCHMARK(BM_FeedFetchCallable)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
|
||||
|
||||
} // namespace
|
||||
|
||||
class DirectSessionCollectiveTest : public ::testing::Test {
|
||||
public:
|
||||
// Creates a graph with CollectiveOps inside functions and runs it. Returns
|
||||
// the generated collective_graph_key.
|
||||
Status RunGraphWithCollectiveFunctions(bool add_unused_function,
|
||||
int64* collective_graph_key) {
|
||||
GraphDef g = CreateGraph(add_unused_function);
|
||||
const Tensor t1 =
|
||||
test::AsTensor<float>({0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1});
|
||||
const Tensor t2 =
|
||||
test::AsTensor<float>({0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3});
|
||||
auto session = CreateSession();
|
||||
TF_RETURN_IF_ERROR(session->Create(g));
|
||||
std::vector<Tensor> outputs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
session->Run({{"input1:0", t1}, {"input2:0", t2}}, {},
|
||||
{"collective_call1:0", "collective_call2:0"}, &outputs));
|
||||
DirectSession* direct_session = static_cast<DirectSession*>(session.get());
|
||||
*collective_graph_key = direct_session->collective_graph_key_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
// Creates a function with name `function_name` and a single CollectiveReduce
|
||||
// node with instance key set as `instance_key`.
|
||||
FunctionDef CollectiveFunction(const string& function_name,
|
||||
int instance_key) {
|
||||
return FunctionDefHelper::Define(
|
||||
// Function name
|
||||
function_name,
|
||||
// In def
|
||||
{"arg:float"},
|
||||
// Out def
|
||||
{"reduce:float"},
|
||||
// Attr def
|
||||
{},
|
||||
// Node def
|
||||
{{
|
||||
{"reduce"},
|
||||
"CollectiveReduce",
|
||||
{"arg"},
|
||||
{{"group_size", 2},
|
||||
{"group_key", 1},
|
||||
{"instance_key", instance_key},
|
||||
{"subdiv_offsets", gtl::ArraySlice<int32>({0})},
|
||||
{"merge_op", "Add"},
|
||||
{"final_op", "Div"},
|
||||
{"T", DT_FLOAT}},
|
||||
}});
|
||||
}
|
||||
|
||||
// Creates a GraphDef that adds two CollectiveFunctions, one each on CPU0 and
|
||||
// CPU1, with instance_key 1, and appropriate placeholder inputs. If
|
||||
// `add_unused_function` is true, adds another CollectiveFunction with
|
||||
// instance_key 2 that is not invoked in the graph.
|
||||
GraphDef CreateGraph(bool add_unused_function) {
|
||||
GraphDef g;
|
||||
FunctionDef collective_function =
|
||||
CollectiveFunction("CollectiveFunction1", 1);
|
||||
FunctionDefLibrary* lib = g.mutable_library();
|
||||
*lib->add_function() = collective_function;
|
||||
if (add_unused_function) {
|
||||
FunctionDef unused_function =
|
||||
CollectiveFunction("CollectiveFunction2", 2);
|
||||
*lib->add_function() = unused_function;
|
||||
}
|
||||
|
||||
// Inputs.
|
||||
AttrValue dtype_attr;
|
||||
SetAttrValue(DT_FLOAT, &dtype_attr);
|
||||
NodeDef input1;
|
||||
input1.set_name("input1");
|
||||
input1.set_op("Placeholder");
|
||||
input1.mutable_attr()->insert({"dtype", dtype_attr});
|
||||
NodeDef input2;
|
||||
input2.set_name("input2");
|
||||
input2.set_op("Placeholder");
|
||||
input2.mutable_attr()->insert({"dtype", dtype_attr});
|
||||
|
||||
// CollectiveReduce on CPU0 with instance_key 1.
|
||||
NodeDef collective_call1;
|
||||
collective_call1.set_name("collective_call1");
|
||||
collective_call1.set_op("CollectiveFunction1");
|
||||
collective_call1.add_input("input1");
|
||||
collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
// CollectiveReduce on CPU1 with instance_key 1.
|
||||
NodeDef collective_call2;
|
||||
collective_call2.set_name("collective_call2");
|
||||
collective_call2.set_op("CollectiveFunction1");
|
||||
collective_call2.add_input("input2");
|
||||
collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
|
||||
*g.add_node() = input1;
|
||||
*g.add_node() = input2;
|
||||
*g.add_node() = collective_call1;
|
||||
*g.add_node() = collective_call2;
|
||||
|
||||
return g;
|
||||
}
|
||||
};
|
||||
|
||||
#ifndef GOOGLE_CUDA
|
||||
// TODO(ayushd): enable this test for GPU builds.
|
||||
TEST_F(DirectSessionCollectiveTest,
|
||||
TestCollectiveGraphKeyUsesOnlyCalledFunctions) {
|
||||
int64 key1;
|
||||
TF_ASSERT_OK(RunGraphWithCollectiveFunctions(false, &key1));
|
||||
int64 key2;
|
||||
TF_ASSERT_OK(RunGraphWithCollectiveFunctions(true, &key2));
|
||||
ASSERT_EQ(key1, key2);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/graph_execution_state.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
@ -728,50 +727,12 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
|
||||
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
|
||||
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
|
||||
|
||||
int64 collective_graph_key = options.collective_graph_key;
|
||||
if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
|
||||
// BuildGraphOptions does not specify a collective_graph_key. Check all
|
||||
// nodes in the Graph and FunctionLibraryDefinition for collective ops and
|
||||
// if found, initialize a collective_graph_key as a hash of the ordered set
|
||||
// of instance keys.
|
||||
std::set<int32> instance_key_set;
|
||||
for (Node* node : optimized_graph->nodes()) {
|
||||
if (node->IsCollective()) {
|
||||
int32 instance_key;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(node->attrs(), "instance_key", &instance_key));
|
||||
instance_key_set.emplace(instance_key);
|
||||
} else {
|
||||
const FunctionDef* fdef = optimized_flib->Find(node->def().op());
|
||||
if (fdef != nullptr) {
|
||||
for (const NodeDef& ndef : fdef->node_def()) {
|
||||
if (ndef.op() == "CollectiveReduce" ||
|
||||
ndef.op() == "CollectiveBcastSend" ||
|
||||
ndef.op() == "CollectiveBcastRecv") {
|
||||
int32 instance_key;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(ndef, "instance_key", &instance_key));
|
||||
instance_key_set.emplace(instance_key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!instance_key_set.empty()) {
|
||||
uint64 hash = 0x8774aa605c729c72ULL;
|
||||
for (int32 instance_key : instance_key_set) {
|
||||
hash = Hash64Combine(instance_key, hash);
|
||||
}
|
||||
collective_graph_key = hash;
|
||||
}
|
||||
}
|
||||
|
||||
// Copy the extracted graph in order to make its node ids dense,
|
||||
// since the local CostModel used to record its stats is sized by
|
||||
// the largest node id.
|
||||
std::unique_ptr<ClientGraph> dense_copy(
|
||||
new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types,
|
||||
rewrite_metadata.fetch_types, collective_graph_key));
|
||||
rewrite_metadata.fetch_types));
|
||||
CopyGraph(*optimized_graph, &dense_copy->graph);
|
||||
|
||||
// TODO(vrv): We should check invariants of the graph here.
|
||||
|
@ -50,20 +50,17 @@ struct GraphExecutionStateOptions {
|
||||
// BuildGraphOptions.
|
||||
struct ClientGraph {
|
||||
explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
|
||||
DataTypeVector feed_types, DataTypeVector fetch_types,
|
||||
int64 collective_graph_key)
|
||||
DataTypeVector feed_types, DataTypeVector fetch_types)
|
||||
: flib_def(std::move(flib)),
|
||||
graph(flib_def.get()),
|
||||
feed_types(std::move(feed_types)),
|
||||
fetch_types(std::move(fetch_types)),
|
||||
collective_graph_key(collective_graph_key) {}
|
||||
fetch_types(std::move(fetch_types)) {}
|
||||
// Each client-graph gets its own function library since optimization passes
|
||||
// post rewrite for execution might want to introduce new functions.
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def;
|
||||
Graph graph;
|
||||
DataTypeVector feed_types;
|
||||
DataTypeVector fetch_types;
|
||||
int64 collective_graph_key;
|
||||
};
|
||||
|
||||
// GraphExecutionState is responsible for generating an
|
||||
|
@ -449,7 +449,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
|
||||
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
|
||||
*c->req.mutable_debug_options() =
|
||||
callable_opts_.run_options().debug_options();
|
||||
c->req.set_collective_graph_key(client_graph()->collective_graph_key);
|
||||
c->req.set_collective_graph_key(bg_opts_.collective_graph_key);
|
||||
VLOG(2) << "Register " << c->req.graph_def().DebugString();
|
||||
auto cb = [c, &done](const Status& s) {
|
||||
c->status = s;
|
||||
@ -1111,6 +1111,10 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
|
||||
h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
|
||||
}
|
||||
|
||||
if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
|
||||
h = Hash64Combine(opts.collective_graph_key, h);
|
||||
}
|
||||
|
||||
return h;
|
||||
}
|
||||
|
||||
@ -1784,10 +1788,10 @@ Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
|
||||
Status s = run_status;
|
||||
if (s.ok()) {
|
||||
pss->end_micros = Env::Default()->NowMicros();
|
||||
if (rcg->client_graph()->collective_graph_key !=
|
||||
if (rcg->build_graph_options().collective_graph_key !=
|
||||
BuildGraphOptions::kNoCollectiveGraphKey) {
|
||||
env_->collective_executor_mgr->RetireStepId(
|
||||
rcg->client_graph()->collective_graph_key, step_id);
|
||||
rcg->build_graph_options().collective_graph_key, step_id);
|
||||
}
|
||||
// Schedule post-processing and cleanup to be done asynchronously.
|
||||
rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
|
||||
@ -1846,7 +1850,7 @@ Status MasterSession::DoRunWithLocalExecution(
|
||||
|
||||
// Keeps the highest 8 bits 0x01: we reserve some bits of the
|
||||
// step_id for future use.
|
||||
uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key);
|
||||
uint64 step_id = NewStepId(bgopts.collective_graph_key);
|
||||
TRACEPRINTF("stepid %llu", step_id);
|
||||
|
||||
std::unique_ptr<ProfileHandler> ph;
|
||||
@ -1910,7 +1914,8 @@ Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
|
||||
// Prepare.
|
||||
int64 count = rcg->get_and_increment_execution_count();
|
||||
|
||||
const uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key);
|
||||
const uint64 step_id =
|
||||
NewStepId(rcg->build_graph_options().collective_graph_key);
|
||||
TRACEPRINTF("stepid %llu", step_id);
|
||||
|
||||
const RunOptions& run_options = rcg->callable_options().run_options();
|
||||
|
@ -29,7 +29,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
class CollectiveOpTest(test.TestCase):
|
||||
|
||||
def _testCollectiveReduce(self, t0, t1, expected, set_graph_key):
|
||||
def _testCollectiveReduce(self, t0, t1, expected):
|
||||
group_key = 1
|
||||
instance_key = 1
|
||||
with self.test_session(
|
||||
@ -43,8 +43,7 @@ class CollectiveOpTest(test.TestCase):
|
||||
colred1 = collective_ops.all_reduce(in1, 2, group_key, instance_key,
|
||||
'Add', 'Div')
|
||||
run_options = config_pb2.RunOptions()
|
||||
if set_graph_key:
|
||||
run_options.experimental.collective_graph_key = 1
|
||||
run_options.experimental.collective_graph_key = 1
|
||||
results = sess.run([colred0, colred1], options=run_options)
|
||||
self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
|
||||
self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5)
|
||||
@ -52,15 +51,10 @@ class CollectiveOpTest(test.TestCase):
|
||||
def testCollectiveReduce(self):
|
||||
self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
|
||||
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
|
||||
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], True)
|
||||
|
||||
def testCollectiveAutoGraphKey(self):
|
||||
self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
|
||||
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
|
||||
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], False)
|
||||
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2])
|
||||
|
||||
def testCollectiveReduceScalar(self):
|
||||
self._testCollectiveReduce(0.1, 0.3, 0.2, True)
|
||||
self._testCollectiveReduce(0.1, 0.3, 0.2)
|
||||
|
||||
def _testCollectiveBroadcast(self, t0):
|
||||
group_key = 1
|
||||
|
Loading…
Reference in New Issue
Block a user