Remove excessive attributes lookup in control flow lowering
PiperOrigin-RevId: 272926707
This commit is contained in:
parent
20ebd8a796
commit
567af03528
@ -20,9 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
|
||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
||||
#include "tensorflow/core/common_runtime/lower_while_op.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -122,6 +120,16 @@ Status LowerFunctionalOpsPass::Run(
|
||||
(options.session_options->config.experimental().executor_type() ==
|
||||
"SINGLE_THREADED_EXECUTOR");
|
||||
|
||||
// Returns true if `node` will be used for XLA compilation.
|
||||
const auto used_by_xla = [](Node* node) -> bool {
|
||||
return MarkedForTpuCompilation(node) || MarkedForXlaCompilation(node);
|
||||
};
|
||||
|
||||
// Returns true if control flow `node` should be lowered to Switch/Merge.
|
||||
const auto lower_control_flow = [&](Node* node) -> bool {
|
||||
return LowerUsingSwitchMergeIsOn(node) && !used_by_xla(node);
|
||||
};
|
||||
|
||||
// Lower all If, Case, While ops that have the `kLowerUsingSwitchMergeAttr`
|
||||
// attr set and inline all function calls into the graph.
|
||||
// We start at `i` = 2 to skip the source and sink nodes.
|
||||
@ -132,31 +140,34 @@ Status LowerFunctionalOpsPass::Run(
|
||||
for (int i = 2; i < g->num_node_ids(); ++i) {
|
||||
Node* n = g->FindNodeId(i);
|
||||
if (n == nullptr) continue; // deleted node
|
||||
if (MarkedForTpuCompilation(n)) continue;
|
||||
if (MarkedForXlaCompilation(n)) continue;
|
||||
|
||||
// Always lower function calls produces by lowering If/While nodes.
|
||||
if (IsFunctionCall(*flib_def, *n) &&
|
||||
// Always lower function calls produced by lowering If/While nodes.
|
||||
if (IsFunctionCall(*flib_def, *n) && !used_by_xla(n) &&
|
||||
(lower_function_calls || LowerAsMultiDeviceFunctionIsOn(n))) {
|
||||
TF_RETURN_IF_ERROR(RewriteFunctionCallNode(n, g, *flib_def,
|
||||
keep_lowered_nodes_fetchable));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!functional_control_flow && LowerUsingSwitchMergeIsOn(n)) {
|
||||
if (n->IsIfNode()) {
|
||||
TF_RETURN_IF_ERROR(RewriteIfNode(n, g, keep_lowered_nodes_fetchable));
|
||||
} else if (n->type_string() == "Case") {
|
||||
TF_RETURN_IF_ERROR(RewriteCaseNode(n, g, keep_lowered_nodes_fetchable));
|
||||
} else if (n->IsWhileNode()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
RewriteWhileNode(n, g, keep_lowered_nodes_fetchable));
|
||||
} else {
|
||||
return errors::Internal(
|
||||
"Node ", FormatNodeForError(*n), " of type ", n->type_string(),
|
||||
" has '", LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr,
|
||||
"' attr set but it does not support lowering.\n");
|
||||
}
|
||||
// If we are allowed to used function control flow, we do not need to check
|
||||
// for If/While/Case nodes in the graph.
|
||||
if (functional_control_flow) continue;
|
||||
|
||||
if (n->IsIfNode() && lower_control_flow(n)) {
|
||||
TF_RETURN_IF_ERROR(RewriteIfNode(n, g, keep_lowered_nodes_fetchable));
|
||||
|
||||
} else if (n->type_string() == "Case" && lower_control_flow(n)) {
|
||||
TF_RETURN_IF_ERROR(RewriteCaseNode(n, g, keep_lowered_nodes_fetchable));
|
||||
|
||||
} else if (n->IsWhileNode() && lower_control_flow(n)) {
|
||||
TF_RETURN_IF_ERROR(RewriteWhileNode(n, g, keep_lowered_nodes_fetchable));
|
||||
|
||||
} else {
|
||||
DCHECK(!lower_control_flow(n))
|
||||
<< "Node " << FormatNodeForError(*n) << " of type "
|
||||
<< n->type_string() << " has '"
|
||||
<< LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr
|
||||
<< "' attr set but it does not support lowering.\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -326,25 +326,5 @@ TEST(LowerIfWhileTest, WhileInCond) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(LowerIfWhileTest, RaisesWhenLoweringUnhandledOpType) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Node* const_node;
|
||||
Tensor const_val(DT_INT32, TensorShape({}));
|
||||
const_val.scalar<int32>()() = 1;
|
||||
TF_ASSERT_OK(NodeBuilder("const", "Const")
|
||||
.Attr("value", const_val)
|
||||
.Attr("dtype", const_val.dtype())
|
||||
.Attr(kLowerUsingSwitchMergeAttr, true)
|
||||
.Finalize(root.graph(), &const_node));
|
||||
TF_ASSERT_OK(root.DoShapeInference(const_node));
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
Status s = Rewrite(&graph);
|
||||
ASSERT_EQ(s.code(), error::INTERNAL);
|
||||
AssertHasSubstr(s.error_message(), "does not support lowering");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user