Some NFC cleanups to MarkForCompilationPass

1. Extract out all of the functionality into a MarkForCompilationPassImpl class
    that lets us thread through state easily.

 2. Extracted out some helper functions, mainly aimed at making the various loop
    bodies more readable.  There is still room for improvement here.

 3. Fix the race on the `fuel` counter (b/129277762).

 4. Remove the is_compilable_fn lambda in favor of a more specific
    MarkForCompilationPassImpl::DebugOptions struct (with the lambda the code
    looks more general than it actually is).

PiperOrigin-RevId: 240469235
This commit is contained in:
Sanjoy Das 2019-03-26 18:09:23 -07:00 committed by TensorFlower Gardener
parent 65d68abc1f
commit 3d2488f052
6 changed files with 650 additions and 521 deletions

View File

@ -533,6 +533,7 @@ cc_library(
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
"//tensorflow/compiler/tf2xla/cc:xla_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",

View File

@ -394,11 +394,11 @@ bool GraphCycles::ContractEdge(int32 a, int32 b) {
return true;
}
std::unordered_set<int32> GraphCycles::Successors(int32 node) {
std::unordered_set<int32> GraphCycles::Successors(int32 node) const {
return rep_->nodes_[node]->out;
}
std::unordered_set<int32> GraphCycles::Predecessors(int32 node) {
std::unordered_set<int32> GraphCycles::Predecessors(int32 node) const {
return rep_->nodes_[node]->in;
}

View File

@ -117,8 +117,8 @@ class GraphCycles {
// Expensive: should only be called from graphcycles_test.cc.
bool CheckInvariants() const;
std::unordered_set<int32> Successors(int32 node);
std::unordered_set<int32> Predecessors(int32 node);
std::unordered_set<int32> Successors(int32 node) const;
std::unordered_set<int32> Predecessors(int32 node) const;
// ----------------------------------------------------
struct Rep;

File diff suppressed because it is too large Load Diff

View File

@ -41,9 +41,7 @@ class MarkForCompilationPass : public GraphOptimizationPass {
Status Run(const GraphOptimizationPassOptions& options) override;
private:
Status RunImpl(const GraphOptimizationPassOptions& options,
const std::function<bool(const Node*, const DeviceType&)>&
is_compilable_fn = {});
Status RunForTest(const GraphOptimizationPassOptions& options);
friend class MarkForCompilationPassTestHelper;
};

View File

@ -49,7 +49,7 @@ namespace tensorflow {
opt_options.session_options = &session_options;
opt_options.flib_def = flib_def;
MarkForCompilationPass pass;
return pass.RunImpl(opt_options);
return pass.RunForTest(opt_options);
}
/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(