Make cond_v2 If op lowering work in a defun + eager.
Prior to this change, the lowering pass assumed that the If op functions would be available in the If op's graph. If the If op is defined in a defun and then called via eager execution, the functions will be in the eager context, but not in the defun's graph. This change makes the lowering pass correctly use the function library passed in by the caller via GraphOptimizationPassOptions. PiperOrigin-RevId: 215271990
This commit is contained in:
parent
1630584951
commit
c86f594135
@ -38,11 +38,12 @@ class CondBuilder {
|
||||
public:
|
||||
enum Branch { kElseBranch = 0, kThenBranch = 1 };
|
||||
|
||||
// Create a CondBuilder to create the lowering of If op. that has then and
|
||||
// Create a CondBuilder to create the lowered form of `if_op` with then and
|
||||
// else functions named `then_fn_name` and `else_fn_name` respectively in the
|
||||
// given graph.
|
||||
// `graph`. The functions should be available in `flib`.
|
||||
CondBuilder(Node* if_op, const string& then_fn_name,
|
||||
const string& else_fn_name, Graph* graph);
|
||||
const string& else_fn_name, const FunctionLibraryDefinition& flib,
|
||||
Graph* graph);
|
||||
|
||||
// Constructs the basic conditional control flow using switch and merge nodes.
|
||||
Status CreatePivotNodes();
|
||||
@ -89,6 +90,7 @@ class CondBuilder {
|
||||
Node* then_call_node_;
|
||||
Node* else_call_node_;
|
||||
Graph* graph_;
|
||||
const FunctionLibraryDefinition& flib_;
|
||||
string name_;
|
||||
|
||||
NodeBuilder then_call_builder_;
|
||||
@ -96,9 +98,11 @@ class CondBuilder {
|
||||
};
|
||||
|
||||
CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name,
|
||||
const string& else_fn_name, Graph* graph)
|
||||
const string& else_fn_name,
|
||||
const FunctionLibraryDefinition& flib, Graph* graph)
|
||||
: if_op_(if_op),
|
||||
graph_(graph),
|
||||
flib_(flib),
|
||||
name_(if_op->name()),
|
||||
then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()),
|
||||
else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) {
|
||||
@ -193,15 +197,15 @@ Status CondBuilder::AddOutputs() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InlineCallInGraph(Node* n, Graph* g) {
|
||||
const auto& lib = g->flib_def();
|
||||
const FunctionDef* fdef = lib.Find(n->type_string());
|
||||
Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib,
|
||||
Graph* g) {
|
||||
const FunctionDef* fdef = flib.Find(n->type_string());
|
||||
CHECK(fdef != nullptr);
|
||||
FunctionBody* fbody;
|
||||
TF_RETURN_IF_ERROR(
|
||||
FunctionDefToBodyHelper(*fdef, n->attrs(), &lib,
|
||||
[&lib](const string& op, const OpDef** sig) {
|
||||
return lib.LookUpOpDef(op, sig);
|
||||
FunctionDefToBodyHelper(*fdef, n->attrs(), &flib,
|
||||
[&flib](const string& op, const OpDef** sig) {
|
||||
return flib.LookUpOpDef(op, sig);
|
||||
},
|
||||
&fbody));
|
||||
// TODO(jpienaar): Improve this interface to make the need to delete it
|
||||
@ -219,8 +223,8 @@ Status CondBuilder::BuildLoweredIfOutput() {
|
||||
}
|
||||
|
||||
Status CondBuilder::InlineCallNodes() {
|
||||
TF_RETURN_IF_ERROR(InlineCallInGraph(then_call_node_, graph_));
|
||||
TF_RETURN_IF_ERROR(InlineCallInGraph(else_call_node_, graph_));
|
||||
TF_RETURN_IF_ERROR(InlineCallInGraph(then_call_node_, flib_, graph_));
|
||||
TF_RETURN_IF_ERROR(InlineCallInGraph(else_call_node_, flib_, graph_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -240,6 +244,12 @@ Status LowerIfOpPass::Run(const GraphOptimizationPassOptions& options) {
|
||||
return errors::Internal("Lowering If op requires a graph to be available.");
|
||||
}
|
||||
|
||||
FunctionLibraryDefinition* flib = options.flib_def;
|
||||
if (flib == nullptr) {
|
||||
return errors::Internal(
|
||||
"Lowering If op requires a FunctionLibraryDefinition to be available.");
|
||||
}
|
||||
|
||||
// Match all the nodes that need to be rewritten.
|
||||
gtl::InlinedVector<Node*, 2> matches;
|
||||
for (Node* n : g->op_nodes()) {
|
||||
@ -251,12 +261,14 @@ Status LowerIfOpPass::Run(const GraphOptimizationPassOptions& options) {
|
||||
}
|
||||
}
|
||||
for (Node* n : matches) {
|
||||
TF_RETURN_IF_ERROR(RewriteNode(n, g));
|
||||
TF_RETURN_IF_ERROR(RewriteNode(n, *flib, g));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LowerIfOpPass::RewriteNode(Node* n, Graph* g) {
|
||||
Status LowerIfOpPass::RewriteNode(Node* n,
|
||||
const FunctionLibraryDefinition& flib,
|
||||
Graph* g) {
|
||||
const AttrValue* then_attr = n->attrs().Find("then_branch");
|
||||
if (then_attr == nullptr) {
|
||||
return errors::InvalidArgument("Then branch function missing");
|
||||
@ -266,7 +278,8 @@ Status LowerIfOpPass::RewriteNode(Node* n, Graph* g) {
|
||||
return errors::InvalidArgument("Else branch function missing");
|
||||
}
|
||||
|
||||
CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), g);
|
||||
CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), flib,
|
||||
g);
|
||||
TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
|
||||
TF_RETURN_IF_ERROR(cb.AddInputs());
|
||||
TF_RETURN_IF_ERROR(cb.AddOutputs());
|
||||
|
@ -29,8 +29,9 @@ class LowerIfOpPass : public GraphOptimizationPass {
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
|
||||
private:
|
||||
// Rewrite the given If node `n` in graph `g` to use the switch-merge form.
|
||||
Status RewriteNode(Node* n, Graph* g);
|
||||
// Rewrite the given If node `n` in graph `g` to use the switch-merge
|
||||
// form. `flib` should contain the branch functions referenced by `n`.
|
||||
Status RewriteNode(Node* n, const FunctionLibraryDefinition& flib, Graph* g);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -36,9 +36,7 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status Rewrite(std::unique_ptr<Graph>* graph) {
|
||||
FunctionDefLibrary flib;
|
||||
FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
|
||||
|
||||
FunctionLibraryDefinition flib_def((*graph)->flib_def());
|
||||
GraphOptimizationPassOptions opt_options;
|
||||
opt_options.graph = graph;
|
||||
opt_options.flib_def = &flib_def;
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function as eager_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
@ -3414,6 +3415,27 @@ class EagerTest(test.TestCase):
|
||||
self.assertAllEqual(r.numpy(), 10)
|
||||
self.assertFalse(isinstance(r, list))
|
||||
|
||||
def testCondInDefun(self):
|
||||
if "GPU" in [d.device_type for d in device_lib.list_local_devices()]:
|
||||
return unittest.skip("b/113346829 (gpu failure)")
|
||||
|
||||
with context.eager_mode():
|
||||
|
||||
@eager_function.defun
|
||||
def foo(pred):
|
||||
# TODO(b/111124878): this only needs to output one element.
|
||||
fn1 = lambda: (constant_op.constant(10), constant_op.constant(100))
|
||||
fn2 = lambda: (constant_op.constant(20), constant_op.constant(200))
|
||||
return control_flow_ops.cond(constant_op.constant(pred), fn1, fn2)
|
||||
|
||||
r = foo(True)
|
||||
self.assertAllEqual(r[0].numpy(), 10)
|
||||
self.assertNotIsInstance(r, list)
|
||||
|
||||
r = foo(False)
|
||||
self.assertAllEqual(r[0].numpy(), 20)
|
||||
self.assertFalse(isinstance(r, list))
|
||||
|
||||
def testWhileLoop(self):
|
||||
with context.eager_mode():
|
||||
tensor = constant_op.constant([1, 2, 3, 4, 5])
|
||||
|
Loading…
Reference in New Issue
Block a user