Some NFC cleanups to MarkForCompilationPass
1. Extract out all of the functionality into a MarkForCompilationPassImpl class
that lets us thread through state easily.
2. Extracted out some helper functions, mainly aimed at making the various loop
bodies more readable. There is still room for improvement here.
3. Fix the race on the `fuel` counter (b/129277762).
4. Remove the is_compilable_fn lambda in favor of a more specific
MarkForCompilationPassImpl::DebugOptions struct (with the lambda the code
looks more general than it actually is).
PiperOrigin-RevId: 240469235
This commit is contained in:
parent
65d68abc1f
commit
3d2488f052
@ -533,6 +533,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
|
||||
"//tensorflow/compiler/tf2xla/cc:xla_ops",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
|
||||
@ -394,11 +394,11 @@ bool GraphCycles::ContractEdge(int32 a, int32 b) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unordered_set<int32> GraphCycles::Successors(int32 node) {
|
||||
std::unordered_set<int32> GraphCycles::Successors(int32 node) const {
|
||||
return rep_->nodes_[node]->out;
|
||||
}
|
||||
|
||||
std::unordered_set<int32> GraphCycles::Predecessors(int32 node) {
|
||||
std::unordered_set<int32> GraphCycles::Predecessors(int32 node) const {
|
||||
return rep_->nodes_[node]->in;
|
||||
}
|
||||
|
||||
|
||||
@ -117,8 +117,8 @@ class GraphCycles {
|
||||
// Expensive: should only be called from graphcycles_test.cc.
|
||||
bool CheckInvariants() const;
|
||||
|
||||
std::unordered_set<int32> Successors(int32 node);
|
||||
std::unordered_set<int32> Predecessors(int32 node);
|
||||
std::unordered_set<int32> Successors(int32 node) const;
|
||||
std::unordered_set<int32> Predecessors(int32 node) const;
|
||||
|
||||
// ----------------------------------------------------
|
||||
struct Rep;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -41,9 +41,7 @@ class MarkForCompilationPass : public GraphOptimizationPass {
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
|
||||
private:
|
||||
Status RunImpl(const GraphOptimizationPassOptions& options,
|
||||
const std::function<bool(const Node*, const DeviceType&)>&
|
||||
is_compilable_fn = {});
|
||||
Status RunForTest(const GraphOptimizationPassOptions& options);
|
||||
|
||||
friend class MarkForCompilationPassTestHelper;
|
||||
};
|
||||
|
||||
@ -49,7 +49,7 @@ namespace tensorflow {
|
||||
opt_options.session_options = &session_options;
|
||||
opt_options.flib_def = flib_def;
|
||||
MarkForCompilationPass pass;
|
||||
return pass.RunImpl(opt_options);
|
||||
return pass.RunForTest(opt_options);
|
||||
}
|
||||
|
||||
/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user