Set collective_graph_key when there're v2 collective ops present
Otherwise there won't be a collective executor. PiperOrigin-RevId: 354133200 Change-Id: Icb99d7570b9a380c18cdf18c754836a496d4ffc1
This commit is contained in:
parent
3db793ee03
commit
06ab7dc374
@ -62,6 +62,13 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
bool IsCollectiveV2(const string& op) {
|
||||
return op == "CollectiveReduceV2" || op == "CollectiveGatherV2" ||
|
||||
op == "CollectiveBcastRecvV2" || op == "CollectiveBcastSendV2";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
GraphExecutionState::GraphExecutionState(
|
||||
std::unique_ptr<GraphDef>&& graph_def,
|
||||
std::unique_ptr<FunctionLibraryDefinition>&& flib_def,
|
||||
@ -898,12 +905,15 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
|
||||
// if found, initialize a collective_graph_key as a hash of the ordered set
|
||||
// of instance keys.
|
||||
std::set<int32> instance_key_set;
|
||||
bool has_collective_v2 = false;
|
||||
for (Node* node : optimized_graph->nodes()) {
|
||||
if (node->IsCollective()) {
|
||||
int32 instance_key;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(node->attrs(), "instance_key", &instance_key));
|
||||
instance_key_set.emplace(instance_key);
|
||||
} else if (IsCollectiveV2(node->type_string())) {
|
||||
has_collective_v2 = true;
|
||||
} else {
|
||||
const FunctionDef* fdef = optimized_flib->Find(node->def().op());
|
||||
if (fdef != nullptr) {
|
||||
@ -916,6 +926,8 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(ndef, "instance_key", &instance_key));
|
||||
instance_key_set.emplace(instance_key);
|
||||
} else if (IsCollectiveV2(ndef.op())) {
|
||||
has_collective_v2 = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -927,6 +939,8 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
|
||||
hash = Hash64Combine(instance_key, hash);
|
||||
}
|
||||
collective_graph_key = hash;
|
||||
} else if (has_collective_v2) {
|
||||
collective_graph_key = 0x8774aa605c729c72ULL;
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user