1590 lines
60 KiB
C++
1590 lines
60 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
|
|
|
|
#include <algorithm>
|
|
#include <deque>
|
|
#include <stack>
|
|
#include <unordered_set>
|
|
#include <vector>
|
|
|
|
#include "absl/memory/memory.h"
|
|
#include "absl/strings/str_join.h"
|
|
#include "absl/types/optional.h"
|
|
#include "tensorflow/compiler/jit/union_find.h"
|
|
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
|
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
|
#include "tensorflow/core/common_runtime/function.h"
|
|
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
|
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
|
#include "tensorflow/core/framework/node_def_builder.h"
|
|
#include "tensorflow/core/framework/versions.pb.h"
|
|
#include "tensorflow/core/graph/algorithm.h"
|
|
#include "tensorflow/core/graph/control_flow.h"
|
|
#include "tensorflow/core/graph/node_builder.h"
|
|
#include "tensorflow/core/lib/core/errors.h"
|
|
#include "tensorflow/core/lib/hash/hash.h"
|
|
#include "tensorflow/core/lib/strings/strcat.h"
|
|
#include "tensorflow/core/util/dump_graph.h"
|
|
|
|
using xla::StatusOr;
|
|
|
|
namespace tensorflow {
|
|
namespace functionalize_cond {
|
|
|
|
bool AncestorNode::operator<(const AncestorNode& other) const {
|
|
return (output_tensor.node->id() < other.output_tensor.node->id()) ||
|
|
(output_tensor.node->id() == other.output_tensor.node->id() &&
|
|
output_tensor.index < other.output_tensor.index) ||
|
|
(output_tensor.node->id() == other.output_tensor.node->id() &&
|
|
output_tensor.index == other.output_tensor.index &&
|
|
type < other.type);
|
|
}
|
|
|
|
bool AncestorNode::operator==(const AncestorNode& other) const {
|
|
return output_tensor.node->id() == other.output_tensor.node->id() &&
|
|
output_tensor.index == other.output_tensor.index && type == other.type;
|
|
}
|
|
|
|
size_t AncestorNode::Hash::operator()(const AncestorNode& ancestor) const {
|
|
size_t h = std::hash<int>()(ancestor.output_tensor.node->id());
|
|
h = Hash64Combine(h, std::hash<int>()(ancestor.output_tensor.index));
|
|
return Hash64Combine(h, std::hash<int>()(static_cast<int>(ancestor.type)));
|
|
}
|
|
|
|
typedef std::tuple<StateMap::CondId, StateMap::AncestorId, OutputTensor>
|
|
ClusterTuple;
|
|
|
|
struct ClusterTupleLessThan {
|
|
bool operator()(const ClusterTuple& a, const ClusterTuple& b) const {
|
|
if (std::tie(std::get<0>(a), std::get<1>(a)) <
|
|
std::tie(std::get<0>(b), std::get<1>(b))) {
|
|
return true;
|
|
} else if (std::tie(std::get<0>(a), std::get<1>(a)) ==
|
|
std::tie(std::get<0>(b), std::get<1>(b))) {
|
|
return StateMap::OutputTensorLess()(std::get<2>(a), std::get<2>(b));
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
};
|
|
|
|
// TODO(jpienaar): Move to OutputTensor.
|
|
string DebugString(const OutputTensor& tensor) {
|
|
return absl::StrCat(tensor.node->name(), ":", tensor.index);
|
|
}
|
|
|
|
string Branch_Name(BranchType b) {
|
|
switch (b) {
|
|
case BranchType::kElseBranch:
|
|
return "else";
|
|
case BranchType::kThenBranch:
|
|
return "then";
|
|
case BranchType::kBoth:
|
|
return "both";
|
|
case BranchType::kNeither:
|
|
return "neither";
|
|
}
|
|
}
|
|
|
|
string DebugString(StateMap::CondId cond_state) {
|
|
if (cond_state == nullptr || cond_state->empty()) return "{}";
|
|
using value_type = StateMap::CondState::value_type;
|
|
return absl::StrCat(
|
|
"{",
|
|
absl::StrJoin(*cond_state, ", ",
|
|
[](string* output, const value_type& pred_branch) {
|
|
const OutputTensor& pred = pred_branch.first;
|
|
const BranchType& branch = pred_branch.second;
|
|
if (branch == BranchType::kNeither)
|
|
absl::StrAppend(output, "d");
|
|
else
|
|
absl::StrAppend(output, "s(", DebugString(pred), ",",
|
|
Branch_Name(branch), ")");
|
|
}),
|
|
"}");
|
|
}
|
|
|
|
// Returns the predicate of a switch.
|
|
Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) {
|
|
const Edge* pred_edge;
|
|
TF_RETURN_IF_ERROR(switch_node.input_edge(1, &pred_edge));
|
|
// The predicate can be preceded by a identity node. Look through
|
|
// identity nodes to predicate.
|
|
while (pred_edge->src()->IsIdentity()) {
|
|
TF_RETURN_IF_ERROR(pred_edge->src()->input_edge(0, &pred_edge));
|
|
}
|
|
*pred = OutputTensor(pred_edge->src(), pred_edge->src_output());
|
|
return Status::OK();
|
|
}
|
|
|
|
Status GetSwitchValue(const Node& switch_node, OutputTensor* val) {
|
|
const Edge* val_edge;
|
|
TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge));
|
|
*val = OutputTensor(val_edge->src(), val_edge->src_output());
|
|
return Status::OK();
|
|
}
|
|
|
|
bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs,
|
|
const OutputTensor& rhs) const {
|
|
return (lhs.node->id() < rhs.node->id()) ||
|
|
(lhs.node->id() == rhs.node->id() && lhs.index < rhs.index);
|
|
}
|
|
|
|
struct CondStateLess {
|
|
bool operator()(const StateMap::CondState::value_type& lhs,
|
|
const StateMap::CondState::value_type& rhs) const {
|
|
if (StateMap::OutputTensorLess().operator()(lhs.first, rhs.first))
|
|
return true;
|
|
if (lhs.first.node->id() == rhs.first.node->id() &&
|
|
lhs.first.index == rhs.first.index)
|
|
return lhs.second < rhs.second;
|
|
return false;
|
|
}
|
|
};
|
|
|
|
StateMap::StateMap(Graph* graph) {
|
|
node_to_condid_map_.resize(graph->num_node_ids());
|
|
node_to_ancestorid_map_.resize(graph->num_node_ids());
|
|
// Initialize the dead state (empty state is designated with a nullptr).
|
|
dead_id_ = GetCondId(
|
|
{std::make_pair(OutputTensor(nullptr, -1), BranchType::kNeither)});
|
|
}
|
|
|
|
bool StateMap::IsDead(StateMap::CondId id) const { return id == dead_id_; }
|
|
|
|
bool StateMap::IsEmpty(StateMap::CondId id) const { return id == nullptr; }
|
|
|
|
size_t StateMap::Hash::operator()(const StateMap::CondState& map) const {
|
|
if (map.empty()) return 0;
|
|
// Compute hash of the front element.
|
|
auto it = map.begin();
|
|
size_t h = Hash64Combine(OutputTensor::Hash()(it->first),
|
|
hash<BranchType>()(it->second));
|
|
for (++it; it != map.end(); ++it) {
|
|
// Combine the has with the different elements in the map.
|
|
h = Hash64Combine(h, Hash64Combine(OutputTensor::Hash()(it->first),
|
|
hash<BranchType>()(it->second)));
|
|
}
|
|
return h;
|
|
}
|
|
|
|
size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const {
|
|
if (map.empty()) return 0;
|
|
// Compute hash of the front element.
|
|
auto it = map.begin();
|
|
size_t h = AncestorNode::Hash()(*it);
|
|
for (++it; it != map.end(); ++it) {
|
|
// Combine the has with the different elements in the map.
|
|
h = Hash64Combine(h, AncestorNode::Hash()(*it));
|
|
}
|
|
return h;
|
|
}
|
|
|
|
// CondArgNode represents a input to the conditional and its corresponding
|
|
// switch nodes.
|
|
struct CondArgNode {
|
|
explicit CondArgNode(Node* src, int src_output)
|
|
: src(src), src_output(src_output) {}
|
|
|
|
string ToString() const {
|
|
return absl::StrCat("src=", src->name(), ":", src_output,
|
|
" switches=", NodesToString(switches));
|
|
}
|
|
|
|
Node* src;
|
|
int src_output;
|
|
std::array<Node*, 2> branch_copy;
|
|
std::vector<Node*> switches;
|
|
};
|
|
using CondArgNodes = std::vector<CondArgNode>;
|
|
|
|
string DebugString(const CondArgNodes& nodes) {
|
|
return absl::StrCat(
|
|
"[",
|
|
absl::StrJoin(nodes, ", ",
|
|
[](string* output, const CondArgNode& node) {
|
|
absl::StrAppend(output, node.ToString());
|
|
}),
|
|
"]");
|
|
}
|
|
|
|
StateMap::CondId StateMap::LookupCondId(const Node* node) const {
|
|
if (node->id() < node_to_condid_map_.size())
|
|
return node_to_condid_map_[node->id()];
|
|
return added_node_condid_mapping_.at(node->id());
|
|
}
|
|
|
|
StateMap::CondId StateMap::GetCondId(const StateMap::CondState& state) {
|
|
if (state.empty()) return nullptr;
|
|
return &*condstate_set_.insert(state).first;
|
|
}
|
|
|
|
void StateMap::ResetCondId(const Node* node, StateMap::CondId id) {
|
|
if (node->id() < node_to_condid_map_.size())
|
|
node_to_condid_map_[node->id()] = id;
|
|
else
|
|
added_node_condid_mapping_[node->id()] = id;
|
|
}
|
|
|
|
StateMap::AncestorId StateMap::LookupAncestorId(const Node* node) const {
|
|
if (node->id() < node_to_ancestorid_map_.size())
|
|
return node_to_ancestorid_map_[node->id()];
|
|
return added_node_ancestorid_mapping_.at(node->id());
|
|
}
|
|
|
|
StateMap::AncestorId StateMap::GetAncestorId(
|
|
const StateMap::AncestorState& state) {
|
|
if (state.empty()) return nullptr;
|
|
return &*ancestorstate_set_.insert(state).first;
|
|
}
|
|
|
|
void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) {
|
|
if (node->id() < node_to_ancestorid_map_.size())
|
|
node_to_ancestorid_map_[node->id()] = id;
|
|
else
|
|
added_node_ancestorid_mapping_[node->id()] = id;
|
|
}
|
|
|
|
void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); }
|
|
|
|
string StateMap::CondStateToString(const Node* node) const {
|
|
return CondStateToString(LookupCondId(node));
|
|
}
|
|
|
|
string StateMap::CondStateToString(StateMap::CondId id) const {
|
|
return DebugString(id);
|
|
}
|
|
|
|
string StateMap::AncestorStateToString(const Node* node) const {
|
|
if (auto id = LookupAncestorId(node)) {
|
|
return absl::StrCat(
|
|
"{",
|
|
absl::StrJoin(*id, ",",
|
|
[](string* output, const AncestorNode& ancestor) {
|
|
absl::StrAppend(output,
|
|
ancestor.output_tensor.node->name(),
|
|
":", ancestor.output_tensor.index);
|
|
}),
|
|
"}");
|
|
}
|
|
return "{}";
|
|
}
|
|
|
|
FunctionalizeCond::FunctionalizeCond(Graph* graph,
|
|
FunctionLibraryDefinition* library)
|
|
: state_map_(graph), library_(library), graph_(graph) {}
|
|
|
|
// Class representing the merge/switch nodes that will become a conditional.
|
|
class Conditional {
|
|
public:
|
|
Conditional(OutputTensor predicate, FunctionalizeCond* parent,
|
|
StateMap* cond_state_map, const ShapeRefiner& refiner);
|
|
|
|
// Adds merge node that is part of this conditional.
|
|
Status AddMerge(Node* m);
|
|
|
|
// Constructs an If node from the merge nodes.
|
|
Status BuildAndReplace(
|
|
Graph* graph, FunctionLibraryDefinition* library,
|
|
std::unordered_map<Node*, OutputTensor>* merge_to_replacement);
|
|
|
|
private:
|
|
// Extracts the then/else bodies: creates new graphs with the nodes
|
|
// corresponding to the nodes in the then/else branches as of this conditional
|
|
// as function bodies.
|
|
Status ExtractBodies(Graph* graph);
|
|
|
|
// Builds the arguments that are the input to the If.
|
|
Status BuildArgumentNodes();
|
|
|
|
// Builds the If node for the extracted bodies with the given predicate.
|
|
Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library);
|
|
|
|
// Adds input edges to If node.
|
|
Status AddInputEdges(
|
|
Graph* graph,
|
|
const std::unordered_map<Node*, OutputTensor>& merge_to_replacement);
|
|
|
|
// Adds output edges from If node.
|
|
// Record new output tensor for all Merge nodes in 'merge_to_replacement'.
|
|
Status AddOutputEdges(
|
|
Graph* graph,
|
|
std::unordered_map<Node*, OutputTensor>* merge_to_replacement);
|
|
|
|
// Adds switch node that is part of this conditional.
|
|
Status AddSwitch(Node* s);
|
|
|
|
// Adds a switch node along the edge and rewire the edge to go via the switch.
|
|
Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
|
|
Graph* graph);
|
|
|
|
// Internal name of conditional. The name is based on the first merge node
|
|
// added.
|
|
string name() const;
|
|
|
|
// The FunctionalizeCond instance that created this.
|
|
FunctionalizeCond* parent_;
|
|
|
|
// Mapping between nodes and their cond state.
|
|
StateMap* state_map_;
|
|
|
|
// The predicate of the conditional.
|
|
OutputTensor predicate_;
|
|
|
|
// Shape refiner of ops in the graph.
|
|
const ShapeRefiner& refiner_;
|
|
|
|
// The predicate of the switches of the conditional. This may be different
|
|
// than predicate (which is initialized from the original graph) as the
|
|
// predicate could be the output of a newly created If node.
|
|
OutputTensor switch_predicate_;
|
|
|
|
// Switch nodes in graph that are part of this conditional.
|
|
std::set<Node*, NodeCmpByNameResourcesLast> switches_;
|
|
|
|
// Merge nodes in graph that are part of this conditional.
|
|
std::set<Node*, NodeCmpByNameResourcesLast> merges_;
|
|
|
|
// Vector of control inputs from outside the conditional to a node inside.
|
|
std::vector<Node*> external_control_inputs_;
|
|
std::vector<Node*> external_control_outputs_;
|
|
|
|
// Graphs corresponding to the then and else branch.
|
|
std::array<std::unique_ptr<Graph>, 2> bodies_;
|
|
|
|
// Maps from graph_ to the branch body's graph.
|
|
std::array<std::vector<Node*>, 2> node_maps_;
|
|
|
|
// The argument nodes created for the switches.
|
|
CondArgNodes cond_arg_nodes_;
|
|
|
|
// The constructed If node.
|
|
Node* if_node_ = nullptr;
|
|
|
|
// Whether the merge nodes of this conditional have been replaced.
|
|
bool replaced_ = false;
|
|
};
|
|
|
|
Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent,
|
|
StateMap* cond_state_map, const ShapeRefiner& refiner)
|
|
: parent_(parent),
|
|
state_map_(cond_state_map),
|
|
predicate_(predicate),
|
|
refiner_(refiner) {}
|
|
|
|
Status Conditional::AddMerge(Node* m) {
|
|
merges_.insert(m);
|
|
return Status::OK();
|
|
}
|
|
|
|
Status Conditional::AddSwitch(Node* s) {
|
|
VLOG(5) << "Adding switch " << s->DebugString();
|
|
OutputTensor predicate;
|
|
TF_RETURN_IF_ERROR(GetSwitchPredicate(*s, &predicate));
|
|
if (switch_predicate_.node == nullptr) switch_predicate_ = predicate;
|
|
if (!(switch_predicate_ == predicate)) {
|
|
return errors::InvalidArgument(
|
|
"Merge nodes ", NodesToString(merges_),
|
|
" directly dominated by switch nodes with different predicates (",
|
|
DebugString(switch_predicate_), " vs ", DebugString(predicate), ").");
|
|
}
|
|
switches_.insert(s);
|
|
parent_->AddSwitchId(s->id());
|
|
return Status::OK();
|
|
}
|
|
|
|
Status Conditional::BuildArgumentNodes() {
|
|
VLOG(1) << "Build function arguments";
|
|
struct Hash {
|
|
size_t operator()(const std::pair<Node*, int>& item) const {
|
|
return Hash64Combine(hash<Node*>()(item.first),
|
|
std::hash<int>()(item.second));
|
|
}
|
|
};
|
|
|
|
std::unordered_map<std::pair<Node*, int>, int, Hash> input_index;
|
|
for (Node* switch_node : switches_) {
|
|
const Edge* e;
|
|
TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e));
|
|
std::pair<Node*, int> key = std::make_pair(e->src(), e->src_output());
|
|
if (input_index.find(key) == input_index.end()) {
|
|
input_index[key] = cond_arg_nodes_.size();
|
|
cond_arg_nodes_.emplace_back(key.first, key.second);
|
|
}
|
|
cond_arg_nodes_.at(input_index.at(key)).switches.push_back(switch_node);
|
|
}
|
|
VLOG(5) << "CondArg nodes created: " << DebugString(cond_arg_nodes_);
|
|
|
|
int arg_count = 0;
|
|
for (CondArgNode& cond_arg_node : cond_arg_nodes_) {
|
|
DataType dtype = cond_arg_node.src->output_type(cond_arg_node.src_output);
|
|
for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
|
|
int branch_index = static_cast<int>(branch);
|
|
TF_RETURN_IF_ERROR(
|
|
NodeBuilder(absl::StrCat("_Arg", arg_count),
|
|
FunctionLibraryDefinition::kArgOp)
|
|
.Attr("T", dtype)
|
|
.Attr("index", arg_count)
|
|
.Finalize(bodies_[branch_index].get(),
|
|
&cond_arg_node.branch_copy[branch_index]));
|
|
}
|
|
for (Node* node : cond_arg_node.switches) {
|
|
for (const Edge* e : node->out_edges()) {
|
|
if (e->IsControlEdge()) continue;
|
|
int branch_index = e->src_output();
|
|
Node* src_copy = cond_arg_node.branch_copy[branch_index];
|
|
Node* dst_copy = node_maps_[branch_index][e->dst()->id()];
|
|
|
|
// The graph may contain dead switch nodes,
|
|
if (dst_copy == nullptr) continue;
|
|
|
|
TF_RET_CHECK(dst_copy != nullptr)
|
|
<< "Unable to find copied node for " << e->dst()->DebugString()
|
|
<< " on branch " << Branch_Name(BranchType(branch_index));
|
|
// If the input goes directly to a merge then the merge has
|
|
// been replaced by a retval so the dst input is 0 instead of
|
|
// dst_input.
|
|
int dst_input = IsMerge(e->dst()) ? 0 : e->dst_input();
|
|
bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input);
|
|
}
|
|
}
|
|
++arg_count;
|
|
}
|
|
|
|
// Verify that all retvals have an input.
|
|
// TODO(jpienaar): One could add a ZerosLike in the branch that doesn't have
|
|
// input.
|
|
for (Node* m : merges_) {
|
|
for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
|
|
bool has_input = false;
|
|
for (auto e : node_maps_[static_cast<int>(branch)][m->id()]->in_edges()) {
|
|
if (!e->IsControlEdge()) {
|
|
has_input = true;
|
|
break;
|
|
}
|
|
}
|
|
if (!has_input) {
|
|
return errors::Internal(
|
|
"Failed to functionalize control flow with merge ",
|
|
FormatNodeForError(*m), " that doesn't have input on ",
|
|
Branch_Name(branch), " branch.");
|
|
}
|
|
}
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
|
|
Graph* graph) {
|
|
// Previously we had edge:
|
|
// src:src_output ---- edge ----> dst:dst_input
|
|
// post this we have (in graph)
|
|
// src:src_output --> switch<pred> --- new_edge --> dst:dst_input
|
|
|
|
// TODO(jpienaar): One could keep a map caching the extra switch nodes added
|
|
// to avoid adding another switch to feed a value for which a switch was
|
|
// already added.
|
|
Node* switch_node;
|
|
Node* src = edge->src();
|
|
int src_output = edge->src_output();
|
|
TF_RETURN_IF_ERROR(
|
|
NodeBuilder(graph->NewName(absl::StrCat(src->name(), "_added_switch")),
|
|
"Switch")
|
|
.Input(src, src_output)
|
|
.Input(const_cast<Node*>(predicate_.node), predicate_.index)
|
|
.Finalize(graph, &switch_node));
|
|
state_map_->ResetCondId(switch_node, state_map_->LookupCondId(src));
|
|
state_map_->ResetAncestorId(switch_node, state_map_->LookupAncestorId(src));
|
|
|
|
Node* dst = edge->dst();
|
|
int dst_input = edge->dst_input();
|
|
graph->RemoveEdge(edge);
|
|
graph->AddEdge(switch_node, static_cast<int>(branch), dst, dst_input);
|
|
return AddSwitch(switch_node);
|
|
}
|
|
|
|
Status Conditional::ExtractBodies(Graph* graph) {
|
|
VLOG(2) << "Extracting bodies for " << name();
|
|
for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) {
|
|
bodies_[static_cast<int>(b)] =
|
|
absl::make_unique<Graph>(graph->op_registry());
|
|
}
|
|
|
|
auto find_branch = [&](const Edge* e) {
|
|
const auto& id = state_map_->LookupCondId(e->src());
|
|
return IsSwitch(e->src()) ? BranchType(e->src_output())
|
|
: state_map_->FindBranchOf(id, predicate_);
|
|
};
|
|
|
|
std::array<std::vector<Node*>, 2> stacks;
|
|
VLOG(5) << "Merges: " << NodesToString(merges_);
|
|
for (Node* m : merges_) {
|
|
VLOG(5) << "For merge: " << m->DebugString() << " "
|
|
<< state_map_->CondStateToString(m);
|
|
for (auto e : m->in_edges()) {
|
|
if (e->IsControlEdge()) continue;
|
|
BranchType branch = find_branch(e);
|
|
TF_RET_CHECK(branch == BranchType::kThenBranch ||
|
|
branch == BranchType::kElseBranch)
|
|
<< "Error: " << e->src()->name()
|
|
<< " is not on either then or else branch (" << Branch_Name(branch)
|
|
<< ") for predicate " << DebugString(predicate_) << " ["
|
|
<< DebugString(state_map_->LookupCondId(e->src())) << "].";
|
|
Node* src = e->src();
|
|
if (IsSwitch(src)) {
|
|
// Switch node outputs and dependencies are handled separately.
|
|
TF_RETURN_IF_ERROR(AddSwitch(src));
|
|
} else {
|
|
stacks[static_cast<int>(branch)].push_back(src);
|
|
}
|
|
}
|
|
}
|
|
|
|
for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
|
|
int branch_index = static_cast<int>(branch);
|
|
auto output = bodies_[branch_index].get();
|
|
auto& stack = stacks[branch_index];
|
|
VLOG(5) << "In branch: " << Branch_Name(branch) << " "
|
|
<< NodesToString(stack);
|
|
std::vector<bool> visited(graph->num_node_ids(), false);
|
|
node_maps_[branch_index].resize(graph->num_node_ids(), nullptr);
|
|
auto& node_map = node_maps_[branch_index];
|
|
|
|
while (!stack.empty()) {
|
|
Node* n = stack.back();
|
|
stack.pop_back();
|
|
|
|
if (visited.at(n->id())) continue;
|
|
visited[n->id()] = true;
|
|
|
|
// Verify output edges and record control edges exiting scope.
|
|
for (const Edge* e : n->out_edges()) {
|
|
Node* dst = e->dst();
|
|
if (IsMerge(dst)) continue;
|
|
Node* src = e->src();
|
|
|
|
auto dst_id = state_map_->LookupCondId(dst);
|
|
auto src_id = state_map_->LookupCondId(src);
|
|
if (dst_id != src_id) {
|
|
if (e->IsControlEdge()) {
|
|
external_control_outputs_.push_back(e->src());
|
|
} else {
|
|
// Constants are treated specially to workaround the case of
|
|
// non-dominated constant nodes.
|
|
if (!IsConstant(src)) {
|
|
// TODO(b/78882471): A node that feeds into two different
|
|
// CondState is not necessarily an error so log a warning for now
|
|
// but revisit to improve the testing to enable making this an
|
|
// error.
|
|
LOG(WARNING) << errors::InvalidArgument(
|
|
"Graph contains node ", FormatNodeForError(*src),
|
|
" that feeds into node ", FormatNodeForError(*dst),
|
|
" but these nodes are in different control contexts (",
|
|
DebugString(src_id), " vs ", DebugString(dst_id),
|
|
" (detected during out edge testing)");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Copying incoming edges to dst node. Iterate over a copy of the edges
|
|
// as they could be mutated during iteration.
|
|
std::vector<const Edge*> in_edges(n->in_edges().begin(),
|
|
n->in_edges().end());
|
|
// Sort in_edges to make sure nodes are copied in a deterministic order.
|
|
std::sort(
|
|
in_edges.begin(), in_edges.end(), [](const Edge* a, const Edge* b) {
|
|
int a_src_output = a->src_output(), b_src_output = b->src_output();
|
|
StringPiece a_name(a->src()->name()), b_name(b->src()->name());
|
|
return std::tie(a_src_output, a_name) <
|
|
std::tie(b_src_output, b_name);
|
|
});
|
|
for (const Edge* e : in_edges) {
|
|
Node* src = e->src();
|
|
// Skip src/dst node.
|
|
if (!src->IsOp()) continue;
|
|
|
|
Node* dst = e->dst();
|
|
if (IsSwitch(src)) {
|
|
// Switch node outputs and dependencies are handled separately.
|
|
TF_RETURN_IF_ERROR(AddSwitch(src));
|
|
continue;
|
|
}
|
|
|
|
// Verify input is from the same context.
|
|
auto src_id = state_map_->LookupCondId(src);
|
|
auto dst_id = state_map_->LookupCondId(dst);
|
|
if (IsMerge(dst) || src_id == dst_id) {
|
|
// TODO(jpienaar): The merge case can be more strict.
|
|
if (node_map.at(src->id()) == nullptr) {
|
|
node_map.at(src->id()) = output->CopyNode(src);
|
|
stack.push_back(src);
|
|
}
|
|
} else if (e->IsControlEdge()) {
|
|
// Here we have a control flow edge between src and dst that are not
|
|
// in the same context. This is an external control dependency except
|
|
// for one case: where the only difference between CondId of e->src()
|
|
// and CondId of e->dst() is that e->src() has {PRED, kNeither} and
|
|
// e->dst() has {PRED, kThenBranch/kElseBranch}. This happens in
|
|
// gradients code for tf.cond(), where e->src() is a control pivot
|
|
// node for a branch and e->dst() is a data node in that branch.
|
|
bool is_external_control_input = true;
|
|
if (!state_map_->IsEmpty(src_id) && !state_map_->IsEmpty(dst_id)) {
|
|
std::vector<StateMap::CondState::value_type> diff;
|
|
std::set_symmetric_difference(
|
|
src_id->begin(), src_id->end(), dst_id->begin(), dst_id->end(),
|
|
std::back_inserter(diff), CondStateLess());
|
|
if (diff.size() == 2 && diff[0].first == diff[1].first &&
|
|
(diff[0].second == BranchType::kNeither ||
|
|
diff[1].second == BranchType::kNeither)) {
|
|
auto src_branch = src_id->find(diff[0].first);
|
|
if (src_branch != src_id->end() &&
|
|
src_branch->second == BranchType::kNeither) {
|
|
is_external_control_input = false;
|
|
}
|
|
}
|
|
}
|
|
if (is_external_control_input) {
|
|
external_control_inputs_.push_back(src);
|
|
}
|
|
} else {
|
|
// This shouldn't happen, this means we have an external data input
|
|
// not entering via a switch node. Work around this by for
|
|
// * constant nodes copy them;
|
|
// * non-constant nodes, insert a switch along the edge;
|
|
if (IsConstant(src)) {
|
|
// Check if constant node was added already. It is possible to have
|
|
// multiple uses of a constant node.
|
|
if (node_map.at(src->id()) == nullptr) {
|
|
node_map.at(src->id()) = output->CopyNode(src);
|
|
}
|
|
} else {
|
|
StateMap::CondState state = *dst_id;
|
|
state.erase(predicate_);
|
|
if (state_map_->GetCondId(state) == src_id) {
|
|
TF_RETURN_IF_ERROR(AddSwitchNodeAlongEdge(e, branch, graph));
|
|
continue;
|
|
} else {
|
|
return errors::InvalidArgument(
|
|
"Graph contains node ", FormatNodeForError(*src),
|
|
" that feeds into node ", FormatNodeForError(*dst),
|
|
" but these nodes are in different control contexts (",
|
|
DebugString(src_id), " vs ", DebugString(dst_id),
|
|
" (detected during in edge testing)");
|
|
}
|
|
}
|
|
}
|
|
|
|
Node* src_copy = node_map.at(e->src()->id());
|
|
int src_output = e->src_output();
|
|
if (node_map.at(dst->id()) == nullptr) {
|
|
node_map.at(dst->id()) = output->CopyNode(dst);
|
|
}
|
|
Node* dst_copy = node_map.at(e->dst()->id());
|
|
if (e->IsControlEdge()) {
|
|
// Skip control inputs from external context.
|
|
if (src_copy != nullptr) output->AddControlEdge(src_copy, dst_copy);
|
|
} else {
|
|
output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Build return values from the merge nodes.
|
|
int index = 0;
|
|
for (Node* m : merges_) {
|
|
for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
|
|
int branch_index = static_cast<int>(branch);
|
|
auto& node_map = node_maps_[branch_index];
|
|
auto output = bodies_[branch_index].get();
|
|
TF_ASSIGN_OR_RETURN(node_map[m->id()],
|
|
BuildRetvalNode(output, m->output_type(0), index));
|
|
}
|
|
++index;
|
|
|
|
// Connect the input to the merge_ with the retval, except if it is a
|
|
// Switch node, which is handled separately.
|
|
for (auto e : m->in_edges()) {
|
|
if (e->IsControlEdge()) continue;
|
|
int branch_index = static_cast<int>(find_branch(e));
|
|
auto& node_map = node_maps_[branch_index];
|
|
auto output = bodies_[branch_index].get();
|
|
Node* in = e->src();
|
|
if (!IsSwitch(in)) {
|
|
if (node_map.at(in->id()) == nullptr) {
|
|
node_map[in->id()] = output->CopyNode(in);
|
|
}
|
|
output->AddEdge(node_map[in->id()], e->src_output(),
|
|
node_map.at(m->id()), 0);
|
|
}
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status Conditional::BuildIfNode(Graph* graph,
|
|
FunctionLibraryDefinition* library) {
|
|
VLOG(2) << "Build cond function for " << name();
|
|
NodeDebugInfo debug_info((*merges_.begin())->def());
|
|
NodeDefBuilder builder(name(), "If", library, &debug_info);
|
|
const string branch_name[] = {"else_branch", "then_branch"};
|
|
for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
|
|
int branch_index = static_cast<int>(branch);
|
|
|
|
NameAttrList body_name;
|
|
body_name.set_name(library->UniqueFunctionName(
|
|
absl::StrCat("_functionalize_if_", branch_name[branch_index], "_")));
|
|
|
|
VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index]
|
|
<< "): "
|
|
<< DumpGraphToFile(
|
|
"functionalize_cond_body_" + branch_name[branch_index],
|
|
*bodies_[branch_index], nullptr);
|
|
|
|
FunctionDef body_fdef;
|
|
TF_RETURN_IF_ERROR(GraphToFunctionDef(*bodies_[branch_index],
|
|
body_name.name(), &body_fdef));
|
|
TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
|
|
builder.Attr(branch_name[branch_index], body_name);
|
|
}
|
|
|
|
VLOG(3) << "Build input type";
|
|
std::vector<NodeDefBuilder::NodeOut> inputs;
|
|
DataTypeVector in_arg_types;
|
|
for (auto& kv : cond_arg_nodes_) {
|
|
bool inserted = false;
|
|
for (const Node* arg : kv.switches) {
|
|
const Edge* in_edge;
|
|
TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
|
|
if (in_edge->IsControlEdge()) {
|
|
builder.ControlInput(in_edge->src()->name());
|
|
} else {
|
|
if (!inserted) {
|
|
DataType dtype = arg->input_type(0);
|
|
inputs.emplace_back(NodeDefBuilder::NodeOut(
|
|
in_edge->src()->name(), in_edge->src_output(), dtype));
|
|
in_arg_types.push_back(dtype);
|
|
inserted = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
builder.Attr("Tin", in_arg_types);
|
|
|
|
DataTypeVector out_type;
|
|
std::vector<PartialTensorShape> output_shapes;
|
|
output_shapes.reserve(merges_.size());
|
|
for (const Node* merge : merges_) {
|
|
DataType dtype = merge->output_type(0);
|
|
TensorShapeProto shape;
|
|
if (auto* shape_ctx = refiner_.GetContext(merge)) {
|
|
shape_inference::ShapeHandle handle;
|
|
shape_ctx->ShapeHandleToProto(shape_ctx->output(0), &shape);
|
|
}
|
|
out_type.push_back(dtype);
|
|
output_shapes.push_back(shape);
|
|
}
|
|
builder.Attr("Tout", out_type);
|
|
VLOG(3) << "Build output type: " << DataTypeVectorString(out_type);
|
|
builder.Attr("output_shapes", output_shapes);
|
|
VLOG(3) << "Build output shapes: "
|
|
<< PartialTensorShapeUtils::PartialShapeListString(output_shapes);
|
|
|
|
builder.Attr("Tcond", DT_BOOL);
|
|
string outside_compilation;
|
|
if (GetNodeAttr(predicate_.node->def(), kXlaOutsideCompilationAttrName,
|
|
&outside_compilation)
|
|
.ok()) {
|
|
builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation);
|
|
}
|
|
builder.Device(predicate_.node->assigned_device_name());
|
|
// Conditional should be the first input ...
|
|
builder.Input(
|
|
NodeDefBuilder::NodeOut(predicate_.node->name(), predicate_.index,
|
|
predicate_.node->output_type(predicate_.index)));
|
|
// ... followed by the other inputs.
|
|
builder.Input(inputs);
|
|
|
|
VLOG(3) << "Build If node";
|
|
NodeDef if_def;
|
|
TF_RETURN_IF_ERROR(builder.Finalize(&if_def));
|
|
TF_ASSIGN_OR_RETURN(if_node_,
|
|
parent_->AddIfNode(if_def, *merges_.begin(), predicate_));
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status Conditional::AddInputEdges(
|
|
Graph* graph,
|
|
const std::unordered_map<Node*, OutputTensor>& merge_to_replacement) {
|
|
VLOG(2) << "AddInputEdges for " << if_node_->name();
|
|
int index = 0;
|
|
// Add predicate input.
|
|
if (predicate_.node->IsMerge()) {
|
|
// If the predicate is a Merge node, we should not use Merge output as
|
|
// predicate. Instead, we should use the corresponding If output in
|
|
// 'merge_to_replacement'. Otherwise, this Conditional's If node is still
|
|
// connected to the predicate Merge node; and when we call
|
|
// DeleteReachableAndDeadNodes(), the predicate Merge node and this
|
|
// Conditional's If node will be removed.
|
|
auto iter = merge_to_replacement.find(predicate_.node);
|
|
if (iter == merge_to_replacement.end()) {
|
|
return errors::Internal("Cannot find replacement for Merge node ",
|
|
predicate_.node->name());
|
|
}
|
|
graph->AddEdge(iter->second.node, iter->second.index, if_node_, index++);
|
|
} else {
|
|
graph->AddEdge(const_cast<Node*>(predicate_.node), predicate_.index,
|
|
if_node_, index++);
|
|
}
|
|
// Add function body inputs.
|
|
for (auto& arg : cond_arg_nodes_) {
|
|
if (arg.src_output == Graph::kControlSlot) {
|
|
graph->AddControlEdge(arg.src, if_node_);
|
|
} else {
|
|
graph->AddEdge(arg.src, arg.src_output, if_node_, index++);
|
|
}
|
|
}
|
|
for (Node* n : external_control_inputs_) {
|
|
graph->AddControlEdge(n, if_node_);
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status Conditional::AddOutputEdges(
|
|
Graph* graph,
|
|
std::unordered_map<Node*, OutputTensor>* merge_to_replacement) {
|
|
VLOG(2) << "AddOutputEdges for " << if_node_->name();
|
|
int i = 0;
|
|
for (Node* node : merges_) {
|
|
TF_RETURN_IF_ERROR(parent_->AddIdentityNode(node, if_node_, i));
|
|
std::vector<const Edge*> edges(node->out_edges().begin(),
|
|
node->out_edges().end());
|
|
for (const Edge* edge : edges) {
|
|
Node* dst = edge->dst();
|
|
int dst_input = edge->dst_input();
|
|
if (edge->src_output() > 0) {
|
|
return errors::Unimplemented("Output of index (", edge->src_output(),
|
|
") of merge node ",
|
|
FormatNodeForError(*node));
|
|
}
|
|
|
|
bool control_edge = edge->IsControlEdge();
|
|
graph->RemoveEdge(edge);
|
|
if (control_edge) {
|
|
graph->AddControlEdge(if_node_, dst);
|
|
} else {
|
|
graph->AddEdge(if_node_, i, dst, dst_input);
|
|
}
|
|
}
|
|
|
|
// Record corresponding output tensor in 'merge_to_replacement'.
|
|
(*merge_to_replacement)[node] = OutputTensor{if_node_, i};
|
|
|
|
++i;
|
|
}
|
|
for (Node* n : external_control_outputs_) {
|
|
graph->AddControlEdge(if_node_, n);
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status Conditional::BuildAndReplace(
|
|
Graph* graph, FunctionLibraryDefinition* library,
|
|
std::unordered_map<Node*, OutputTensor>* merge_to_replacement) {
|
|
VLOG(1) << "Build If and replace merge nodes "
|
|
<< NodesToString(this->merges_);
|
|
if (replaced_) return Status::OK();
|
|
|
|
TF_RETURN_IF_ERROR(ExtractBodies(graph));
|
|
TF_RETURN_IF_ERROR(BuildArgumentNodes());
|
|
|
|
if (VLOG_IS_ON(3)) {
|
|
LOG(INFO) << "Extracted bodies:";
|
|
for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
|
|
int branch_index = static_cast<int>(branch);
|
|
auto output = bodies_[branch_index].get();
|
|
LOG(INFO) << Branch_Name(branch) << ": "
|
|
<< DebugString(output->ToGraphDefDebug());
|
|
}
|
|
}
|
|
|
|
TF_RETURN_IF_ERROR(BuildIfNode(graph, library));
|
|
TF_RETURN_IF_ERROR(AddInputEdges(graph, *merge_to_replacement));
|
|
TF_RETURN_IF_ERROR(AddOutputEdges(graph, merge_to_replacement));
|
|
TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_));
|
|
|
|
// Check that the if_node doesn't feed into itself.
|
|
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
|
CheckNodeNotInCycle(if_node_, graph->num_node_ids()),
|
|
"Converting to If failed.");
|
|
|
|
replaced_ = true;
|
|
return Status::OK();
|
|
}
|
|
|
|
string Conditional::name() const {
|
|
CHECK(!merges_.empty());
|
|
return absl::StrCat((*merges_.begin())->name(), "_if");
|
|
}
|
|
|
|
Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
|
|
int port) {
|
|
NodeBuilder id_builder(replacee->name(), "Identity");
|
|
id_builder.Input(if_node, port);
|
|
string outside_compilation;
|
|
if (GetNodeAttr(if_node->def(), kXlaOutsideCompilationAttrName,
|
|
&outside_compilation)
|
|
.ok()) {
|
|
id_builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation);
|
|
}
|
|
Node* id;
|
|
TF_RETURN_IF_ERROR(id_builder.Finalize(graph_, &id));
|
|
state_map_.ResetCondId(id, state_map_.LookupCondId(if_node));
|
|
state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node));
|
|
return Status::OK();
|
|
}
|
|
|
|
StatusOr<Node*> FunctionalizeCond::AddIfNode(const NodeDef& def,
|
|
const Node* replacee,
|
|
const OutputTensor& predicate) {
|
|
Status status;
|
|
Node* ret = graph_->AddNode(def, &status);
|
|
TF_RETURN_IF_ERROR(status);
|
|
VLOG(1) << "Adding If for " << replacee->name();
|
|
StateMap::CondId id = state_map_.LookupCondId(replacee);
|
|
if (id) {
|
|
StateMap::CondState state = *id;
|
|
state.erase(predicate);
|
|
state_map_.ResetCondId(ret, state_map_.GetCondId(state));
|
|
} else {
|
|
state_map_.ResetCondId(ret, nullptr);
|
|
}
|
|
|
|
state_map_.ResetAncestorId(ret, state_map_.LookupAncestorId(replacee));
|
|
|
|
return ret;
|
|
}
|
|
|
|
Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) {
|
|
VLOG(2) << "Propagating update state for " << replacee->name() << " "
|
|
<< state_map_.CondStateToString(replacee);
|
|
// Redo topological sort as the order could have changed.
|
|
// TODO(jpienaar): The original topological order could also be updated
|
|
// dynamically if needed.
|
|
std::vector<Node*> rev_topo_order;
|
|
GetPostOrder(*graph_, &rev_topo_order);
|
|
|
|
// All the outputs of the new node could potentially be updated.
|
|
std::unordered_set<Node*> changed;
|
|
for (auto n : replacee->out_nodes())
|
|
if (n->IsOp()) changed.insert(n);
|
|
|
|
// Iterate through the changed/possible changed nodes in topological order.
|
|
for (auto it = rev_topo_order.rbegin();
|
|
it != rev_topo_order.rend() && !changed.empty(); ++it) {
|
|
if (changed.find(*it) != changed.end()) {
|
|
// Update the node state.
|
|
Node* n = *it;
|
|
StateMap::CondId old_state = state_map_.LookupCondId(n);
|
|
state_map_.ResetCondId(n, nullptr);
|
|
TF_RETURN_IF_ERROR(DetermineCondState(n));
|
|
if (state_map_.LookupCondId(n) != old_state) {
|
|
for (auto out : n->out_nodes())
|
|
if (out->IsOp()) changed.insert(out);
|
|
}
|
|
changed.erase(n);
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
// Returns the most restrictive branch of two branches or neither. This is the
|
|
// meet operator of the BranchType lattice.
|
|
BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) {
|
|
if (lhs == rhs) return lhs;
|
|
if (lhs == BranchType::kNeither) return rhs;
|
|
if (rhs == BranchType::kNeither) return lhs;
|
|
if (lhs == BranchType::kBoth) return rhs;
|
|
if (rhs == BranchType::kBoth) return lhs;
|
|
return BranchType::kNeither;
|
|
}
|
|
|
|
BranchType StateMap::FindBranchOf(CondId id, OutputTensor predicate) const {
|
|
if (IsEmpty(id)) return BranchType::kNeither;
|
|
const CondState& nodes = *id;
|
|
auto it = nodes.find(predicate);
|
|
if (it == nodes.end()) return BranchType::kNeither;
|
|
return it->second;
|
|
}
|
|
|
|
StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesNonMerge(
|
|
StateMap::CondId src, StateMap::CondId dst) {
|
|
VLOG(5) << "Joining src=" << DebugString(src) << " [" << src
|
|
<< "] and dst=" << DebugString(dst) << " [" << dst << "]";
|
|
|
|
if (state_map_.IsEmpty(dst) || state_map_.IsDead(src)) return src;
|
|
if (state_map_.IsDead(dst) || state_map_.IsEmpty(src)) return dst;
|
|
|
|
// Nothing to do if the CondState is the same.
|
|
if (src == dst) return src;
|
|
|
|
StateMap::CondState both = *src;
|
|
for (const auto& kv : *dst) {
|
|
auto it = both.find(kv.first);
|
|
if (it == both.end()) {
|
|
both.insert(kv);
|
|
} else {
|
|
if (it->second != kv.second) {
|
|
if (it->second == BranchType::kNeither) {
|
|
// BranchType for 'src' is kNeither. Use the BranchType in 'dst'.
|
|
it->second = kv.second;
|
|
} else if (kv.second == BranchType::kNeither) {
|
|
// BranchType for 'dst' is kNeither. Use the BranchType in 'src'.
|
|
// No need to change it->second.
|
|
} else {
|
|
return errors::InvalidArgument(
|
|
"Graph contains node with inputs predicated on incompatible "
|
|
"predicates: ",
|
|
DebugString(src), " and ", DebugString(dst));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return state_map_.GetCondId(both);
|
|
}
|
|
|
|
StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesMerge(
|
|
Node* merge, StateMap::CondId src, StateMap::CondId dst) {
|
|
// Determine the flow state when joining two states for a merge
|
|
// node. Combining the two states for a merge node is effectively performing a
|
|
// disjunction of the states along the different input edges. For a merge that
|
|
// can be transformed into a If the two inputs paths have to have a predicate
|
|
// on which they differ (e.g., along one edge predicate `p` has to hold while
|
|
// on another it should not). This function first determines this predicate
|
|
// and then the resultant state is the common path between the two inputs
|
|
// followed by s(p, both).
|
|
VLOG(4) << "Joining (for merge) " << DebugString(src) << " and "
|
|
<< DebugString(dst);
|
|
if (state_map_.IsEmpty(dst)) return src;
|
|
if (state_map_.IsEmpty(src)) {
|
|
return errors::Internal("Merge node ", merge->name(),
|
|
" has input that's not in any CondContext.");
|
|
}
|
|
|
|
if (state_map_.IsDead(src)) return src;
|
|
if (state_map_.IsDead(dst)) return dst;
|
|
|
|
std::vector<StateMap::CondState::value_type> diff;
|
|
StateMap::CondState merged;
|
|
std::set_symmetric_difference(src->begin(), src->end(), dst->begin(),
|
|
dst->end(), std::back_inserter(diff),
|
|
CondStateLess());
|
|
std::set_intersection(src->begin(), src->end(), dst->begin(), dst->end(),
|
|
std::inserter(merged, merged.begin()), CondStateLess());
|
|
|
|
// Update mapping from merge node to predicate.
|
|
if (diff.size() == 2) {
|
|
auto pred = diff[0].first;
|
|
bool different_branches = (diff[0].second != diff[1].second) &&
|
|
(diff[0].second == BranchType::kThenBranch ||
|
|
diff[0].second == BranchType::kElseBranch) &&
|
|
(diff[1].second == BranchType::kThenBranch ||
|
|
diff[1].second == BranchType::kElseBranch);
|
|
if (!(pred == diff[1].first) || !different_branches)
|
|
return errors::InvalidArgument(
|
|
"Unable to determine predicate for merge node");
|
|
merge_to_predicate_[merge] = pred;
|
|
} else {
|
|
return errors::InvalidArgument(
|
|
"Merge of two inputs that differ on more than one predicate ",
|
|
DebugString(src), " and ", DebugString(dst));
|
|
}
|
|
|
|
return state_map_.GetCondId(merged);
|
|
}
|
|
|
|
StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) {
|
|
Node* src = e->src();
|
|
StateMap::CondId id = state_map_.LookupCondId(e->src());
|
|
|
|
// Dead nodes only propagate dead state.
|
|
if (state_map_.IsDead(id)) return id;
|
|
|
|
if (IsSwitch(src)) {
|
|
StateMap::CondState state;
|
|
if (id != nullptr) state = *id;
|
|
OutputTensor predicate;
|
|
TF_CHECK_OK(GetSwitchPredicate(*src, &predicate));
|
|
if (e->IsControlEdge()) {
|
|
// In gradients of tf.cond(), in each branch, we have a NoOp node as
|
|
// control pivot. These NoOp nodes have control dependency from Switch
|
|
// node. If we don't record this into CondState, branches might have
|
|
// incorrect CondState (e.g. if the branch only has a Const data node).
|
|
// We set it to kNeither because there is no way to tell whether it's
|
|
// for true branch or false branch. This node's descendents might have
|
|
// other incoming edges with defined BranchType, and we correctly handle
|
|
// merging kNeither with other defined BranchType in StateAlongEdge().
|
|
state[predicate] = BranchType::kNeither;
|
|
} else {
|
|
state[predicate] = BranchType(e->src_output());
|
|
}
|
|
return state_map_.GetCondId(state);
|
|
}
|
|
return id;
|
|
}
|
|
|
|
Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) {
|
|
// Only Merge nodes with two inputs are supported, but if this is a redundant
|
|
// merge, then the dead edge may already have been removed (if due to a
|
|
// switch) and so the input count would be incorrect.
|
|
if (state_map_.IsDead(state_map_.LookupCondId(dst))) return Status::OK();
|
|
|
|
int data_inputs = 0;
|
|
for (auto e : dst->in_edges()) {
|
|
Node* src = e->src();
|
|
VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " "
|
|
<< state_map_.CondStateToString(src);
|
|
if (!src->IsOp()) continue;
|
|
if (!e->IsControlEdge()) ++data_inputs;
|
|
|
|
StateMap::CondId prop = StateAlongEdge(e);
|
|
auto id_or = JoinCondStatesMerge(dst, prop, state_map_.LookupCondId(dst));
|
|
TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
|
|
FormatNodeForError(*dst));
|
|
state_map_.ResetCondId(dst, id_or.ValueOrDie());
|
|
}
|
|
|
|
// Incomplete Merge nodes are not supported.
|
|
if (data_inputs != 2) {
|
|
return errors::Unimplemented(
|
|
dst->name(), " only has ", data_inputs,
|
|
" inputs, while only merge nodes with two inputs supported.");
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) {
|
|
// Handle non-merge join.
|
|
for (auto e : dst->in_edges()) {
|
|
VLOG(4) << "Processing forward flow for: " << e->DebugString() << " "
|
|
<< state_map_.CondStateToString(dst);
|
|
Node* src = e->src();
|
|
if (!src->IsOp()) continue;
|
|
|
|
// Joining the state between the current and propagated state.
|
|
StateMap::CondId prop = StateAlongEdge(e);
|
|
auto id_or = JoinCondStatesNonMerge(prop, state_map_.LookupCondId(dst));
|
|
TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
|
|
FormatNodeForError(*dst));
|
|
state_map_.ResetCondId(dst, id_or.ValueOrDie());
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
|
|
// Handle redundant merge nodes. A merge node is considered redundant if
|
|
// one input edge is dead while the other has a value.
|
|
if (!state_map_.IsDead(state_map_.LookupCondId(node))) return Status::OK();
|
|
|
|
const Edge* non_dead_edge = nullptr;
|
|
for (auto e : node->in_edges()) {
|
|
if (e->IsControlEdge()) continue;
|
|
Node* src = e->src();
|
|
|
|
// Handle merge with dead state.
|
|
const auto& src_id = state_map_.LookupCondId(src);
|
|
if (!state_map_.IsDead(src_id)) {
|
|
non_dead_edge = e;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (non_dead_edge == nullptr) {
|
|
return errors::InvalidArgument("Merge node ", FormatNodeForError(*node),
|
|
" has no non-dead inputs.");
|
|
}
|
|
state_map_.MarkDead(node);
|
|
VLOG(5) << "removing redundant merge: " << node->name();
|
|
while (!node->out_edges().empty()) {
|
|
const Edge* oe = *node->out_edges().begin();
|
|
Node* dst_node = oe->dst();
|
|
int dst_port = oe->dst_input();
|
|
graph_->RemoveEdge(oe);
|
|
graph_->AddEdge(non_dead_edge->src(),
|
|
dst_port == Graph::kControlSlot
|
|
? Graph::kControlSlot
|
|
: non_dead_edge->src_output(),
|
|
dst_node, dst_port);
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
|
|
// Handle redundant switch nodes. A switch node is considered redundant if
|
|
// the predicate of the switch already holds on the current branch. E.g., if
|
|
// p is the predicate of the switch but p is already known to hold on this
|
|
// branch, then the switch can be removed and the dead state propagated
|
|
// along one. The checking of predicate is based on the exact predicate
|
|
// (rather than boolean equivalence) and aimed at redundant switches as
|
|
// currently generated by gradient code.
|
|
StateMap::CondId dst_id = state_map_.LookupCondId(node);
|
|
if (state_map_.IsDead(dst_id)) return Status::OK();
|
|
|
|
BranchType b;
|
|
OutputTensor pred;
|
|
TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred));
|
|
|
|
// Determine if we are already on a branch where the switch predicate is
|
|
// true/false. Consider both the data and predicate to determine if the
|
|
// node is redundant (skipping over identity node).
|
|
b = state_map_.FindBranchOf(dst_id, pred);
|
|
if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) {
|
|
OutputTensor val;
|
|
const Edge* e;
|
|
TF_RETURN_IF_ERROR(node->input_edge(0, &e));
|
|
val = OutputTensor(e->src(), e->src_output());
|
|
while (IsIdentity(val.node)) {
|
|
TF_RETURN_IF_ERROR(val.node->input_edge(0, &e));
|
|
val = OutputTensor(e->src(), e->src_output());
|
|
}
|
|
b = state_map_.FindBranchOf(dst_id, val);
|
|
if (b != BranchType::kThenBranch && b != BranchType::kElseBranch)
|
|
return Status::OK();
|
|
}
|
|
|
|
VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " "
|
|
<< DebugString(dst_id);
|
|
const Edge* value_edge;
|
|
TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge));
|
|
Node* val_node = value_edge->src();
|
|
int val_port = value_edge->src_output();
|
|
while (!node->out_edges().empty()) {
|
|
auto e = *node->out_edges().begin();
|
|
Node* dst_node = e->dst();
|
|
int dst_input = e->dst_input();
|
|
int switch_branch = e->src_output();
|
|
graph_->RemoveEdge(e);
|
|
if (switch_branch == Graph::kControlSlot) {
|
|
if (IsMerge(dst_node)) {
|
|
auto id_or = JoinCondStatesMerge(dst_node, dst_id,
|
|
state_map_.LookupCondId(dst_node));
|
|
TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
|
|
FormatNodeForError(*dst_node));
|
|
state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
|
|
} else {
|
|
auto id_or =
|
|
JoinCondStatesNonMerge(dst_id, state_map_.LookupCondId(dst_node));
|
|
TF_RETURN_IF_ERROR(id_or.status());
|
|
state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
|
|
}
|
|
} else if (BranchType(switch_branch) != b) {
|
|
state_map_.MarkDead(dst_node);
|
|
continue;
|
|
}
|
|
graph_->AddEdge(
|
|
val_node,
|
|
switch_branch == Graph::kControlSlot ? Graph::kControlSlot : val_port,
|
|
dst_node, dst_input);
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status FunctionalizeCond::DetermineStates(std::vector<Node*> rev_topo_order) {
|
|
// The state that is propagated along the given edge.
|
|
for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) {
|
|
Node* dst = *it;
|
|
TF_RETURN_IF_ERROR(DetermineCondState(dst));
|
|
TF_RETURN_IF_ERROR(DetermineAncestorState(dst));
|
|
if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst));
|
|
if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst));
|
|
|
|
VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst)
|
|
<< " @ " << state_map_.AncestorStateToString(dst);
|
|
if (VLOG_IS_ON(10)) DumpGraphWithCondState("it");
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status FunctionalizeCond::DetermineAncestorState(Node* dst) {
|
|
StateMap::AncestorId id = nullptr;
|
|
StateMap::AncestorState state;
|
|
|
|
auto insert = [&](StateMap::AncestorId id, Node* src) {
|
|
auto other_id = state_map_.LookupAncestorId(src);
|
|
if (other_id != id && other_id != nullptr) {
|
|
state.insert(other_id->begin(), other_id->end());
|
|
}
|
|
if (IsMerge(src)) {
|
|
state.insert({{src, 0}, AncestorNode::AncestorNodeType::kMerge});
|
|
} else if (IsSwitch(src)) {
|
|
OutputTensor pred;
|
|
// For dead switch nodes, GetSwitchPredicate() will fail, and we use
|
|
// the switch node directly as ancestor.
|
|
if (GetSwitchPredicate(*src, &pred).ok()) {
|
|
state.insert({pred, AncestorNode::AncestorNodeType::kPred});
|
|
} else {
|
|
state.insert({{src, 0}, AncestorNode::AncestorNodeType::kSwitch});
|
|
}
|
|
}
|
|
return state_map_.GetAncestorId(state);
|
|
};
|
|
|
|
// Compute the union of all the switch/merge nodes that affects the input of
|
|
// dst.
|
|
for (auto e : dst->in_edges()) {
|
|
Node* src = e->src();
|
|
id = insert(id, src);
|
|
}
|
|
state_map_.ResetAncestorId(dst, id);
|
|
return Status::OK();
|
|
}
|
|
|
|
void FunctionalizeCond::DeleteReachableAndDeadNodes(
|
|
const std::vector<Node*>& merge_order) {
|
|
// Delete all nodes that have been extracted or are reachable from
|
|
// deleted/dead nodes. The input and outgoing edges should have already been
|
|
// removed.
|
|
std::deque<int> delete_nodes;
|
|
std::vector<bool> deleted(graph_->num_node_ids(), false);
|
|
// Don't try to delete source or sink nodes.
|
|
deleted[graph_->kSourceId] = true;
|
|
deleted[graph_->kSinkId] = true;
|
|
|
|
// All remaining Switch nodes are not reachable from a Merge node and
|
|
// removed. This is to account for dead Switch nodes.
|
|
for (int s_id : switch_ids_) {
|
|
Node* s = graph_->FindNodeId(s_id);
|
|
if (s == nullptr) continue;
|
|
for (const Edge* e : s->out_edges()) {
|
|
// Control outputs of switch nodes (which are unconditionally executed if
|
|
// the switch is) are not removed as they need not be part of a
|
|
// conditional.
|
|
if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
|
|
}
|
|
deleted[s_id] = true;
|
|
graph_->RemoveNode(s);
|
|
}
|
|
|
|
// All merge nodes should have been transformed at this point and we remove
|
|
// them from the graph here.
|
|
for (Node* m : merge_order) {
|
|
for (const Edge* e : m->out_edges()) {
|
|
// Similar to control outputs of switch nodes don't remove control
|
|
// outputs of merge nodes.
|
|
// TODO(jpienaar): Check cases where output edges still exist here vs
|
|
// being removed in AddOutputEdges.
|
|
if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
|
|
}
|
|
deleted[m->id()] = true;
|
|
graph_->RemoveNode(m);
|
|
}
|
|
|
|
// Enqueue all the dead nodes.
|
|
for (Node* n : graph_->nodes()) {
|
|
if (state_map_.IsDead(state_map_.LookupCondId(n))) {
|
|
delete_nodes.push_back(n->id());
|
|
}
|
|
}
|
|
|
|
while (!delete_nodes.empty()) {
|
|
int d_id = delete_nodes.front();
|
|
delete_nodes.pop_front();
|
|
if (deleted[d_id]) continue;
|
|
Node* d = graph_->FindNodeId(d_id);
|
|
// Switch and Merge nodes could have been deleted already.
|
|
if (d == nullptr) continue;
|
|
for (const Edge* e : d->out_edges()) {
|
|
delete_nodes.push_back(e->dst()->id());
|
|
}
|
|
deleted[d_id] = true;
|
|
graph_->RemoveNode(d);
|
|
}
|
|
}
|
|
|
|
void FunctionalizeCond::SortMergeNodes(std::vector<Node*>* merge_order) {
|
|
// Sort merge nodes by nesting depth.
|
|
using sort_pair = std::pair<int, Node*>;
|
|
std::vector<sort_pair> inner_to_outer_merge_order;
|
|
inner_to_outer_merge_order.reserve(merge_order->size());
|
|
for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) {
|
|
Node* merge = *it;
|
|
StateMap::CondId id = state_map_.LookupCondId(merge);
|
|
int depth = id != nullptr ? id->size() : 0;
|
|
inner_to_outer_merge_order.emplace_back(depth, merge);
|
|
}
|
|
std::stable_sort(
|
|
inner_to_outer_merge_order.begin(), inner_to_outer_merge_order.end(),
|
|
[](sort_pair lhs, sort_pair rhs) { return lhs.first > rhs.first; });
|
|
merge_order->clear();
|
|
for (sort_pair t : inner_to_outer_merge_order) {
|
|
merge_order->push_back(t.second);
|
|
}
|
|
}
|
|
|
|
Status FunctionalizeCond::FunctionalizeInternal() {
|
|
// The general approach for converting a tf.cond (as lowered via switch/merge
|
|
// nodes) to a functional if is as follows:
|
|
// 1. Determine the topological order and collect all the switch and merge
|
|
// nodes in the graph;
|
|
// 2. Compute the predicates and dominance structure for all the nodes in the
|
|
// graph - this includes which predicate must be true for a op to execute
|
|
// (predicate values are considered directly rather than attempting to
|
|
// determine deeper equivalence). We shall refer to this structure as the
|
|
// CondState;
|
|
// 3. Sort the merge nodes by nesting depth;
|
|
// 4. Extract merge nodes together that have the same CondState and
|
|
// AncestorState from the innermost to the outermost into IfOps;
|
|
// Note: In the above only nodes that feed into a merge node will be
|
|
// considered for functionalization.
|
|
|
|
// Perform a DFS over the graph and
|
|
// * Determine the reverse topological order of the nodes (there should be no
|
|
// cycles at this point so the post-order numbering corresponds to the
|
|
// reverse topological sorting);
|
|
// * Record reverse topological for merge and switch nodes;
|
|
std::vector<Node*> rev_topo_order;
|
|
std::vector<Node*> merge_order;
|
|
DFS(*graph_, nullptr, [&](Node* n) {
|
|
if (IsSwitch(n)) {
|
|
AddSwitchId(n->id());
|
|
}
|
|
if (IsMerge(n)) {
|
|
merge_order.push_back(n);
|
|
}
|
|
if (n->IsOp()) {
|
|
rev_topo_order.push_back(n);
|
|
}
|
|
});
|
|
|
|
// No merges to functionalize.
|
|
if (merge_order.empty()) {
|
|
// No merges mean no switch values consumed (as only considering values
|
|
// fetchable as output of merge);
|
|
DeleteReachableAndDeadNodes(merge_order);
|
|
return Status::OK();
|
|
}
|
|
|
|
TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order)));
|
|
if (VLOG_IS_ON(4)) DumpGraphWithCondState("id");
|
|
|
|
// Determine the shapes of the ops in the graph.
|
|
ShapeRefiner shape_refiner{graph_->versions().producer(),
|
|
graph_->op_registry()};
|
|
std::vector<Node*> nodes;
|
|
GetReversePostOrder(*graph_, &nodes);
|
|
for (auto node : nodes) {
|
|
if (!shape_refiner.AddNode(node).ok()) {
|
|
LOG(WARNING) << "Couldn't deduce shape for " << node->name();
|
|
}
|
|
}
|
|
|
|
// Sort the merge nodes from innermost outwards.
|
|
SortMergeNodes(&merge_order);
|
|
|
|
// Cluster merge nodes by (CondId, AncestorId, predicate) in order of
|
|
// nesting. (CondId, AncestorId) is not enough, e.g.
|
|
// pred1 = array_ops.placeholder(dtypes.bool, name='pred1')
|
|
// pred2 = array_ops.placeholder(dtypes.bool, name='pred2')
|
|
// cond1 = control_flow_ops.cond(pred1, ...)
|
|
// cond2 = control_flow_ops.cond(pred2, ...)
|
|
// cond3 = control_flow_ops.cond(pred1, use cond1 and cond2)
|
|
// cond4 = control_flow_ops.cond(pred2, use cond1 and cond2)
|
|
// cond3 and cond4 have the same (CondId, AncestorId), but they should not
|
|
// be merged into one "If" node (because they have different predicates).
|
|
std::deque<std::vector<Node*>> merge_clusters;
|
|
std::map<ClusterTuple, int, ClusterTupleLessThan> merge_cluster_index;
|
|
for (Node* merge : merge_order) {
|
|
auto cond_id = state_map_.LookupCondId(merge);
|
|
if (state_map_.IsDead(cond_id)) continue;
|
|
|
|
auto predicate = merge_to_predicate_.find(merge);
|
|
if (predicate == merge_to_predicate_.end()) {
|
|
return errors::Internal("Cannot find predicate for Merge node ",
|
|
merge->name());
|
|
}
|
|
|
|
ClusterTuple key = std::make_tuple(
|
|
cond_id, state_map_.LookupAncestorId(merge), predicate->second);
|
|
auto idx = merge_cluster_index.find(key);
|
|
if (idx == merge_cluster_index.end()) {
|
|
merge_cluster_index[key] = merge_clusters.size();
|
|
merge_clusters.push_back({merge});
|
|
} else {
|
|
merge_clusters[idx->second].emplace_back(merge);
|
|
}
|
|
}
|
|
|
|
// Extract the conditionals from inner most to outer most. Extracting from
|
|
// innermost to outermost enables the extraction pass to stop once it
|
|
// encounters a Switch node instead of having to keep track of Switch/Merge
|
|
// nodes seen.
|
|
for (const auto& cluster : merge_clusters) {
|
|
// Construct a Conditional with the predicate of the merge.
|
|
Conditional cond(merge_to_predicate_.at(cluster.front()), this, &state_map_,
|
|
shape_refiner);
|
|
for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge));
|
|
TF_RETURN_IF_ERROR(
|
|
cond.BuildAndReplace(graph_, library_, &merge_to_replacement_));
|
|
|
|
if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract");
|
|
}
|
|
|
|
DeleteReachableAndDeadNodes(merge_order);
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
void FunctionalizeCond::DumpGraphWithCondState(const string& name) {
|
|
const char* const kCondGroupDebugAttr = "_XlaFunctionalizeCondGroup";
|
|
|
|
for (Node* n : graph_->nodes()) {
|
|
n->ClearAttr(kCondGroupDebugAttr);
|
|
n->AddAttr(kCondGroupDebugAttr,
|
|
absl::StrCat(state_map_.CondStateToString(n), "_",
|
|
state_map_.AncestorStateToString(n)));
|
|
}
|
|
LOG(INFO) << "FunctionalizeControlFlow (" << name << "): "
|
|
<< DumpGraphToFile(absl::StrCat("functionalize_cond_", name),
|
|
*graph_, library_);
|
|
}
|
|
|
|
void FunctionalizeCond::AddSwitchId(int switch_id) {
|
|
switch_ids_.push_back(switch_id);
|
|
}
|
|
|
|
Status FunctionalizeCond::Functionalize(Graph* graph,
|
|
FunctionLibraryDefinition* library) {
|
|
VLOG(1) << "FunctionalizeCond::Functionalize";
|
|
FunctionalizeCond fc(graph, library);
|
|
return fc.FunctionalizeInternal();
|
|
}
|
|
|
|
} // namespace functionalize_cond
|
|
|
|
Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) {
|
|
// FunctionalizeControlFlow is invoked for every function, so the loops's
|
|
// bodies and conditionals that were extracted into functions will be handled
|
|
// in successive invocations.
|
|
return functionalize_cond::FunctionalizeCond::Functionalize(graph, library);
|
|
}
|
|
|
|
} // namespace tensorflow
|