Set collective_graph_key when there're v2 collective ops present

Otherwise there won't be a collective executor.

PiperOrigin-RevId: 354133200
Change-Id: Icb99d7570b9a380c18cdf18c754836a496d4ffc1
This commit is contained in:
Ran Chen 2021-01-27 11:29:49 -08:00 committed by TensorFlower Gardener
parent 3db793ee03
commit 06ab7dc374

View File

@ -62,6 +62,13 @@ limitations under the License.
namespace tensorflow {
namespace {
bool IsCollectiveV2(const string& op) {
return op == "CollectiveReduceV2" || op == "CollectiveGatherV2" ||
op == "CollectiveBcastRecvV2" || op == "CollectiveBcastSendV2";
}
} // namespace
GraphExecutionState::GraphExecutionState(
std::unique_ptr<GraphDef>&& graph_def,
std::unique_ptr<FunctionLibraryDefinition>&& flib_def,
@ -898,12 +905,15 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
// if found, initialize a collective_graph_key as a hash of the ordered set
// of instance keys.
std::set<int32> instance_key_set;
bool has_collective_v2 = false;
for (Node* node : optimized_graph->nodes()) {
if (node->IsCollective()) {
int32 instance_key;
TF_RETURN_IF_ERROR(
GetNodeAttr(node->attrs(), "instance_key", &instance_key));
instance_key_set.emplace(instance_key);
} else if (IsCollectiveV2(node->type_string())) {
has_collective_v2 = true;
} else {
const FunctionDef* fdef = optimized_flib->Find(node->def().op());
if (fdef != nullptr) {
@ -916,6 +926,8 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
TF_RETURN_IF_ERROR(
GetNodeAttr(ndef, "instance_key", &instance_key));
instance_key_set.emplace(instance_key);
} else if (IsCollectiveV2(ndef.op())) {
has_collective_v2 = true;
}
}
}
@ -927,6 +939,8 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
hash = Hash64Combine(instance_key, hash);
}
collective_graph_key = hash;
} else if (has_collective_v2) {
collective_graph_key = 0x8774aa605c729c72ULL;
}
}