Don't cluster nodes that have inputs with mismatching deadness
TensorFlow allows nodes to have some live inputs and some dead inputs. The executor does not execute these nodes but instead propagates a dead signal to all their outputs (i.e. these nodes are treated as fully dead). This is a problem for auto-clustering because it means auto-clustering can kill nodes that used to be alive. For instance say before clustering we have a graph like digraph { Alive0 -> P Alive1 -> Q Dead -> R P -> X Q -> X Q -> Y R -> Y } and we cluster P, Q, R, X and Y into a single XLA cluster. Then after clustering both X and Y are dead because the cluster is a single node as far as the executor is concerned and said node won't get scheduled if any of its inputs are dead. This CL introduces a static analysis pass that our auto-clustering code can use to ensure nodes that have inputs with mismatching deadness (like "Y" in the example graph) are not included in XLA clusters. PiperOrigin-RevId: 205143316
This commit is contained in:
parent
a186bcdcb0
commit
6619dd5fdc
@ -304,11 +304,13 @@ cc_library(
|
||||
name = "compilation_passes",
|
||||
srcs = [
|
||||
"build_xla_launch_ops_pass.cc",
|
||||
"deadness_analysis.cc",
|
||||
"encapsulate_subgraphs_pass.cc",
|
||||
"mark_for_compilation_pass.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"build_xla_launch_ops_pass.h",
|
||||
"deadness_analysis.h",
|
||||
"encapsulate_subgraphs_pass.h",
|
||||
"mark_for_compilation_pass.h",
|
||||
],
|
||||
@ -325,6 +327,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:dump_graph",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
@ -377,6 +380,7 @@ tf_cc_test(
|
||||
name = "compilation_passes_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"deadness_analysis_test.cc",
|
||||
"encapsulate_subgraphs_pass_test.cc",
|
||||
"mark_for_compilation_pass_test.cc",
|
||||
],
|
||||
@ -387,6 +391,7 @@ tf_cc_test(
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:sendrecv_ops",
|
||||
"//tensorflow/compiler/jit/kernels:xla_launch_op",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
@ -458,6 +463,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":common",
|
||||
":compilation_passes",
|
||||
":union_find",
|
||||
":xla_cluster_util",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
|
546
tensorflow/compiler/jit/deadness_analysis.cc
Normal file
546
tensorflow/compiler/jit/deadness_analysis.cc
Normal file
@ -0,0 +1,546 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
|
||||
// ALGORITHM OVERVIEW
|
||||
//
|
||||
// We map every output produced by each node in the TensorFlow graph (including
|
||||
// control dependence) into an instance of the Predicate class. Instances of
|
||||
// Predicate denote logical formulas and mapping a node `n` to a predicate
|
||||
// `pred` implies that `n` is executed whenver `pred` is true. Then we can
|
||||
// deduce mismatching liveness in the inputs to node by comparing the predicate
|
||||
// those inputs are mapped to.
|
||||
//
|
||||
// Loops are handled pessimistically -- we map Merge nodes with backedges to
|
||||
// uninterpreted symbols (the same kind we use to represent Switch and _Recv).
|
||||
// Predicate equality has to hold over all possible assignments to these
|
||||
// uninterpreted symbols.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// Represents a logical predicate, used as described in the algorithm overview
|
||||
// above.
|
||||
class Predicate {
|
||||
public:
|
||||
enum class Kind { kAnd, kOr, kNot, kSymbol };
|
||||
|
||||
virtual string ToString() const = 0;
|
||||
virtual bool operator==(const Predicate& other) const = 0;
|
||||
virtual bool operator!=(const Predicate& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
int64 hash() const { return hash_; }
|
||||
|
||||
virtual Kind kind() const = 0;
|
||||
virtual ~Predicate() {}
|
||||
|
||||
protected:
|
||||
explicit Predicate(int64 hash) : hash_(hash) {}
|
||||
|
||||
private:
|
||||
const int64 hash_;
|
||||
};
|
||||
|
||||
int64 HashPredicateSequence(Predicate::Kind kind,
|
||||
gtl::ArraySlice<Predicate*> preds) {
|
||||
int64 hash = ::tensorflow::hash<Predicate::Kind>()(kind);
|
||||
for (Predicate* pred : preds) {
|
||||
hash = Hash64Combine(hash, pred->hash());
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
bool PredicateSequenceEqual(gtl::ArraySlice<Predicate*> lhs,
|
||||
gtl::ArraySlice<Predicate*> rhs) {
|
||||
if (lhs.size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
for (int64 i = 0; i < lhs.size(); i++) {
|
||||
if (*lhs[i] != *rhs[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Represents a logical conjunction of a set of predicates.
|
||||
class AndPredicate : public Predicate {
|
||||
public:
|
||||
explicit AndPredicate(std::vector<Predicate*> operands)
|
||||
: Predicate(HashPredicateSequence(Kind::kAnd, operands)),
|
||||
operands_(std::move(operands)) {}
|
||||
|
||||
string ToString() const override {
|
||||
if (operands().empty()) {
|
||||
return "#true";
|
||||
}
|
||||
|
||||
std::vector<string> operands_str;
|
||||
std::transform(operands().begin(), operands().end(),
|
||||
std::back_inserter(operands_str),
|
||||
[](Predicate* pred) { return pred->ToString(); });
|
||||
|
||||
return strings::StrCat("(", str_util::Join(operands_str, " & "), ")");
|
||||
}
|
||||
|
||||
bool operator==(const Predicate& other) const override {
|
||||
return other.kind() == Kind::kAnd &&
|
||||
PredicateSequenceEqual(
|
||||
dynamic_cast<const AndPredicate&>(other).operands(), operands());
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kAnd; }
|
||||
|
||||
const tensorflow::gtl::ArraySlice<Predicate*> operands() const {
|
||||
return operands_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<Predicate*> operands_;
|
||||
};
|
||||
|
||||
// Represents a logical disjunction of a set of predicates.
|
||||
class OrPredicate : public Predicate {
|
||||
public:
|
||||
explicit OrPredicate(std::vector<Predicate*> operands)
|
||||
: Predicate(HashPredicateSequence(Kind::kOr, operands)),
|
||||
operands_(std::move(operands)) {}
|
||||
|
||||
string ToString() const override {
|
||||
if (operands().empty()) {
|
||||
return "#false";
|
||||
}
|
||||
|
||||
std::vector<string> operands_str;
|
||||
std::transform(operands().begin(), operands().end(),
|
||||
std::back_inserter(operands_str),
|
||||
[](Predicate* pred) { return pred->ToString(); });
|
||||
|
||||
return strings::StrCat("(", str_util::Join(operands_str, " | "), ")");
|
||||
}
|
||||
|
||||
bool operator==(const Predicate& other) const override {
|
||||
return other.kind() == Kind::kOr &&
|
||||
PredicateSequenceEqual(
|
||||
dynamic_cast<const OrPredicate&>(other).operands(), operands());
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kOr; }
|
||||
const tensorflow::gtl::ArraySlice<Predicate*> operands() const {
|
||||
return operands_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<Predicate*> operands_;
|
||||
};
|
||||
|
||||
// Represents a logical negation of a set of predicates.
|
||||
class NotPredicate : public Predicate {
|
||||
public:
|
||||
explicit NotPredicate(Predicate* operand)
|
||||
: Predicate(HashPredicateSequence(Kind::kNot, {operand})),
|
||||
operand_(operand) {}
|
||||
|
||||
string ToString() const override {
|
||||
return strings::StrCat("~", operand()->ToString());
|
||||
}
|
||||
|
||||
bool operator==(const Predicate& other) const override {
|
||||
return other.kind() == Kind::kNot &&
|
||||
*dynamic_cast<const NotPredicate&>(other).operand() == *operand();
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kNot; }
|
||||
Predicate* operand() const { return operand_; }
|
||||
|
||||
private:
|
||||
Predicate* operand_;
|
||||
};
|
||||
|
||||
// Represents an uninterpreted symbol in a logical predicate.
|
||||
//
|
||||
// Two predicates are equivalent iff they are equivalent for all assignments to
|
||||
// the symbols contained in them.
|
||||
class SymbolPredicate : public Predicate {
|
||||
public:
|
||||
explicit SymbolPredicate(TensorId tensor_id, bool must_be_true)
|
||||
: Predicate(Hash(tensor_id, must_be_true)),
|
||||
tensor_id_(std::move(tensor_id)),
|
||||
must_be_true_(must_be_true) {}
|
||||
|
||||
string ToString() const override { return tensor_id_.ToString(); }
|
||||
bool operator==(const Predicate& other) const override {
|
||||
return other.kind() == Kind::kSymbol &&
|
||||
must_be_true() ==
|
||||
dynamic_cast<const SymbolPredicate&>(other).must_be_true() &&
|
||||
dynamic_cast<const SymbolPredicate&>(other).tensor_id() ==
|
||||
tensor_id();
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kSymbol; }
|
||||
|
||||
// If `must_be_true()` is true this SymbolPredicate represents the proposition
|
||||
// "tensor_id() is live and evaluates to true".
|
||||
//
|
||||
// If `must_be_true()` is false then this SymbolPredicate represents the
|
||||
// proposition "tensor_id() is live (and may evalutate to any value)"
|
||||
TensorId tensor_id() const { return tensor_id_; }
|
||||
bool must_be_true() const { return must_be_true_; }
|
||||
|
||||
private:
|
||||
TensorId tensor_id_;
|
||||
bool must_be_true_;
|
||||
|
||||
static int64 Hash(const TensorId tensor_id, bool must_be_true) {
|
||||
return Hash64Combine(
|
||||
::tensorflow::hash<bool>()(must_be_true),
|
||||
Hash64Combine(::tensorflow::hash<Predicate::Kind>()(Kind::kSymbol),
|
||||
TensorId::Hasher{}(tensor_id)));
|
||||
}
|
||||
};
|
||||
|
||||
// Creates and owns Predicate instances. Simplifies predicates as it creates
|
||||
// them.
|
||||
class PredicateFactory {
|
||||
public:
|
||||
Predicate* MakeAndPredicate(gtl::ArraySlice<Predicate*> operands) {
|
||||
return MakeAndOrImpl(operands, /*is_and=*/true);
|
||||
}
|
||||
Predicate* MakeOrPredicate(gtl::ArraySlice<Predicate*> operands) {
|
||||
return MakeAndOrImpl(operands, /*is_and=*/false);
|
||||
}
|
||||
|
||||
Predicate* MakeNotPredicate(Predicate* pred) {
|
||||
return Make<NotPredicate>(pred);
|
||||
}
|
||||
|
||||
Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) {
|
||||
return Make<SymbolPredicate>(tensor_id, must_be_true);
|
||||
}
|
||||
|
||||
Predicate* MakeTrue() { return MakeAndPredicate({}); }
|
||||
Predicate* MakeFalse() { return MakeOrPredicate({}); }
|
||||
|
||||
private:
|
||||
template <typename PredicateT, typename... Args>
|
||||
Predicate* Make(Args... args) {
|
||||
std::unique_ptr<PredicateT> pred(
|
||||
new PredicateT(std::forward<Args>(args)...));
|
||||
predicate_storage_.emplace_back(std::move(pred));
|
||||
return predicate_storage_.back().get();
|
||||
}
|
||||
|
||||
Predicate* MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands, bool is_and);
|
||||
|
||||
struct PredicatePtrHash {
|
||||
size_t operator()(const Predicate* pred) const { return pred->hash(); }
|
||||
};
|
||||
|
||||
struct PredicatePtrEq {
|
||||
size_t operator()(const Predicate* a, const Predicate* b) const {
|
||||
return *a == *b;
|
||||
}
|
||||
};
|
||||
|
||||
using PredicateSet =
|
||||
gtl::FlatSet<Predicate*, PredicatePtrHash, PredicatePtrEq>;
|
||||
|
||||
std::vector<std::unique_ptr<Predicate>> predicate_storage_;
|
||||
};
|
||||
|
||||
// Common code to create AndPredicate or OrPredicate instances.
|
||||
Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
|
||||
bool is_and) {
|
||||
Predicate::Kind pred_kind =
|
||||
is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
|
||||
PredicateSet simplified_ops_set;
|
||||
std::vector<Predicate*> simplified_ops;
|
||||
for (Predicate* op : operands) {
|
||||
// Simplify A&A => A and A|A => A.
|
||||
if (!simplified_ops_set.insert(op).second) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (op->kind() == pred_kind) {
|
||||
// "Inline" the operands of an inner And/Or into the parent And/Or.
|
||||
gtl::ArraySlice<Predicate*> operands =
|
||||
is_and ? dynamic_cast<AndPredicate*>(op)->operands()
|
||||
: dynamic_cast<OrPredicate*>(op)->operands();
|
||||
for (Predicate* subop : operands) {
|
||||
if (simplified_ops_set.insert(subop).second) {
|
||||
simplified_ops.push_back(subop);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
simplified_ops.push_back(op);
|
||||
}
|
||||
}
|
||||
|
||||
if (simplified_ops.size() == 1) {
|
||||
return simplified_ops[0];
|
||||
}
|
||||
|
||||
// Simplify "A&~A=>False" and "A|~A=>True".
|
||||
PredicateSet negated_ops;
|
||||
for (Predicate* op : simplified_ops) {
|
||||
if (op->kind() == Predicate::Kind::kNot) {
|
||||
negated_ops.insert(dynamic_cast<NotPredicate&>(*op).operand());
|
||||
}
|
||||
}
|
||||
|
||||
for (Predicate* op : simplified_ops) {
|
||||
if (negated_ops.count(op)) {
|
||||
return is_and ? MakeFalse() : MakeTrue();
|
||||
}
|
||||
}
|
||||
|
||||
std::stable_sort(
|
||||
simplified_ops.begin(), simplified_ops.end(),
|
||||
[](Predicate* a, Predicate* b) { return a->hash() < b->hash(); });
|
||||
|
||||
return is_and ? Make<AndPredicate>(std::move(simplified_ops))
|
||||
: Make<OrPredicate>(std::move(simplified_ops));
|
||||
}
|
||||
|
||||
class DeadnessAnalysisImpl : public DeadnessAnalysis {
|
||||
public:
|
||||
explicit DeadnessAnalysisImpl(const Graph* graph)
|
||||
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
|
||||
|
||||
Status Populate();
|
||||
bool HasInputsWithMismatchingDeadness(const Node& node) override;
|
||||
void Print() const override;
|
||||
|
||||
private:
|
||||
enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
|
||||
|
||||
std::vector<Predicate*> GetIncomingPreds(Node* n, EdgeKind edge_kind);
|
||||
void SetPred(Node* n, int output_idx, Predicate* pred) {
|
||||
CHECK(
|
||||
predicate_map_.insert({TensorId(n->name(), output_idx), pred}).second);
|
||||
}
|
||||
void SetPred(Node* n, gtl::ArraySlice<int> output_idxs, Predicate* pred) {
|
||||
for (int output_idx : output_idxs) {
|
||||
SetPred(n, output_idx, pred);
|
||||
}
|
||||
}
|
||||
|
||||
Status HandleSwitch(Node* n);
|
||||
Status HandleMerge(Node* n);
|
||||
Status HandleRecv(Node* n);
|
||||
Status HandleGeneric(Node* n);
|
||||
|
||||
const Graph& graph_;
|
||||
gtl::FlatMap<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
|
||||
PredicateFactory predicate_factory_;
|
||||
bool vlog_;
|
||||
};
|
||||
|
||||
TensorId InputEdgeToTensorId(const Edge* e) {
|
||||
return TensorId(e->src()->name(), e->src_output());
|
||||
}
|
||||
|
||||
std::vector<Predicate*> DeadnessAnalysisImpl::GetIncomingPreds(
|
||||
Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind) {
|
||||
std::vector<Predicate*> incoming_preds;
|
||||
for (const Edge* in_edge : n->in_edges()) {
|
||||
bool should_process =
|
||||
edge_kind == EdgeKind::kDataAndControl ||
|
||||
(in_edge->IsControlEdge() && edge_kind == EdgeKind::kControlOnly) ||
|
||||
(!in_edge->IsControlEdge() && edge_kind == EdgeKind::kDataOnly);
|
||||
|
||||
if (should_process) {
|
||||
auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
|
||||
CHECK(it != predicate_map_.end());
|
||||
incoming_preds.push_back(it->second);
|
||||
}
|
||||
}
|
||||
return incoming_preds;
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::HandleSwitch(Node* n) {
|
||||
std::vector<Predicate*> input_preds =
|
||||
GetIncomingPreds(n, EdgeKind::kDataAndControl);
|
||||
const Edge* pred_edge;
|
||||
TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
|
||||
Predicate* true_switch = predicate_factory_.MakeSymbolPredicate(
|
||||
TensorId(pred_edge->src()->name(), pred_edge->src_output()),
|
||||
/*must_be_true=*/true);
|
||||
Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch);
|
||||
|
||||
// Output 0 is alive iff all inputs are alive and the condition is false.
|
||||
input_preds.push_back(false_switch);
|
||||
SetPred(n, 0, predicate_factory_.MakeAndPredicate(input_preds));
|
||||
input_preds.pop_back();
|
||||
|
||||
// Output 1 is alive iff all inputs are alive and the condition is true.
|
||||
input_preds.push_back(true_switch);
|
||||
SetPred(n, 1, predicate_factory_.MakeAndPredicate(input_preds));
|
||||
input_preds.pop_back();
|
||||
|
||||
// Control is alive iff any inputs are alive.
|
||||
SetPred(n, Graph::kControlSlot,
|
||||
predicate_factory_.MakeAndPredicate(input_preds));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::HandleMerge(Node* n) {
|
||||
// Merge ignores deadness of its control inputs. A merge that isn't the
|
||||
// target of a backedge has is alive iff any of its data inputs are. We treat
|
||||
// the liveness of a merge that is the target of a backedge symbolically.
|
||||
|
||||
bool has_backedge = std::any_of(
|
||||
n->in_edges().begin(), n->in_edges().end(), [](const Edge* e) {
|
||||
return !e->IsControlEdge() && e->src()->IsNextIteration();
|
||||
});
|
||||
|
||||
Predicate* input_data_pred =
|
||||
has_backedge ? predicate_factory_.MakeSymbolPredicate(
|
||||
TensorId(n->name(), 0), /*must_be_true=*/false)
|
||||
: predicate_factory_.MakeOrPredicate(
|
||||
GetIncomingPreds(n, EdgeKind::kDataOnly));
|
||||
|
||||
SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::HandleRecv(Node* n) {
|
||||
// In addition to being alive or dead based on the inputs, a _Recv can also
|
||||
// acquire a dead signal from a _Send.
|
||||
std::vector<Predicate*> input_preds =
|
||||
GetIncomingPreds(n, EdgeKind::kDataAndControl);
|
||||
input_preds.push_back(predicate_factory_.MakeSymbolPredicate(
|
||||
TensorId(n->name(), 0), /*must_be_true=*/false));
|
||||
SetPred(n, {0, Graph::kControlSlot},
|
||||
predicate_factory_.MakeAndPredicate(input_preds));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::HandleGeneric(Node* n) {
|
||||
// Generally nodes are alive iff all their inputs are alive.
|
||||
Predicate* pred = predicate_factory_.MakeAndPredicate(
|
||||
GetIncomingPreds(n, EdgeKind::kDataAndControl));
|
||||
for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) {
|
||||
SetPred(n, output_idx, pred);
|
||||
}
|
||||
SetPred(n, Graph::kControlSlot, pred);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::Populate() {
|
||||
std::vector<Node*> rpo;
|
||||
GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/{},
|
||||
/*edge_filter=*/[](const Edge& edge) {
|
||||
return !edge.src()->IsNextIteration();
|
||||
});
|
||||
|
||||
// This an abstract interpretation over the deadness propagation semantics of
|
||||
// the graph executor.
|
||||
for (Node* n : rpo) {
|
||||
if (n->IsSwitch()) {
|
||||
TF_RETURN_IF_ERROR(HandleSwitch(n));
|
||||
} else if (n->IsMerge()) {
|
||||
TF_RETURN_IF_ERROR(HandleMerge(n));
|
||||
} else if (n->IsControlTrigger()) {
|
||||
SetPred(n, Graph::kControlSlot, predicate_factory_.MakeTrue());
|
||||
} else if (n->IsRecv() || n->IsHostRecv()) {
|
||||
TF_RETURN_IF_ERROR(HandleRecv(n));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(HandleGeneric(n));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool DeadnessAnalysisImpl::HasInputsWithMismatchingDeadness(const Node& node) {
|
||||
CHECK(!node.IsMerge());
|
||||
|
||||
if (vlog_) {
|
||||
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() << ")";
|
||||
}
|
||||
|
||||
Predicate* pred = nullptr;
|
||||
for (const Edge* edge : node.in_edges()) {
|
||||
auto it = predicate_map_.find(InputEdgeToTensorId(edge));
|
||||
CHECK(it != predicate_map_.end());
|
||||
if (vlog_) {
|
||||
VLOG(2) << " " << InputEdgeToTensorId(edge).ToString() << ": "
|
||||
<< it->second->ToString();
|
||||
}
|
||||
|
||||
// Today we just compare the predicates for equality (with some
|
||||
// canonicalization/simplification happening before) but we could be more
|
||||
// sophisticated here if need be.
|
||||
if (pred != nullptr && *pred != *it->second) {
|
||||
if (vlog_) {
|
||||
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
|
||||
<< ") -> true";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
pred = it->second;
|
||||
}
|
||||
|
||||
if (vlog_) {
|
||||
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
|
||||
<< ") -> false";
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void DeadnessAnalysisImpl::Print() const {
|
||||
std::vector<TensorId> tensor_ids;
|
||||
for (const auto& kv_pair : predicate_map_) {
|
||||
tensor_ids.push_back(kv_pair.first);
|
||||
}
|
||||
|
||||
std::sort(tensor_ids.begin(), tensor_ids.end());
|
||||
|
||||
for (TensorId tensor_id : tensor_ids) {
|
||||
auto it = predicate_map_.find(tensor_id);
|
||||
CHECK(it != predicate_map_.end()) << tensor_id.ToString();
|
||||
VLOG(2) << tensor_id.ToString() << " -> " << it->second->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
DeadnessAnalysis::~DeadnessAnalysis() {}
|
||||
|
||||
/*static*/ Status DeadnessAnalysis::Run(
|
||||
const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) {
|
||||
std::unique_ptr<DeadnessAnalysisImpl> analysis(
|
||||
new DeadnessAnalysisImpl(&graph));
|
||||
TF_RETURN_IF_ERROR(analysis->Populate());
|
||||
|
||||
if (VLOG_IS_ON(2)) {
|
||||
analysis->Print();
|
||||
}
|
||||
|
||||
*result = std::move(analysis);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
68
tensorflow/compiler/jit/deadness_analysis.h
Normal file
68
tensorflow/compiler/jit/deadness_analysis.h
Normal file
@ -0,0 +1,68 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
|
||||
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This analyzes a TensorFlow graph to identify nodes which may have partially
|
||||
// dead inputs (i.e. these nodes may have some dead inputs and some alive
|
||||
// inputs).
|
||||
//
|
||||
// For example, the ADD node in the following graph
|
||||
//
|
||||
// V0 PRED0 V1 PRED1
|
||||
// | | | |
|
||||
// v v v v
|
||||
// SWITCH SWITCH
|
||||
// | |
|
||||
// +---+ + ---+
|
||||
// | |
|
||||
// v v
|
||||
// ADD
|
||||
//
|
||||
// can have its inputs independently dead or alive based on the runtime values
|
||||
// of PRED0 and PRED1.
|
||||
//
|
||||
// It is tempting to call this a liveness analysis but I avoided that because
|
||||
// "liveness" already has other connotations.
|
||||
class DeadnessAnalysis {
|
||||
public:
|
||||
// Returns true if `node` may have some live inputs and some dead inputs.
|
||||
//
|
||||
// This is a conservatively correct routine -- if it returns false then `node`
|
||||
// is guaranteed to not have inputs with mismatching liveness, but not the
|
||||
// converse.
|
||||
//
|
||||
// REQUIRES: node is not a Merge operation.
|
||||
virtual bool HasInputsWithMismatchingDeadness(const Node& node) = 0;
|
||||
|
||||
// Prints out the internal state of this instance. For debugging purposes
|
||||
// only.
|
||||
virtual void Print() const = 0;
|
||||
virtual ~DeadnessAnalysis();
|
||||
|
||||
// Run the deadness analysis over `graph` and returns an error or a populated
|
||||
// instance of DeadnessAnalysis in `result`.
|
||||
static Status Run(const Graph& graph,
|
||||
std::unique_ptr<DeadnessAnalysis>* result);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
|
443
tensorflow/compiler/jit/deadness_analysis_test.cc
Normal file
443
tensorflow/compiler/jit/deadness_analysis_test.cc
Normal file
@ -0,0 +1,443 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status AnalyzeDeadness(Graph* graph,
|
||||
std::unique_ptr<DeadnessAnalysis>* result) {
|
||||
FixupSourceAndSinkEdges(graph);
|
||||
return DeadnessAnalysis::Run(*graph, result);
|
||||
}
|
||||
|
||||
ops::Switch CreateSwitch(const Scope& root, const string& prefix) {
|
||||
Output value = ops::Placeholder(root.WithOpName(prefix + "/value"), DT_FLOAT);
|
||||
Output predicate =
|
||||
ops::Placeholder(root.WithOpName(prefix + "/pred"), DT_BOOL);
|
||||
return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate);
|
||||
}
|
||||
|
||||
Output CreateInductionVariable(const Scope& root, const string& prefix,
|
||||
const string& frame_name, int32 init) {
|
||||
Output initial_value = ops::Const(root.WithOpName(prefix + "/init"), init);
|
||||
Output enter_initial_value = ops::internal::Enter(
|
||||
root.WithOpName(prefix + "/enter"), initial_value, frame_name);
|
||||
|
||||
ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_initial_value});
|
||||
Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1);
|
||||
Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10);
|
||||
Output loop_cond_expr =
|
||||
ops::Less(root.WithOpName(prefix + "/less"), iv.output, final_value);
|
||||
Output loop_cond =
|
||||
ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr);
|
||||
ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
|
||||
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output);
|
||||
Output iv_next =
|
||||
ops::Add(root.WithOpName(prefix + "/ivnext"), iv.output, increment_by);
|
||||
Output next_iteration =
|
||||
ops::NextIteration(root.WithOpName(prefix + "next_iteration"), iv_next);
|
||||
|
||||
root.graph()->AddEdge(next_iteration.node(), 0, iv.output.node(), 1);
|
||||
root.graph()->AddControlEdge(iv.output.node(), increment_by.node());
|
||||
root.graph()->AddControlEdge(iv.output.node(), final_value.node());
|
||||
|
||||
return iv.output;
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, BasicPositive) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
ops::Switch sw = CreateSwitch(root, "0");
|
||||
Output add =
|
||||
ops::Add(root.WithOpName("add"), sw.output_true, sw.output_false);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, BasicNegative) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
|
||||
Output b = ops::Placeholder(root.WithOpName("b"), DT_FLOAT);
|
||||
Output add = ops::Add(root.WithOpName("add"), a, b);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, AndIsCommutative) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
ops::Switch sw_0 = CreateSwitch(root, "0");
|
||||
ops::Switch sw_1 = CreateSwitch(root, "1");
|
||||
|
||||
Output a0 =
|
||||
ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
|
||||
Output a1 =
|
||||
ops::Add(root.WithOpName("a1"), sw_1.output_false, sw_0.output_false);
|
||||
|
||||
Output b0 =
|
||||
ops::Add(root.WithOpName("b0"), sw_0.output_false, sw_1.output_true);
|
||||
Output b1 =
|
||||
ops::Add(root.WithOpName("b1"), sw_1.output_true, sw_0.output_false);
|
||||
|
||||
Output live0 = ops::Add(root.WithOpName("live0"), a0, a1);
|
||||
Output live1 = ops::Add(root.WithOpName("live1"), b0, b1);
|
||||
|
||||
Output halfdead0 = ops::Add(root.WithOpName("halfdead0"), a0, b0);
|
||||
Output halfdead1 = ops::Add(root.WithOpName("halfdead1"), a1, b1);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, AndIsAssociative) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
ops::Switch sw_0 = CreateSwitch(root, "0");
|
||||
ops::Switch sw_1 = CreateSwitch(root, "1");
|
||||
ops::Switch sw_2 = CreateSwitch(root, "2");
|
||||
|
||||
Output a0 =
|
||||
ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
|
||||
Output a1 = ops::Add(root.WithOpName("a1"), a0, sw_2.output_false);
|
||||
|
||||
Output b0 =
|
||||
ops::Add(root.WithOpName("b0"), sw_1.output_false, sw_2.output_false);
|
||||
Output b1 = ops::Add(root.WithOpName("b1"), sw_0.output_false, b0);
|
||||
|
||||
Output add = ops::Add(root.WithOpName("add"), a1, b1);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, OrIsCommutative) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
ops::Switch sw_0 = CreateSwitch(root, "0");
|
||||
ops::Switch sw_1 = CreateSwitch(root, "1");
|
||||
|
||||
ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
|
||||
ops::Merge m1(root.WithOpName("m1"), {sw_1.output_false, sw_0.output_false});
|
||||
ops::Merge m2(root.WithOpName("m2"), {sw_0.output_false, sw_1.output_true});
|
||||
ops::Merge m3(root.WithOpName("m3"), {sw_1.output_true, sw_0.output_false});
|
||||
|
||||
Output live0 = ops::Add(root.WithOpName("live0"), m0.output, m1.output);
|
||||
Output live1 = ops::Add(root.WithOpName("live1"), m2.output, m3.output);
|
||||
|
||||
Output halfdead0 =
|
||||
ops::Add(root.WithOpName("halfdead0"), m0.output, m2.output);
|
||||
Output halfdead1 =
|
||||
ops::Add(root.WithOpName("halfdead1"), m1.output, m3.output);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, OrIsAssociative) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
ops::Switch sw_0 = CreateSwitch(root, "0");
|
||||
ops::Switch sw_1 = CreateSwitch(root, "1");
|
||||
ops::Switch sw_2 = CreateSwitch(root, "2");
|
||||
|
||||
ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
|
||||
ops::Merge m1(root.WithOpName("m1"), {m0.output, sw_2.output_false});
|
||||
ops::Merge m2(root.WithOpName("m2"), {sw_1.output_false, sw_2.output_false});
|
||||
ops::Merge m3(root.WithOpName("m3"), {sw_0.output_false, m2.output});
|
||||
|
||||
Output add = ops::Add(root.WithOpName("add"), m1.output, m3.output);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, AndOfOr) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
ops::Switch sw_0 = CreateSwitch(root, "0");
|
||||
ops::Switch sw_1 = CreateSwitch(root, "1");
|
||||
ops::Switch sw_2 = CreateSwitch(root, "2");
|
||||
ops::Switch sw_3 = CreateSwitch(root, "3");
|
||||
|
||||
ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
|
||||
ops::Merge m1(root.WithOpName("m1"), {sw_2.output_false, sw_3.output_false});
|
||||
|
||||
Output add0 = ops::Add(root.WithOpName("add0"), m0.output, m1.output);
|
||||
Output add1 = ops::Add(root.WithOpName("add1"), m0.output, m1.output);
|
||||
|
||||
Output add2 = ops::Add(root.WithOpName("add2"), add0, add1);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, OrOfAnd) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
ops::Switch sw_0 = CreateSwitch(root, "0");
|
||||
ops::Switch sw_1 = CreateSwitch(root, "1");
|
||||
ops::Switch sw_2 = CreateSwitch(root, "2");
|
||||
ops::Switch sw_3 = CreateSwitch(root, "3");
|
||||
|
||||
Output add0 =
|
||||
ops::Add(root.WithOpName("add0"), sw_0.output_false, sw_1.output_false);
|
||||
Output add1 =
|
||||
ops::Add(root.WithOpName("add1"), sw_2.output_false, sw_3.output_false);
|
||||
|
||||
ops::Merge m0(root.WithOpName("m0"), {add0, add1});
|
||||
ops::Merge m1(root.WithOpName("m1"), {add0, add1});
|
||||
|
||||
Output add2 = ops::Add(root.WithOpName("add2"), m0.output, m1.output);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) {
|
||||
// This demonstrates one of the weaknesses in the current approach -- since we
|
||||
// only do some basic simplifications we can't see that "(A|B)&C" ==
|
||||
// "(A&C)|(B&C)".
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
ops::Switch sw_0 = CreateSwitch(root, "0");
|
||||
ops::Switch sw_1 = CreateSwitch(root, "1");
|
||||
ops::Switch sw_2 = CreateSwitch(root, "2");
|
||||
|
||||
ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
|
||||
Output add0 = ops::Add(root.WithOpName("add0"), m0.output, sw_2.output_false);
|
||||
|
||||
Output add1 =
|
||||
ops::Add(root.WithOpName("add1"), sw_0.output_false, sw_2.output_false);
|
||||
Output add2 =
|
||||
ops::Add(root.WithOpName("add2"), sw_1.output_false, sw_2.output_false);
|
||||
ops::Merge m1(root.WithOpName("m1"), {add1, add2});
|
||||
|
||||
Output add3 = ops::Add(root.WithOpName("add3"), add0, m1.output);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, Ternary) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output predicate = ops::Placeholder(root.WithOpName("predicate"), DT_BOOL);
|
||||
Output true_value = ops::Placeholder(root.WithOpName("true_value"), DT_FLOAT);
|
||||
Output false_value =
|
||||
ops::Placeholder(root.WithOpName("false_value"), DT_FLOAT);
|
||||
|
||||
ops::Switch predicated_true(root.WithOpName("predicated_true"), true_value,
|
||||
predicate);
|
||||
|
||||
ops::Switch predicated_false(root.WithOpName("predicated_false"), true_value,
|
||||
predicate);
|
||||
ops::Merge merge(root.WithOpName("ternary"), {predicated_true.output_true,
|
||||
predicated_false.output_false});
|
||||
Output addend = ops::Placeholder(root.WithOpName("addend"), DT_FLOAT);
|
||||
Output add = ops::Add(root.WithOpName("add"), merge.output, addend);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, Recv) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_FLOAT, "tensor_a",
|
||||
"sender", 0, "receiver");
|
||||
Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_FLOAT, "tensor_b",
|
||||
"sender", 0, "receiver");
|
||||
Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, HostRecv) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output recv_a = ops::_HostRecv(root.WithOpName("recv_a"), DT_FLOAT,
|
||||
"tensor_a", "sender", 0, "receiver");
|
||||
Output recv_b = ops::_HostRecv(root.WithOpName("recv_b"), DT_FLOAT,
|
||||
"tensor_b", "sender", 0, "receiver");
|
||||
Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, Loop) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0);
|
||||
Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0);
|
||||
Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1);
|
||||
Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1);
|
||||
Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
// NB! iv0 and iv1 are equivalent and a smarter deadness analysis would have
|
||||
// noticed that. Today we are pessimistic here because we assign an
|
||||
// uninterpreted symbol to merges with backedges.
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, ControlInputs) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
ops::Switch sw = CreateSwitch(root, "0");
|
||||
|
||||
Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
|
||||
Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
|
||||
|
||||
Output const0 = ops::Const(root.WithOpName("const0"), 1);
|
||||
Output const1 = ops::Const(root.WithOpName("const1"), 2);
|
||||
|
||||
Output add = ops::Add(root.WithOpName("add"), const0, const1);
|
||||
|
||||
root.graph()->AddControlEdge(id0.node(), const0.node());
|
||||
root.graph()->AddControlEdge(id1.node(), const1.node());
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, ControlTrigger) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
ops::Switch sw = CreateSwitch(root, "0");
|
||||
|
||||
Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
|
||||
Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
|
||||
|
||||
ops::ControlTrigger ctrl_trigger0(root.WithOpName("ctrl_trigger0"));
|
||||
ops::ControlTrigger ctrl_trigger1(root.WithOpName("ctrl_trigger1"));
|
||||
|
||||
Output const0 = ops::Const(root.WithOpName("const0"), 1);
|
||||
Output const1 = ops::Const(root.WithOpName("const1"), 2);
|
||||
|
||||
Output add = ops::Add(root.WithOpName("add"), const0, const1);
|
||||
|
||||
root.graph()->AddControlEdge(id0.node(), ctrl_trigger0.operation.node());
|
||||
root.graph()->AddControlEdge(ctrl_trigger0.operation.node(), const0.node());
|
||||
|
||||
root.graph()->AddControlEdge(id1.node(), ctrl_trigger1.operation.node());
|
||||
root.graph()->AddControlEdge(ctrl_trigger1.operation.node(), const1.node());
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
ops::Switch sw = CreateSwitch(root, "0");
|
||||
|
||||
Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
|
||||
Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
|
||||
|
||||
Output constant = ops::Const(root.WithOpName("constant"), 5);
|
||||
ops::Merge m0(root.WithOpName("m0"), {constant});
|
||||
ops::Merge m1(root.WithOpName("m0"), {constant});
|
||||
Output add = ops::Add(root.WithOpName("add"), m0.output, m1.output);
|
||||
|
||||
root.graph()->AddControlEdge(id0.node(), m0.output.node());
|
||||
root.graph()->AddControlEdge(id1.node(), m1.output.node());
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, RecvVsSwitch) {
|
||||
// Demonstrates why we need the must_be_true bit on SymbolP.
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
|
||||
0, "receiver");
|
||||
Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
|
||||
ops::Switch sw(root.WithOpName("switch"), value, recv);
|
||||
Output logical_and =
|
||||
ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
|
||||
@ -28,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
@ -462,17 +464,27 @@ Status MarkForCompilationPass::Run(
|
||||
VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
|
||||
const FunctionLibraryDefinition* fld = options.flib_def;
|
||||
|
||||
auto is_compilable = [global_jit_level, cpu_global_jit, fusion_only, fld](
|
||||
const Node* node, const DeviceType& device_type) {
|
||||
std::unique_ptr<DeadnessAnalysis> deadness;
|
||||
{
|
||||
XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 0);
|
||||
TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness));
|
||||
}
|
||||
|
||||
auto is_compilable = [&](const Node* node, const DeviceType& device_type) {
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
|
||||
®istration)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO(b/111570009): This bailout for ControlTrigger is probably not
|
||||
// needed.
|
||||
//
|
||||
// Don't compile control trigger nodes. We won't preserve their deadness
|
||||
// semantics correctly, so it's safest not to compile them.
|
||||
if (node->IsControlTrigger()) return false;
|
||||
if (node->IsControlTrigger()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If this device requires a JIT, we must say yes.
|
||||
if (registration->requires_compilation) return true;
|
||||
@ -485,6 +497,14 @@ Status MarkForCompilationPass::Run(
|
||||
status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
|
||||
if (status.ok()) return compile;
|
||||
|
||||
// If inputs to `node` can have conflicting deadness (i.e. some are alive
|
||||
// and some are dead) then don't compile it. XLA cannot represent the
|
||||
// deadness semantics of these nodes correctly and auto-clustering these
|
||||
// nodes can cause deadness propagate to nodes that should be live.
|
||||
if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check for fusable ops only if requested.
|
||||
if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) {
|
||||
return false;
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
@ -146,6 +147,9 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportGraphDef(options, item.graph, &graph, &shape_refiner));
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> deadness;
|
||||
TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(graph, &deadness));
|
||||
|
||||
// Collect nodes that can be fused via XLA, while ignoring those that
|
||||
// explicitly ask for XLA: (*) nodes that are marked to be compiled
|
||||
// explicitly. (*) nodes assigned to XLA device.
|
||||
@ -185,6 +189,14 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
|
||||
continue;
|
||||
}
|
||||
|
||||
// If inputs to `node` can have conflicting deadness (i.e. some are alive
|
||||
// and some are dead) then don't compile it. XLA cannot represent the
|
||||
// deadness semantics of these nodes correctly and auto-clustering these
|
||||
// nodes can cause deadness propagate to nodes that should be live.
|
||||
if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
compilation_candidates.insert(node);
|
||||
}
|
||||
|
||||
|
@ -25,7 +25,8 @@ namespace tensorflow {
|
||||
|
||||
void DFS(const Graph& g, const std::function<void(Node*)>& enter,
|
||||
const std::function<void(Node*)>& leave,
|
||||
const NodeComparator& stable_comparator) {
|
||||
const NodeComparator& stable_comparator,
|
||||
const EdgeFilter& edge_filter) {
|
||||
// Stack of work to do.
|
||||
struct Work {
|
||||
Node* node;
|
||||
@ -52,7 +53,6 @@ void DFS(const Graph& g, const std::function<void(Node*)>& enter,
|
||||
// Arrange to call leave(n) when all done with descendants.
|
||||
if (leave) stack.push_back(Work{n, true});
|
||||
|
||||
gtl::iterator_range<NeighborIter> nodes = n->out_nodes();
|
||||
auto add_work = [&visited, &stack](Node* out) {
|
||||
if (!visited[out->id()]) {
|
||||
// Note; we must not mark as visited until we actually process it.
|
||||
@ -62,16 +62,20 @@ void DFS(const Graph& g, const std::function<void(Node*)>& enter,
|
||||
|
||||
if (stable_comparator) {
|
||||
std::vector<Node*> nodes_sorted;
|
||||
for (Node* out : nodes) {
|
||||
nodes_sorted.emplace_back(out);
|
||||
for (const Edge* out_edge : n->out_edges()) {
|
||||
if (!edge_filter || edge_filter(*out_edge)) {
|
||||
nodes_sorted.emplace_back(out_edge->dst());
|
||||
}
|
||||
}
|
||||
std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
|
||||
for (Node* out : nodes_sorted) {
|
||||
add_work(out);
|
||||
}
|
||||
} else {
|
||||
for (Node* out : nodes) {
|
||||
add_work(out);
|
||||
for (const Edge* out_edge : n->out_edges()) {
|
||||
if (!edge_filter || edge_filter(*out_edge)) {
|
||||
add_work(out_edge->dst());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -118,8 +122,6 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
|
||||
// Arrange to call leave(n) when all done with descendants.
|
||||
if (leave) stack.push_back(Work{n, true});
|
||||
|
||||
gtl::iterator_range<NeighborIter> nodes = n->in_nodes();
|
||||
|
||||
auto add_work = [&visited, &stack](T out) {
|
||||
if (!visited[out->id()]) {
|
||||
// Note; we must not mark as visited until we actually process it.
|
||||
@ -129,16 +131,16 @@ void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
|
||||
|
||||
if (stable_comparator) {
|
||||
std::vector<T> nodes_sorted;
|
||||
for (T in : nodes) {
|
||||
nodes_sorted.emplace_back(in);
|
||||
for (const Edge* in_edge : n->in_edges()) {
|
||||
nodes_sorted.emplace_back(in_edge->src());
|
||||
}
|
||||
std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
|
||||
for (T in : nodes_sorted) {
|
||||
add_work(in);
|
||||
}
|
||||
} else {
|
||||
for (T in : nodes) {
|
||||
add_work(in);
|
||||
for (const Edge* in_edge : n->in_edges()) {
|
||||
add_work(in_edge->src());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -161,14 +163,17 @@ void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
|
||||
}
|
||||
|
||||
void GetPostOrder(const Graph& g, std::vector<Node*>* order,
|
||||
const NodeComparator& stable_comparator) {
|
||||
const NodeComparator& stable_comparator,
|
||||
const EdgeFilter& edge_filter) {
|
||||
order->clear();
|
||||
DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator);
|
||||
DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator,
|
||||
edge_filter);
|
||||
}
|
||||
|
||||
void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
|
||||
const NodeComparator& stable_comparator) {
|
||||
GetPostOrder(g, order, stable_comparator);
|
||||
const NodeComparator& stable_comparator,
|
||||
const EdgeFilter& edge_filter) {
|
||||
GetPostOrder(g, order, stable_comparator, edge_filter);
|
||||
std::reverse(order->begin(), order->end());
|
||||
}
|
||||
|
||||
|
@ -28,6 +28,8 @@ namespace tensorflow {
|
||||
// Comparator for two nodes. This is used in order to get a stable ording.
|
||||
using NodeComparator = std::function<bool(const Node*, const Node*)>;
|
||||
|
||||
using EdgeFilter = std::function<bool(const Edge&)>;
|
||||
|
||||
// Compares two node based on their ids.
|
||||
struct NodeComparatorID {
|
||||
bool operator()(const Node* n1, const Node* n2) const {
|
||||
@ -47,9 +49,11 @@ struct NodeComparatorName {
|
||||
// If leave is not empty, calls leave(n) after visiting all children of n.
|
||||
// If stable_comparator is set, a stable ordering of visit is achieved by
|
||||
// sorting a node's neighbors first before visiting them.
|
||||
// If edge_filter is set then ignores edges for which edge_filter returns false.
|
||||
extern void DFS(const Graph& g, const std::function<void(Node*)>& enter,
|
||||
const std::function<void(Node*)>& leave,
|
||||
const NodeComparator& stable_comparator = {});
|
||||
const NodeComparator& stable_comparator = {},
|
||||
const EdgeFilter& edge_filter = {});
|
||||
|
||||
// Perform a reverse depth-first-search on g starting at the sink node.
|
||||
// If enter is not empty, calls enter(n) before visiting any parents of n.
|
||||
@ -83,15 +87,21 @@ extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
|
||||
// If stable_comparator is set, a stable ordering of visit is achieved by
|
||||
// sorting a node's neighbors first before visiting them.
|
||||
//
|
||||
// If edge_filter is set then ignores edges for which edge_filter returns false.
|
||||
//
|
||||
// REQUIRES: order is not NULL.
|
||||
void GetPostOrder(const Graph& g, std::vector<Node*>* order,
|
||||
const NodeComparator& stable_comparator = {});
|
||||
const NodeComparator& stable_comparator = {},
|
||||
const EdgeFilter& edge_filter = {});
|
||||
|
||||
// Stores in *order the reverse post-order numbering of all nodes
|
||||
// If stable_comparator is set, a stable ordering of visit is achieved by
|
||||
// sorting a node's neighbors first before visiting them.
|
||||
//
|
||||
// If edge_filter is set then ignores edges for which edge_filter returns false.
|
||||
void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
|
||||
const NodeComparator& stable_comparator = {});
|
||||
const NodeComparator& stable_comparator = {},
|
||||
const EdgeFilter& edge_filter = {});
|
||||
|
||||
// Prune nodes in "g" that are not in some path from the source node
|
||||
// to any node in 'nodes'. Returns true if changes were made to the graph.
|
||||
|
@ -36,6 +36,11 @@ namespace {
|
||||
REGISTER_OP("TestParams").Output("o: float");
|
||||
REGISTER_OP("TestInput").Output("a: float").Output("b: float");
|
||||
REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
|
||||
REGISTER_OP("TestUnary").Input("a: float").Output("o: float");
|
||||
REGISTER_OP("TestBinary")
|
||||
.Input("a: float")
|
||||
.Input("b: float")
|
||||
.Output("o: float");
|
||||
|
||||
// Compares that the order of nodes in 'inputs' respects the
|
||||
// pair orders described in 'ordered_pairs'.
|
||||
@ -148,5 +153,52 @@ TEST(AlgorithmTest, ReversePostOrderStable) {
|
||||
EXPECT_TRUE(ExpectBefore({{"t2", "t3"}}, order, &error));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(AlgorithmTest, PostOrderWithEdgeFilter) {
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
string error;
|
||||
Node* n0 = ops::SourceOp("TestParams", b.opts().WithName("n0"));
|
||||
Node* n1 = ops::UnaryOp("TestUnary", n0, b.opts().WithName("n1"));
|
||||
Node* n2 = ops::UnaryOp("TestUnary", n1, b.opts().WithName("n2"));
|
||||
Node* n3 = ops::BinaryOp("TestBinary", n2, n0, b.opts().WithName("n3"));
|
||||
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g));
|
||||
|
||||
g.AddEdge(g.FindNodeId(n3->id()), 0, g.FindNodeId(n1->id()), 1);
|
||||
|
||||
std::vector<Node*> post_order;
|
||||
auto edge_filter = [&](const Edge& e) {
|
||||
return !(e.src()->id() == n3->id() && e.dst()->id() == n1->id());
|
||||
};
|
||||
|
||||
std::vector<Node*> expected_post_order = {
|
||||
g.sink_node(), g.FindNodeId(n3->id()), g.FindNodeId(n2->id()),
|
||||
g.FindNodeId(n1->id()), g.FindNodeId(n0->id()), g.source_node()};
|
||||
|
||||
std::vector<Node*> expected_reverse_post_order = expected_post_order;
|
||||
std::reverse(expected_reverse_post_order.begin(),
|
||||
expected_reverse_post_order.end());
|
||||
|
||||
GetPostOrder(g, &post_order, /*stable_comparator=*/{},
|
||||
/*edge_filter=*/edge_filter);
|
||||
|
||||
ASSERT_EQ(expected_post_order.size(), post_order.size());
|
||||
for (int i = 0; i < post_order.size(); i++) {
|
||||
CHECK_EQ(post_order[i], expected_post_order[i])
|
||||
<< post_order[i]->name() << " vs. " << expected_post_order[i]->name();
|
||||
}
|
||||
|
||||
std::vector<Node*> reverse_post_order;
|
||||
GetReversePostOrder(g, &reverse_post_order, /*stable_comparator=*/{},
|
||||
/*edge_filter=*/edge_filter);
|
||||
|
||||
ASSERT_EQ(expected_reverse_post_order.size(), reverse_post_order.size());
|
||||
for (int i = 0; i < reverse_post_order.size(); i++) {
|
||||
CHECK_EQ(reverse_post_order[i], expected_reverse_post_order[i])
|
||||
<< reverse_post_order[i]->name() << " vs. "
|
||||
<< expected_reverse_post_order[i]->name();
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user