Remove excessive attributes lookup in control flow lowering

PiperOrigin-RevId: 272926707
This commit is contained in:
Eugene Zhulenev 2019-10-04 11:54:43 -07:00 committed by TensorFlower Gardener
parent 20ebd8a796
commit 567af03528
2 changed files with 31 additions and 40 deletions

View File

@ -20,9 +20,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
#include "tensorflow/core/common_runtime/lower_if_op.h"
#include "tensorflow/core/common_runtime/lower_while_op.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
@ -122,6 +120,16 @@ Status LowerFunctionalOpsPass::Run(
(options.session_options->config.experimental().executor_type() ==
"SINGLE_THREADED_EXECUTOR");
// Returns true if `node` will be used for XLA compilation.
const auto used_by_xla = [](Node* node) -> bool {
return MarkedForTpuCompilation(node) || MarkedForXlaCompilation(node);
};
// Returns true if control flow `node` should be lowered to Switch/Merge.
const auto lower_control_flow = [&](Node* node) -> bool {
return LowerUsingSwitchMergeIsOn(node) && !used_by_xla(node);
};
// Lower all If, Case, While ops that have the `kLowerUsingSwitchMergeAttr`
// attr set and inline all function calls into the graph.
// We start at `i` = 2 to skip the source and sink nodes.
@ -132,31 +140,34 @@ Status LowerFunctionalOpsPass::Run(
for (int i = 2; i < g->num_node_ids(); ++i) {
Node* n = g->FindNodeId(i);
if (n == nullptr) continue; // deleted node
if (MarkedForTpuCompilation(n)) continue;
if (MarkedForXlaCompilation(n)) continue;
// Always lower function calls produces by lowering If/While nodes.
if (IsFunctionCall(*flib_def, *n) &&
// Always lower function calls produced by lowering If/While nodes.
if (IsFunctionCall(*flib_def, *n) && !used_by_xla(n) &&
(lower_function_calls || LowerAsMultiDeviceFunctionIsOn(n))) {
TF_RETURN_IF_ERROR(RewriteFunctionCallNode(n, g, *flib_def,
keep_lowered_nodes_fetchable));
continue;
}
if (!functional_control_flow && LowerUsingSwitchMergeIsOn(n)) {
if (n->IsIfNode()) {
TF_RETURN_IF_ERROR(RewriteIfNode(n, g, keep_lowered_nodes_fetchable));
} else if (n->type_string() == "Case") {
TF_RETURN_IF_ERROR(RewriteCaseNode(n, g, keep_lowered_nodes_fetchable));
} else if (n->IsWhileNode()) {
TF_RETURN_IF_ERROR(
RewriteWhileNode(n, g, keep_lowered_nodes_fetchable));
} else {
return errors::Internal(
"Node ", FormatNodeForError(*n), " of type ", n->type_string(),
" has '", LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr,
"' attr set but it does not support lowering.\n");
}
// If we are allowed to used function control flow, we do not need to check
// for If/While/Case nodes in the graph.
if (functional_control_flow) continue;
if (n->IsIfNode() && lower_control_flow(n)) {
TF_RETURN_IF_ERROR(RewriteIfNode(n, g, keep_lowered_nodes_fetchable));
} else if (n->type_string() == "Case" && lower_control_flow(n)) {
TF_RETURN_IF_ERROR(RewriteCaseNode(n, g, keep_lowered_nodes_fetchable));
} else if (n->IsWhileNode() && lower_control_flow(n)) {
TF_RETURN_IF_ERROR(RewriteWhileNode(n, g, keep_lowered_nodes_fetchable));
} else {
DCHECK(!lower_control_flow(n))
<< "Node " << FormatNodeForError(*n) << " of type "
<< n->type_string() << " has '"
<< LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr
<< "' attr set but it does not support lowering.\n";
}
}

View File

@ -326,25 +326,5 @@ TEST(LowerIfWhileTest, WhileInCond) {
}
}
TEST(LowerIfWhileTest, RaisesWhenLoweringUnhandledOpType) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
Scope root = Scope::NewRootScope().ExitOnError();
Node* const_node;
Tensor const_val(DT_INT32, TensorShape({}));
const_val.scalar<int32>()() = 1;
TF_ASSERT_OK(NodeBuilder("const", "Const")
.Attr("value", const_val)
.Attr("dtype", const_val.dtype())
.Attr(kLowerUsingSwitchMergeAttr, true)
.Finalize(root.graph(), &const_node));
TF_ASSERT_OK(root.DoShapeInference(const_node));
TF_ASSERT_OK(root.ToGraph(graph.get()));
Status s = Rewrite(&graph);
ASSERT_EQ(s.code(), error::INTERNAL);
AssertHasSubstr(s.error_message(), "does not support lowering");
}
} // namespace
} // namespace tensorflow