Don't cluster nodes that have inputs with mismatching deadness

TensorFlow allows nodes to have some live inputs and some dead inputs.  The
executor does not execute these nodes but instead propagates a dead signal to
all their outputs (i.e. these nodes are treated as fully dead).

This is a problem for auto-clustering because it means auto-clustering can kill
nodes that used to be alive.  For instance say before clustering we have a graph
like

digraph {
  Alive0 -> P
  Alive1 -> Q
  Dead -> R
  P -> X
  Q -> X
  Q -> Y
  R -> Y
}

and we cluster P, Q, R, X and Y into a single XLA cluster.

Then after clustering both X and Y are dead because the cluster is a single node
as far as the executor is concerned and said node won't get scheduled if any of
its inputs are dead.

This CL introduces a static analysis pass that our auto-clustering code can use
to ensure nodes that have inputs with mismatching deadness (like "Y" in the
example graph) are not included in XLA clusters.

PiperOrigin-RevId: 205143316
This commit is contained in:
Sanjoy Das 2018-07-18 15:03:02 -07:00 committed by TensorFlower Gardener
parent a186bcdcb0
commit 6619dd5fdc
9 changed files with 1184 additions and 22 deletions

View File

@ -304,11 +304,13 @@ cc_library(
name = "compilation_passes",
srcs = [
"build_xla_launch_ops_pass.cc",
"deadness_analysis.cc",
"encapsulate_subgraphs_pass.cc",
"mark_for_compilation_pass.cc",
],
hdrs = [
"build_xla_launch_ops_pass.h",
"deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
"mark_for_compilation_pass.h",
],
@ -325,6 +327,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@ -377,6 +380,7 @@ tf_cc_test(
name = "compilation_passes_test",
size = "small",
srcs = [
"deadness_analysis_test.cc",
"encapsulate_subgraphs_pass_test.cc",
"mark_for_compilation_pass_test.cc",
],
@ -387,6 +391,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/compiler/jit/kernels:xla_launch_op",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@ -458,6 +463,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":common",
":compilation_passes",
":union_find",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",

View File

@ -0,0 +1,546 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
// ALGORITHM OVERVIEW
//
// We map every output produced by each node in the TensorFlow graph (including
// control dependence) into an instance of the Predicate class. Instances of
// Predicate denote logical formulas and mapping a node `n` to a predicate
// `pred` implies that `n` is executed whenver `pred` is true. Then we can
// deduce mismatching liveness in the inputs to node by comparing the predicate
// those inputs are mapped to.
//
// Loops are handled pessimistically -- we map Merge nodes with backedges to
// uninterpreted symbols (the same kind we use to represent Switch and _Recv).
// Predicate equality has to hold over all possible assignments to these
// uninterpreted symbols.
namespace tensorflow {
namespace {
// Represents a logical predicate, used as described in the algorithm overview
// above.
class Predicate {
public:
enum class Kind { kAnd, kOr, kNot, kSymbol };
virtual string ToString() const = 0;
virtual bool operator==(const Predicate& other) const = 0;
virtual bool operator!=(const Predicate& other) const {
return !(*this == other);
}
int64 hash() const { return hash_; }
virtual Kind kind() const = 0;
virtual ~Predicate() {}
protected:
explicit Predicate(int64 hash) : hash_(hash) {}
private:
const int64 hash_;
};
int64 HashPredicateSequence(Predicate::Kind kind,
gtl::ArraySlice<Predicate*> preds) {
int64 hash = ::tensorflow::hash<Predicate::Kind>()(kind);
for (Predicate* pred : preds) {
hash = Hash64Combine(hash, pred->hash());
}
return hash;
}
bool PredicateSequenceEqual(gtl::ArraySlice<Predicate*> lhs,
gtl::ArraySlice<Predicate*> rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (int64 i = 0; i < lhs.size(); i++) {
if (*lhs[i] != *rhs[i]) {
return false;
}
}
return true;
}
// Represents a logical conjunction of a set of predicates.
class AndPredicate : public Predicate {
public:
explicit AndPredicate(std::vector<Predicate*> operands)
: Predicate(HashPredicateSequence(Kind::kAnd, operands)),
operands_(std::move(operands)) {}
string ToString() const override {
if (operands().empty()) {
return "#true";
}
std::vector<string> operands_str;
std::transform(operands().begin(), operands().end(),
std::back_inserter(operands_str),
[](Predicate* pred) { return pred->ToString(); });
return strings::StrCat("(", str_util::Join(operands_str, " & "), ")");
}
bool operator==(const Predicate& other) const override {
return other.kind() == Kind::kAnd &&
PredicateSequenceEqual(
dynamic_cast<const AndPredicate&>(other).operands(), operands());
}
Kind kind() const override { return Kind::kAnd; }
const tensorflow::gtl::ArraySlice<Predicate*> operands() const {
return operands_;
}
private:
std::vector<Predicate*> operands_;
};
// Represents a logical disjunction of a set of predicates.
class OrPredicate : public Predicate {
public:
explicit OrPredicate(std::vector<Predicate*> operands)
: Predicate(HashPredicateSequence(Kind::kOr, operands)),
operands_(std::move(operands)) {}
string ToString() const override {
if (operands().empty()) {
return "#false";
}
std::vector<string> operands_str;
std::transform(operands().begin(), operands().end(),
std::back_inserter(operands_str),
[](Predicate* pred) { return pred->ToString(); });
return strings::StrCat("(", str_util::Join(operands_str, " | "), ")");
}
bool operator==(const Predicate& other) const override {
return other.kind() == Kind::kOr &&
PredicateSequenceEqual(
dynamic_cast<const OrPredicate&>(other).operands(), operands());
}
Kind kind() const override { return Kind::kOr; }
const tensorflow::gtl::ArraySlice<Predicate*> operands() const {
return operands_;
}
private:
std::vector<Predicate*> operands_;
};
// Represents a logical negation of a set of predicates.
class NotPredicate : public Predicate {
public:
explicit NotPredicate(Predicate* operand)
: Predicate(HashPredicateSequence(Kind::kNot, {operand})),
operand_(operand) {}
string ToString() const override {
return strings::StrCat("~", operand()->ToString());
}
bool operator==(const Predicate& other) const override {
return other.kind() == Kind::kNot &&
*dynamic_cast<const NotPredicate&>(other).operand() == *operand();
}
Kind kind() const override { return Kind::kNot; }
Predicate* operand() const { return operand_; }
private:
Predicate* operand_;
};
// Represents an uninterpreted symbol in a logical predicate.
//
// Two predicates are equivalent iff they are equivalent for all assignments to
// the symbols contained in them.
class SymbolPredicate : public Predicate {
public:
explicit SymbolPredicate(TensorId tensor_id, bool must_be_true)
: Predicate(Hash(tensor_id, must_be_true)),
tensor_id_(std::move(tensor_id)),
must_be_true_(must_be_true) {}
string ToString() const override { return tensor_id_.ToString(); }
bool operator==(const Predicate& other) const override {
return other.kind() == Kind::kSymbol &&
must_be_true() ==
dynamic_cast<const SymbolPredicate&>(other).must_be_true() &&
dynamic_cast<const SymbolPredicate&>(other).tensor_id() ==
tensor_id();
}
Kind kind() const override { return Kind::kSymbol; }
// If `must_be_true()` is true this SymbolPredicate represents the proposition
// "tensor_id() is live and evaluates to true".
//
// If `must_be_true()` is false then this SymbolPredicate represents the
// proposition "tensor_id() is live (and may evalutate to any value)"
TensorId tensor_id() const { return tensor_id_; }
bool must_be_true() const { return must_be_true_; }
private:
TensorId tensor_id_;
bool must_be_true_;
static int64 Hash(const TensorId tensor_id, bool must_be_true) {
return Hash64Combine(
::tensorflow::hash<bool>()(must_be_true),
Hash64Combine(::tensorflow::hash<Predicate::Kind>()(Kind::kSymbol),
TensorId::Hasher{}(tensor_id)));
}
};
// Creates and owns Predicate instances. Simplifies predicates as it creates
// them.
class PredicateFactory {
public:
Predicate* MakeAndPredicate(gtl::ArraySlice<Predicate*> operands) {
return MakeAndOrImpl(operands, /*is_and=*/true);
}
Predicate* MakeOrPredicate(gtl::ArraySlice<Predicate*> operands) {
return MakeAndOrImpl(operands, /*is_and=*/false);
}
Predicate* MakeNotPredicate(Predicate* pred) {
return Make<NotPredicate>(pred);
}
Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) {
return Make<SymbolPredicate>(tensor_id, must_be_true);
}
Predicate* MakeTrue() { return MakeAndPredicate({}); }
Predicate* MakeFalse() { return MakeOrPredicate({}); }
private:
template <typename PredicateT, typename... Args>
Predicate* Make(Args... args) {
std::unique_ptr<PredicateT> pred(
new PredicateT(std::forward<Args>(args)...));
predicate_storage_.emplace_back(std::move(pred));
return predicate_storage_.back().get();
}
Predicate* MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands, bool is_and);
struct PredicatePtrHash {
size_t operator()(const Predicate* pred) const { return pred->hash(); }
};
struct PredicatePtrEq {
size_t operator()(const Predicate* a, const Predicate* b) const {
return *a == *b;
}
};
using PredicateSet =
gtl::FlatSet<Predicate*, PredicatePtrHash, PredicatePtrEq>;
std::vector<std::unique_ptr<Predicate>> predicate_storage_;
};
// Common code to create AndPredicate or OrPredicate instances.
Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
bool is_and) {
Predicate::Kind pred_kind =
is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
PredicateSet simplified_ops_set;
std::vector<Predicate*> simplified_ops;
for (Predicate* op : operands) {
// Simplify A&A => A and A|A => A.
if (!simplified_ops_set.insert(op).second) {
continue;
}
if (op->kind() == pred_kind) {
// "Inline" the operands of an inner And/Or into the parent And/Or.
gtl::ArraySlice<Predicate*> operands =
is_and ? dynamic_cast<AndPredicate*>(op)->operands()
: dynamic_cast<OrPredicate*>(op)->operands();
for (Predicate* subop : operands) {
if (simplified_ops_set.insert(subop).second) {
simplified_ops.push_back(subop);
}
}
} else {
simplified_ops.push_back(op);
}
}
if (simplified_ops.size() == 1) {
return simplified_ops[0];
}
// Simplify "A&~A=>False" and "A|~A=>True".
PredicateSet negated_ops;
for (Predicate* op : simplified_ops) {
if (op->kind() == Predicate::Kind::kNot) {
negated_ops.insert(dynamic_cast<NotPredicate&>(*op).operand());
}
}
for (Predicate* op : simplified_ops) {
if (negated_ops.count(op)) {
return is_and ? MakeFalse() : MakeTrue();
}
}
std::stable_sort(
simplified_ops.begin(), simplified_ops.end(),
[](Predicate* a, Predicate* b) { return a->hash() < b->hash(); });
return is_and ? Make<AndPredicate>(std::move(simplified_ops))
: Make<OrPredicate>(std::move(simplified_ops));
}
class DeadnessAnalysisImpl : public DeadnessAnalysis {
public:
explicit DeadnessAnalysisImpl(const Graph* graph)
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
Status Populate();
bool HasInputsWithMismatchingDeadness(const Node& node) override;
void Print() const override;
private:
enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
std::vector<Predicate*> GetIncomingPreds(Node* n, EdgeKind edge_kind);
void SetPred(Node* n, int output_idx, Predicate* pred) {
CHECK(
predicate_map_.insert({TensorId(n->name(), output_idx), pred}).second);
}
void SetPred(Node* n, gtl::ArraySlice<int> output_idxs, Predicate* pred) {
for (int output_idx : output_idxs) {
SetPred(n, output_idx, pred);
}
}
Status HandleSwitch(Node* n);
Status HandleMerge(Node* n);
Status HandleRecv(Node* n);
Status HandleGeneric(Node* n);
const Graph& graph_;
gtl::FlatMap<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
PredicateFactory predicate_factory_;
bool vlog_;
};
TensorId InputEdgeToTensorId(const Edge* e) {
return TensorId(e->src()->name(), e->src_output());
}
std::vector<Predicate*> DeadnessAnalysisImpl::GetIncomingPreds(
Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind) {
std::vector<Predicate*> incoming_preds;
for (const Edge* in_edge : n->in_edges()) {
bool should_process =
edge_kind == EdgeKind::kDataAndControl ||
(in_edge->IsControlEdge() && edge_kind == EdgeKind::kControlOnly) ||
(!in_edge->IsControlEdge() && edge_kind == EdgeKind::kDataOnly);
if (should_process) {
auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
CHECK(it != predicate_map_.end());
incoming_preds.push_back(it->second);
}
}
return incoming_preds;
}
Status DeadnessAnalysisImpl::HandleSwitch(Node* n) {
std::vector<Predicate*> input_preds =
GetIncomingPreds(n, EdgeKind::kDataAndControl);
const Edge* pred_edge;
TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
Predicate* true_switch = predicate_factory_.MakeSymbolPredicate(
TensorId(pred_edge->src()->name(), pred_edge->src_output()),
/*must_be_true=*/true);
Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch);
// Output 0 is alive iff all inputs are alive and the condition is false.
input_preds.push_back(false_switch);
SetPred(n, 0, predicate_factory_.MakeAndPredicate(input_preds));
input_preds.pop_back();
// Output 1 is alive iff all inputs are alive and the condition is true.
input_preds.push_back(true_switch);
SetPred(n, 1, predicate_factory_.MakeAndPredicate(input_preds));
input_preds.pop_back();
// Control is alive iff any inputs are alive.
SetPred(n, Graph::kControlSlot,
predicate_factory_.MakeAndPredicate(input_preds));
return Status::OK();
}
Status DeadnessAnalysisImpl::HandleMerge(Node* n) {
// Merge ignores deadness of its control inputs. A merge that isn't the
// target of a backedge has is alive iff any of its data inputs are. We treat
// the liveness of a merge that is the target of a backedge symbolically.
bool has_backedge = std::any_of(
n->in_edges().begin(), n->in_edges().end(), [](const Edge* e) {
return !e->IsControlEdge() && e->src()->IsNextIteration();
});
Predicate* input_data_pred =
has_backedge ? predicate_factory_.MakeSymbolPredicate(
TensorId(n->name(), 0), /*must_be_true=*/false)
: predicate_factory_.MakeOrPredicate(
GetIncomingPreds(n, EdgeKind::kDataOnly));
SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred);
return Status::OK();
}
Status DeadnessAnalysisImpl::HandleRecv(Node* n) {
// In addition to being alive or dead based on the inputs, a _Recv can also
// acquire a dead signal from a _Send.
std::vector<Predicate*> input_preds =
GetIncomingPreds(n, EdgeKind::kDataAndControl);
input_preds.push_back(predicate_factory_.MakeSymbolPredicate(
TensorId(n->name(), 0), /*must_be_true=*/false));
SetPred(n, {0, Graph::kControlSlot},
predicate_factory_.MakeAndPredicate(input_preds));
return Status::OK();
}
Status DeadnessAnalysisImpl::HandleGeneric(Node* n) {
// Generally nodes are alive iff all their inputs are alive.
Predicate* pred = predicate_factory_.MakeAndPredicate(
GetIncomingPreds(n, EdgeKind::kDataAndControl));
for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) {
SetPred(n, output_idx, pred);
}
SetPred(n, Graph::kControlSlot, pred);
return Status::OK();
}
Status DeadnessAnalysisImpl::Populate() {
std::vector<Node*> rpo;
GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/{},
/*edge_filter=*/[](const Edge& edge) {
return !edge.src()->IsNextIteration();
});
// This an abstract interpretation over the deadness propagation semantics of
// the graph executor.
for (Node* n : rpo) {
if (n->IsSwitch()) {
TF_RETURN_IF_ERROR(HandleSwitch(n));
} else if (n->IsMerge()) {
TF_RETURN_IF_ERROR(HandleMerge(n));
} else if (n->IsControlTrigger()) {
SetPred(n, Graph::kControlSlot, predicate_factory_.MakeTrue());
} else if (n->IsRecv() || n->IsHostRecv()) {
TF_RETURN_IF_ERROR(HandleRecv(n));
} else {
TF_RETURN_IF_ERROR(HandleGeneric(n));
}
}
return Status::OK();
}
bool DeadnessAnalysisImpl::HasInputsWithMismatchingDeadness(const Node& node) {
CHECK(!node.IsMerge());
if (vlog_) {
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() << ")";
}
Predicate* pred = nullptr;
for (const Edge* edge : node.in_edges()) {
auto it = predicate_map_.find(InputEdgeToTensorId(edge));
CHECK(it != predicate_map_.end());
if (vlog_) {
VLOG(2) << " " << InputEdgeToTensorId(edge).ToString() << ": "
<< it->second->ToString();
}
// Today we just compare the predicates for equality (with some
// canonicalization/simplification happening before) but we could be more
// sophisticated here if need be.
if (pred != nullptr && *pred != *it->second) {
if (vlog_) {
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
<< ") -> true";
}
return true;
}
pred = it->second;
}
if (vlog_) {
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
<< ") -> false";
}
return false;
}
void DeadnessAnalysisImpl::Print() const {
std::vector<TensorId> tensor_ids;
for (const auto& kv_pair : predicate_map_) {
tensor_ids.push_back(kv_pair.first);
}
std::sort(tensor_ids.begin(), tensor_ids.end());
for (TensorId tensor_id : tensor_ids) {
auto it = predicate_map_.find(tensor_id);
CHECK(it != predicate_map_.end()) << tensor_id.ToString();
VLOG(2) << tensor_id.ToString() << " -> " << it->second->ToString();
}
}
} // namespace
DeadnessAnalysis::~DeadnessAnalysis() {}
/*static*/ Status DeadnessAnalysis::Run(
const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) {
std::unique_ptr<DeadnessAnalysisImpl> analysis(
new DeadnessAnalysisImpl(&graph));
TF_RETURN_IF_ERROR(analysis->Populate());
if (VLOG_IS_ON(2)) {
analysis->Print();
}
*result = std::move(analysis);
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,68 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
// This analyzes a TensorFlow graph to identify nodes which may have partially
// dead inputs (i.e. these nodes may have some dead inputs and some alive
// inputs).
//
// For example, the ADD node in the following graph
//
// V0 PRED0 V1 PRED1
// | | | |
// v v v v
// SWITCH SWITCH
// | |
// +---+ + ---+
// | |
// v v
// ADD
//
// can have its inputs independently dead or alive based on the runtime values
// of PRED0 and PRED1.
//
// It is tempting to call this a liveness analysis but I avoided that because
// "liveness" already has other connotations.
class DeadnessAnalysis {
public:
// Returns true if `node` may have some live inputs and some dead inputs.
//
// This is a conservatively correct routine -- if it returns false then `node`
// is guaranteed to not have inputs with mismatching liveness, but not the
// converse.
//
// REQUIRES: node is not a Merge operation.
virtual bool HasInputsWithMismatchingDeadness(const Node& node) = 0;
// Prints out the internal state of this instance. For debugging purposes
// only.
virtual void Print() const = 0;
virtual ~DeadnessAnalysis();
// Run the deadness analysis over `graph` and returns an error or a populated
// instance of DeadnessAnalysis in `result`.
static Status Run(const Graph& graph,
std::unique_ptr<DeadnessAnalysis>* result);
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_

View File

@ -0,0 +1,443 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
Status AnalyzeDeadness(Graph* graph,
std::unique_ptr<DeadnessAnalysis>* result) {
FixupSourceAndSinkEdges(graph);
return DeadnessAnalysis::Run(*graph, result);
}
ops::Switch CreateSwitch(const Scope& root, const string& prefix) {
Output value = ops::Placeholder(root.WithOpName(prefix + "/value"), DT_FLOAT);
Output predicate =
ops::Placeholder(root.WithOpName(prefix + "/pred"), DT_BOOL);
return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate);
}
Output CreateInductionVariable(const Scope& root, const string& prefix,
const string& frame_name, int32 init) {
Output initial_value = ops::Const(root.WithOpName(prefix + "/init"), init);
Output enter_initial_value = ops::internal::Enter(
root.WithOpName(prefix + "/enter"), initial_value, frame_name);
ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_initial_value});
Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1);
Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10);
Output loop_cond_expr =
ops::Less(root.WithOpName(prefix + "/less"), iv.output, final_value);
Output loop_cond =
ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr);
ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output);
Output iv_next =
ops::Add(root.WithOpName(prefix + "/ivnext"), iv.output, increment_by);
Output next_iteration =
ops::NextIteration(root.WithOpName(prefix + "next_iteration"), iv_next);
root.graph()->AddEdge(next_iteration.node(), 0, iv.output.node(), 1);
root.graph()->AddControlEdge(iv.output.node(), increment_by.node());
root.graph()->AddControlEdge(iv.output.node(), final_value.node());
return iv.output;
}
TEST(DeadnessAnalysisTest, BasicPositive) {
Scope root = Scope::NewRootScope().ExitOnError();
ops::Switch sw = CreateSwitch(root, "0");
Output add =
ops::Add(root.WithOpName("add"), sw.output_true, sw.output_false);
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
}
TEST(DeadnessAnalysisTest, BasicNegative) {
Scope root = Scope::NewRootScope().ExitOnError();
Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
Output b = ops::Placeholder(root.WithOpName("b"), DT_FLOAT);
Output add = ops::Add(root.WithOpName("add"), a, b);
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
}
TEST(DeadnessAnalysisTest, AndIsCommutative) {
Scope root = Scope::NewRootScope().ExitOnError();
ops::Switch sw_0 = CreateSwitch(root, "0");
ops::Switch sw_1 = CreateSwitch(root, "1");
Output a0 =
ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
Output a1 =
ops::Add(root.WithOpName("a1"), sw_1.output_false, sw_0.output_false);
Output b0 =
ops::Add(root.WithOpName("b0"), sw_0.output_false, sw_1.output_true);
Output b1 =
ops::Add(root.WithOpName("b1"), sw_1.output_true, sw_0.output_false);
Output live0 = ops::Add(root.WithOpName("live0"), a0, a1);
Output live1 = ops::Add(root.WithOpName("live1"), b0, b1);
Output halfdead0 = ops::Add(root.WithOpName("halfdead0"), a0, b0);
Output halfdead1 = ops::Add(root.WithOpName("halfdead1"), a1, b1);
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
}
TEST(DeadnessAnalysisTest, AndIsAssociative) {
Scope root = Scope::NewRootScope().ExitOnError();
ops::Switch sw_0 = CreateSwitch(root, "0");
ops::Switch sw_1 = CreateSwitch(root, "1");
ops::Switch sw_2 = CreateSwitch(root, "2");
Output a0 =
ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
Output a1 = ops::Add(root.WithOpName("a1"), a0, sw_2.output_false);
Output b0 =
ops::Add(root.WithOpName("b0"), sw_1.output_false, sw_2.output_false);
Output b1 = ops::Add(root.WithOpName("b1"), sw_0.output_false, b0);
Output add = ops::Add(root.WithOpName("add"), a1, b1);
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
}
TEST(DeadnessAnalysisTest, OrIsCommutative) {
Scope root = Scope::NewRootScope().ExitOnError();
ops::Switch sw_0 = CreateSwitch(root, "0");
ops::Switch sw_1 = CreateSwitch(root, "1");
ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
ops::Merge m1(root.WithOpName("m1"), {sw_1.output_false, sw_0.output_false});
ops::Merge m2(root.WithOpName("m2"), {sw_0.output_false, sw_1.output_true});
ops::Merge m3(root.WithOpName("m3"), {sw_1.output_true, sw_0.output_false});
Output live0 = ops::Add(root.WithOpName("live0"), m0.output, m1.output);
Output live1 = ops::Add(root.WithOpName("live1"), m2.output, m3.output);
Output halfdead0 =
ops::Add(root.WithOpName("halfdead0"), m0.output, m2.output);
Output halfdead1 =
ops::Add(root.WithOpName("halfdead1"), m1.output, m3.output);
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
}
TEST(DeadnessAnalysisTest, OrIsAssociative) {
Scope root = Scope::NewRootScope().ExitOnError();
ops::Switch sw_0 = CreateSwitch(root, "0");
ops::Switch sw_1 = CreateSwitch(root, "1");
ops::Switch sw_2 = CreateSwitch(root, "2");
ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
ops::Merge m1(root.WithOpName("m1"), {m0.output, sw_2.output_false});
ops::Merge m2(root.WithOpName("m2"), {sw_1.output_false, sw_2.output_false});
ops::Merge m3(root.WithOpName("m3"), {sw_0.output_false, m2.output});
Output add = ops::Add(root.WithOpName("add"), m1.output, m3.output);
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
}
TEST(DeadnessAnalysisTest, AndOfOr) {
Scope root = Scope::NewRootScope().ExitOnError();
ops::Switch sw_0 = CreateSwitch(root, "0");
ops::Switch sw_1 = CreateSwitch(root, "1");
ops::Switch sw_2 = CreateSwitch(root, "2");
ops::Switch sw_3 = CreateSwitch(root, "3");
ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
ops::Merge m1(root.WithOpName("m1"), {sw_2.output_false, sw_3.output_false});
Output add0 = ops::Add(root.WithOpName("add0"), m0.output, m1.output);
Output add1 = ops::Add(root.WithOpName("add1"), m0.output, m1.output);
Output add2 = ops::Add(root.WithOpName("add2"), add0, add1);
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
}
TEST(DeadnessAnalysisTest, OrOfAnd) {
Scope root = Scope::NewRootScope().ExitOnError();
ops::Switch sw_0 = CreateSwitch(root, "0");
ops::Switch sw_1 = CreateSwitch(root, "1");
ops::Switch sw_2 = CreateSwitch(root, "2");
ops::Switch sw_3 = CreateSwitch(root, "3");
Output add0 =
ops::Add(root.WithOpName("add0"), sw_0.output_false, sw_1.output_false);
Output add1 =
ops::Add(root.WithOpName("add1"), sw_2.output_false, sw_3.output_false);
ops::Merge m0(root.WithOpName("m0"), {add0, add1});
ops::Merge m1(root.WithOpName("m1"), {add0, add1});
Output add2 = ops::Add(root.WithOpName("add2"), m0.output, m1.output);
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
}
TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) {
// This demonstrates one of the weaknesses in the current approach -- since we
// only do some basic simplifications we can't see that "(A|B)&C" ==
// "(A&C)|(B&C)".
Scope root = Scope::NewRootScope().ExitOnError();
ops::Switch sw_0 = CreateSwitch(root, "0");
ops::Switch sw_1 = CreateSwitch(root, "1");
ops::Switch sw_2 = CreateSwitch(root, "2");
ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
Output add0 = ops::Add(root.WithOpName("add0"), m0.output, sw_2.output_false);
Output add1 =
ops::Add(root.WithOpName("add1"), sw_0.output_false, sw_2.output_false);
Output add2 =
ops::Add(root.WithOpName("add2"), sw_1.output_false, sw_2.output_false);
ops::Merge m1(root.WithOpName("m1"), {add1, add2});
Output add3 = ops::Add(root.WithOpName("add3"), add0, m1.output);
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node()));
}
TEST(DeadnessAnalysisTest, Ternary) {
Scope root = Scope::NewRootScope().ExitOnError();
Output predicate = ops::Placeholder(root.WithOpName("predicate"), DT_BOOL);
Output true_value = ops::Placeholder(root.WithOpName("true_value"), DT_FLOAT);
Output false_value =
ops::Placeholder(root.WithOpName("false_value"), DT_FLOAT);
ops::Switch predicated_true(root.WithOpName("predicated_true"), true_value,
predicate);
ops::Switch predicated_false(root.WithOpName("predicated_false"), true_value,
predicate);
ops::Merge merge(root.WithOpName("ternary"), {predicated_true.output_true,
predicated_false.output_false});
Output addend = ops::Placeholder(root.WithOpName("addend"), DT_FLOAT);
Output add = ops::Add(root.WithOpName("add"), merge.output, addend);
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
}
TEST(DeadnessAnalysisTest, Recv) {
Scope root = Scope::NewRootScope().ExitOnError();
Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_FLOAT, "tensor_a",
"sender", 0, "receiver");
Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_FLOAT, "tensor_b",
"sender", 0, "receiver");
Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
}
TEST(DeadnessAnalysisTest, HostRecv) {
Scope root = Scope::NewRootScope().ExitOnError();
Output recv_a = ops::_HostRecv(root.WithOpName("recv_a"), DT_FLOAT,
"tensor_a", "sender", 0, "receiver");
Output recv_b = ops::_HostRecv(root.WithOpName("recv_b"), DT_FLOAT,
"tensor_b", "sender", 0, "receiver");
Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
}
TEST(DeadnessAnalysisTest, Loop) {
Scope root = Scope::NewRootScope().ExitOnError();
Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0);
Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0);
Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1);
Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1);
Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2);
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
// NB! iv0 and iv1 are equivalent and a smarter deadness analysis would have
// noticed that. Today we are pessimistic here because we assign an
// uninterpreted symbol to merges with backedges.
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node()));
}
TEST(DeadnessAnalysisTest, ControlInputs) {
Scope root = Scope::NewRootScope().ExitOnError();
ops::Switch sw = CreateSwitch(root, "0");
Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
Output const0 = ops::Const(root.WithOpName("const0"), 1);
Output const1 = ops::Const(root.WithOpName("const1"), 2);
Output add = ops::Add(root.WithOpName("add"), const0, const1);
root.graph()->AddControlEdge(id0.node(), const0.node());
root.graph()->AddControlEdge(id1.node(), const1.node());
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
}
TEST(DeadnessAnalysisTest, ControlTrigger) {
Scope root = Scope::NewRootScope().ExitOnError();
ops::Switch sw = CreateSwitch(root, "0");
Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
ops::ControlTrigger ctrl_trigger0(root.WithOpName("ctrl_trigger0"));
ops::ControlTrigger ctrl_trigger1(root.WithOpName("ctrl_trigger1"));
Output const0 = ops::Const(root.WithOpName("const0"), 1);
Output const1 = ops::Const(root.WithOpName("const1"), 2);
Output add = ops::Add(root.WithOpName("add"), const0, const1);
root.graph()->AddControlEdge(id0.node(), ctrl_trigger0.operation.node());
root.graph()->AddControlEdge(ctrl_trigger0.operation.node(), const0.node());
root.graph()->AddControlEdge(id1.node(), ctrl_trigger1.operation.node());
root.graph()->AddControlEdge(ctrl_trigger1.operation.node(), const1.node());
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
}
TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
Scope root = Scope::NewRootScope().ExitOnError();
ops::Switch sw = CreateSwitch(root, "0");
Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
Output constant = ops::Const(root.WithOpName("constant"), 5);
ops::Merge m0(root.WithOpName("m0"), {constant});
ops::Merge m1(root.WithOpName("m0"), {constant});
Output add = ops::Add(root.WithOpName("add"), m0.output, m1.output);
root.graph()->AddControlEdge(id0.node(), m0.output.node());
root.graph()->AddControlEdge(id1.node(), m1.output.node());
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
}
TEST(DeadnessAnalysisTest, RecvVsSwitch) {
// Demonstrates why we need the must_be_true bit on SymbolP.
Scope root = Scope::NewRootScope().ExitOnError();
Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
0, "receiver");
Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
ops::Switch sw(root.WithOpName("switch"), value, recv);
Output logical_and =
ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node()));
}
} // namespace
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
@ -28,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/memory_types.h"
@ -462,17 +464,27 @@ Status MarkForCompilationPass::Run(
VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
const FunctionLibraryDefinition* fld = options.flib_def;
auto is_compilable = [global_jit_level, cpu_global_jit, fusion_only, fld](
const Node* node, const DeviceType& device_type) {
std::unique_ptr<DeadnessAnalysis> deadness;
{
XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 0);
TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness));
}
auto is_compilable = [&](const Node* node, const DeviceType& device_type) {
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
&registration)) {
return false;
}
// TODO(b/111570009): This bailout for ControlTrigger is probably not
// needed.
//
// Don't compile control trigger nodes. We won't preserve their deadness
// semantics correctly, so it's safest not to compile them.
if (node->IsControlTrigger()) return false;
if (node->IsControlTrigger()) {
return false;
}
// If this device requires a JIT, we must say yes.
if (registration->requires_compilation) return true;
@ -485,6 +497,14 @@ Status MarkForCompilationPass::Run(
status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
if (status.ok()) return compile;
// If inputs to `node` can have conflicting deadness (i.e. some are alive
// and some are dead) then don't compile it. XLA cannot represent the
// deadness semantics of these nodes correctly and auto-clustering these
// nodes can cause deadness propagate to nodes that should be live.
if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
return false;
}
// Check for fusable ops only if requested.
if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) {
return false;

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/union_find.h"
@ -146,6 +147,9 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
TF_RETURN_IF_ERROR(
ImportGraphDef(options, item.graph, &graph, &shape_refiner));
std::unique_ptr<DeadnessAnalysis> deadness;
TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(graph, &deadness));
// Collect nodes that can be fused via XLA, while ignoring those that
// explicitly ask for XLA: (*) nodes that are marked to be compiled
// explicitly. (*) nodes assigned to XLA device.
@ -185,6 +189,14 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
continue;
}
// If inputs to `node` can have conflicting deadness (i.e. some are alive
// and some are dead) then don't compile it. XLA cannot represent the
// deadness semantics of these nodes correctly and auto-clustering these
// nodes can cause deadness propagate to nodes that should be live.
if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
continue;
}
compilation_candidates.insert(node);
}

View File

@ -25,7 +25,8 @@ namespace tensorflow {
void DFS(const Graph& g, const std::function<void(Node*)>& enter,
const std::function<void(Node*)>& leave,
const NodeComparator& stable_comparator) {
const NodeComparator& stable_comparator,
const EdgeFilter& edge_filter) {
// Stack of work to do.
struct Work {
Node* node;
@ -52,7 +53,6 @@ void DFS(const Graph& g, const std::function<void(Node*)>& enter,
// Arrange to call leave(n) when all done with descendants.
if (leave) stack.push_back(Work{n, true});
gtl::iterator_range<NeighborIter> nodes = n->out_nodes();
auto add_work = [&visited, &stack](Node* out) {
if (!visited[out->id()]) {
// Note; we must not mark as visited until we actually process it.
@ -62,16 +62,20 @@ void DFS(const Graph& g, const std::function<void(Node*)>& enter,
if (stable_comparator) {
std::vector<Node*> nodes_sorted;
for (Node* out : nodes) {
nodes_sorted.emplace_back(out);
for (const Edge* out_edge : n->out_edges()) {
if (!edge_filter || edge_filter(*out_edge)) {
nodes_sorted.emplace_back(out_edge->dst());
}
}
std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
for (Node* out : nodes_sorted) {
add_work(out);
}
} else {
for (Node* out : nodes) {
add_work(out);
for (const Edge* out_edge : n->out_edges()) {
if (!edge_filter || edge_filter(*out_edge)) {
add_work(out_edge->dst());
}
}
}
}
@ -118,8 +122,6 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
// Arrange to call leave(n) when all done with descendants.
if (leave) stack.push_back(Work{n, true});
gtl::iterator_range<NeighborIter> nodes = n->in_nodes();
auto add_work = [&visited, &stack](T out) {
if (!visited[out->id()]) {
// Note; we must not mark as visited until we actually process it.
@ -129,16 +131,16 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
if (stable_comparator) {
std::vector<T> nodes_sorted;
for (T in : nodes) {
nodes_sorted.emplace_back(in);
for (const Edge* in_edge : n->in_edges()) {
nodes_sorted.emplace_back(in_edge->src());
}
std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
for (T in : nodes_sorted) {
add_work(in);
}
} else {
for (T in : nodes) {
add_work(in);
for (const Edge* in_edge : n->in_edges()) {
add_work(in_edge->src());
}
}
}
@ -161,14 +163,17 @@ void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
}
void GetPostOrder(const Graph& g, std::vector<Node*>* order,
const NodeComparator& stable_comparator) {
const NodeComparator& stable_comparator,
const EdgeFilter& edge_filter) {
order->clear();
DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator);
DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator,
edge_filter);
}
void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
const NodeComparator& stable_comparator) {
GetPostOrder(g, order, stable_comparator);
const NodeComparator& stable_comparator,
const EdgeFilter& edge_filter) {
GetPostOrder(g, order, stable_comparator, edge_filter);
std::reverse(order->begin(), order->end());
}

View File

@ -28,6 +28,8 @@ namespace tensorflow {
// Comparator for two nodes. This is used in order to get a stable ording.
using NodeComparator = std::function<bool(const Node*, const Node*)>;
using EdgeFilter = std::function<bool(const Edge&)>;
// Compares two node based on their ids.
struct NodeComparatorID {
bool operator()(const Node* n1, const Node* n2) const {
@ -47,9 +49,11 @@ struct NodeComparatorName {
// If leave is not empty, calls leave(n) after visiting all children of n.
// If stable_comparator is set, a stable ordering of visit is achieved by
// sorting a node's neighbors first before visiting them.
// If edge_filter is set then ignores edges for which edge_filter returns false.
extern void DFS(const Graph& g, const std::function<void(Node*)>& enter,
const std::function<void(Node*)>& leave,
const NodeComparator& stable_comparator = {});
const NodeComparator& stable_comparator = {},
const EdgeFilter& edge_filter = {});
// Perform a reverse depth-first-search on g starting at the sink node.
// If enter is not empty, calls enter(n) before visiting any parents of n.
@ -83,15 +87,21 @@ extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
// If stable_comparator is set, a stable ordering of visit is achieved by
// sorting a node's neighbors first before visiting them.
//
// If edge_filter is set then ignores edges for which edge_filter returns false.
//
// REQUIRES: order is not NULL.
void GetPostOrder(const Graph& g, std::vector<Node*>* order,
const NodeComparator& stable_comparator = {});
const NodeComparator& stable_comparator = {},
const EdgeFilter& edge_filter = {});
// Stores in *order the reverse post-order numbering of all nodes
// If stable_comparator is set, a stable ordering of visit is achieved by
// sorting a node's neighbors first before visiting them.
//
// If edge_filter is set then ignores edges for which edge_filter returns false.
void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
const NodeComparator& stable_comparator = {});
const NodeComparator& stable_comparator = {},
const EdgeFilter& edge_filter = {});
// Prune nodes in "g" that are not in some path from the source node
// to any node in 'nodes'. Returns true if changes were made to the graph.

View File

@ -36,6 +36,11 @@ namespace {
REGISTER_OP("TestParams").Output("o: float");
REGISTER_OP("TestInput").Output("a: float").Output("b: float");
REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
REGISTER_OP("TestUnary").Input("a: float").Output("o: float");
REGISTER_OP("TestBinary")
.Input("a: float")
.Input("b: float")
.Output("o: float");
// Compares that the order of nodes in 'inputs' respects the
// pair orders described in 'ordered_pairs'.
@ -148,5 +153,52 @@ TEST(AlgorithmTest, ReversePostOrderStable) {
EXPECT_TRUE(ExpectBefore({{"t2", "t3"}}, order, &error));
}
}
TEST(AlgorithmTest, PostOrderWithEdgeFilter) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
string error;
Node* n0 = ops::SourceOp("TestParams", b.opts().WithName("n0"));
Node* n1 = ops::UnaryOp("TestUnary", n0, b.opts().WithName("n1"));
Node* n2 = ops::UnaryOp("TestUnary", n1, b.opts().WithName("n2"));
Node* n3 = ops::BinaryOp("TestBinary", n2, n0, b.opts().WithName("n3"));
Graph g(OpRegistry::Global());
TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
g.AddEdge(g.FindNodeId(n3->id()), 0, g.FindNodeId(n1->id()), 1);
std::vector<Node*> post_order;
auto edge_filter = [&](const Edge& e) {
return !(e.src()->id() == n3->id() && e.dst()->id() == n1->id());
};
std::vector<Node*> expected_post_order = {
g.sink_node(), g.FindNodeId(n3->id()), g.FindNodeId(n2->id()),
g.FindNodeId(n1->id()), g.FindNodeId(n0->id()), g.source_node()};
std::vector<Node*> expected_reverse_post_order = expected_post_order;
std::reverse(expected_reverse_post_order.begin(),
expected_reverse_post_order.end());
GetPostOrder(g, &post_order, /*stable_comparator=*/{},
/*edge_filter=*/edge_filter);
ASSERT_EQ(expected_post_order.size(), post_order.size());
for (int i = 0; i < post_order.size(); i++) {
CHECK_EQ(post_order[i], expected_post_order[i])
<< post_order[i]->name() << " vs. " << expected_post_order[i]->name();
}
std::vector<Node*> reverse_post_order;
GetReversePostOrder(g, &reverse_post_order, /*stable_comparator=*/{},
/*edge_filter=*/edge_filter);
ASSERT_EQ(expected_reverse_post_order.size(), reverse_post_order.size());
for (int i = 0; i < reverse_post_order.size(); i++) {
CHECK_EQ(reverse_post_order[i], expected_reverse_post_order[i])
<< reverse_post_order[i]->name() << " vs. "
<< expected_reverse_post_order[i]->name();
}
}
} // namespace
} // namespace tensorflow