Remove excessive attributes lookup in control flow lowering
PiperOrigin-RevId: 272926707
This commit is contained in:
parent
20ebd8a796
commit
567af03528
@ -20,9 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
|
#include "tensorflow/core/common_runtime/lower_function_call_op.h"
|
||||||
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
#include "tensorflow/core/common_runtime/lower_if_op.h"
|
||||||
#include "tensorflow/core/common_runtime/lower_while_op.h"
|
#include "tensorflow/core/common_runtime/lower_while_op.h"
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
|
||||||
#include "tensorflow/core/public/session_options.h"
|
#include "tensorflow/core/public/session_options.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -122,6 +120,16 @@ Status LowerFunctionalOpsPass::Run(
|
|||||||
(options.session_options->config.experimental().executor_type() ==
|
(options.session_options->config.experimental().executor_type() ==
|
||||||
"SINGLE_THREADED_EXECUTOR");
|
"SINGLE_THREADED_EXECUTOR");
|
||||||
|
|
||||||
|
// Returns true if `node` will be used for XLA compilation.
|
||||||
|
const auto used_by_xla = [](Node* node) -> bool {
|
||||||
|
return MarkedForTpuCompilation(node) || MarkedForXlaCompilation(node);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns true if control flow `node` should be lowered to Switch/Merge.
|
||||||
|
const auto lower_control_flow = [&](Node* node) -> bool {
|
||||||
|
return LowerUsingSwitchMergeIsOn(node) && !used_by_xla(node);
|
||||||
|
};
|
||||||
|
|
||||||
// Lower all If, Case, While ops that have the `kLowerUsingSwitchMergeAttr`
|
// Lower all If, Case, While ops that have the `kLowerUsingSwitchMergeAttr`
|
||||||
// attr set and inline all function calls into the graph.
|
// attr set and inline all function calls into the graph.
|
||||||
// We start at `i` = 2 to skip the source and sink nodes.
|
// We start at `i` = 2 to skip the source and sink nodes.
|
||||||
@ -132,31 +140,34 @@ Status LowerFunctionalOpsPass::Run(
|
|||||||
for (int i = 2; i < g->num_node_ids(); ++i) {
|
for (int i = 2; i < g->num_node_ids(); ++i) {
|
||||||
Node* n = g->FindNodeId(i);
|
Node* n = g->FindNodeId(i);
|
||||||
if (n == nullptr) continue; // deleted node
|
if (n == nullptr) continue; // deleted node
|
||||||
if (MarkedForTpuCompilation(n)) continue;
|
|
||||||
if (MarkedForXlaCompilation(n)) continue;
|
|
||||||
|
|
||||||
// Always lower function calls produces by lowering If/While nodes.
|
// Always lower function calls produced by lowering If/While nodes.
|
||||||
if (IsFunctionCall(*flib_def, *n) &&
|
if (IsFunctionCall(*flib_def, *n) && !used_by_xla(n) &&
|
||||||
(lower_function_calls || LowerAsMultiDeviceFunctionIsOn(n))) {
|
(lower_function_calls || LowerAsMultiDeviceFunctionIsOn(n))) {
|
||||||
TF_RETURN_IF_ERROR(RewriteFunctionCallNode(n, g, *flib_def,
|
TF_RETURN_IF_ERROR(RewriteFunctionCallNode(n, g, *flib_def,
|
||||||
keep_lowered_nodes_fetchable));
|
keep_lowered_nodes_fetchable));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!functional_control_flow && LowerUsingSwitchMergeIsOn(n)) {
|
// If we are allowed to used function control flow, we do not need to check
|
||||||
if (n->IsIfNode()) {
|
// for If/While/Case nodes in the graph.
|
||||||
|
if (functional_control_flow) continue;
|
||||||
|
|
||||||
|
if (n->IsIfNode() && lower_control_flow(n)) {
|
||||||
TF_RETURN_IF_ERROR(RewriteIfNode(n, g, keep_lowered_nodes_fetchable));
|
TF_RETURN_IF_ERROR(RewriteIfNode(n, g, keep_lowered_nodes_fetchable));
|
||||||
} else if (n->type_string() == "Case") {
|
|
||||||
|
} else if (n->type_string() == "Case" && lower_control_flow(n)) {
|
||||||
TF_RETURN_IF_ERROR(RewriteCaseNode(n, g, keep_lowered_nodes_fetchable));
|
TF_RETURN_IF_ERROR(RewriteCaseNode(n, g, keep_lowered_nodes_fetchable));
|
||||||
} else if (n->IsWhileNode()) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
} else if (n->IsWhileNode() && lower_control_flow(n)) {
|
||||||
RewriteWhileNode(n, g, keep_lowered_nodes_fetchable));
|
TF_RETURN_IF_ERROR(RewriteWhileNode(n, g, keep_lowered_nodes_fetchable));
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
return errors::Internal(
|
DCHECK(!lower_control_flow(n))
|
||||||
"Node ", FormatNodeForError(*n), " of type ", n->type_string(),
|
<< "Node " << FormatNodeForError(*n) << " of type "
|
||||||
" has '", LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr,
|
<< n->type_string() << " has '"
|
||||||
"' attr set but it does not support lowering.\n");
|
<< LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr
|
||||||
}
|
<< "' attr set but it does not support lowering.\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -326,25 +326,5 @@ TEST(LowerIfWhileTest, WhileInCond) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LowerIfWhileTest, RaisesWhenLoweringUnhandledOpType) {
|
|
||||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
|
||||||
|
|
||||||
Scope root = Scope::NewRootScope().ExitOnError();
|
|
||||||
Node* const_node;
|
|
||||||
Tensor const_val(DT_INT32, TensorShape({}));
|
|
||||||
const_val.scalar<int32>()() = 1;
|
|
||||||
TF_ASSERT_OK(NodeBuilder("const", "Const")
|
|
||||||
.Attr("value", const_val)
|
|
||||||
.Attr("dtype", const_val.dtype())
|
|
||||||
.Attr(kLowerUsingSwitchMergeAttr, true)
|
|
||||||
.Finalize(root.graph(), &const_node));
|
|
||||||
TF_ASSERT_OK(root.DoShapeInference(const_node));
|
|
||||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
|
||||||
|
|
||||||
Status s = Rewrite(&graph);
|
|
||||||
ASSERT_EQ(s.code(), error::INTERNAL);
|
|
||||||
AssertHasSubstr(s.error_message(), "does not support lowering");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user