Make ScopedAllocatorOptimizer compatible with Const input.
In a previous change, we aborted ScopedAllocatorOptimizer when one of the inputs
is a Const op.
This change instead enables Const op to work with ScopedAllocatorOptimizer by
introducing an Identity op after Const. Thus we change:
Const -> CollectiveReduce
to
Const -> Identity -> CollectiveReduce
The Identity becomes the real input to CollectiveReduce, and it will use the
pre-allocated buffer slice for its output tensor when it invokes `set_output`
by this logic:
6b65afa420/tensorflow/core/framework/op_kernel.cc (L903)
.
This is similar to the approach in cl/259138773.
PiperOrigin-RevId: 325541732
Change-Id: I6487685089520b73387197a31fef5780217a3a4b
This commit is contained in:
parent
59affc4d61
commit
769155a21e
tensorflow/core/grappler/optimizers
@ -993,6 +993,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
"//tensorflow/core/grappler/utils:topological_sort",
|
||||
|
@ -218,8 +218,9 @@ Status MaybeRewriteInput(ScopedAllocatorOptimizer* sa_opti,
|
||||
NodeDef* input, const string& edge_name,
|
||||
int output_index, NodeDef* op, NodeDef** new_input,
|
||||
int* new_output_index, bool* rewrite) {
|
||||
*rewrite = IsExit(*input) || (sa_opti->repeated_outputs().find(edge_name) !=
|
||||
sa_opti->repeated_outputs().end());
|
||||
*rewrite = IsConstant(*input) || IsExit(*input) ||
|
||||
(sa_opti->repeated_outputs().find(edge_name) !=
|
||||
sa_opti->repeated_outputs().end());
|
||||
if (!(*rewrite)) {
|
||||
*new_input = input;
|
||||
*new_output_index = output_index;
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -241,8 +242,9 @@ class ScopedAllocatorOptimizerTest : public ::testing::Test {
|
||||
// Constructs the following graph.
|
||||
//
|
||||
// c1 and c2 are Const ops. a1 and a2 are Abs ops.
|
||||
// We expect the optimizer to fail, because Const ops do not allocate their
|
||||
// output on every Compute, and hence are not compatible with ScopedAllocator.
|
||||
// We expect the optimizer to succeed and insert Identity between ci and ai.
|
||||
// This will ensure that we will still be able use ScopedAllocator with Const
|
||||
// inputs.
|
||||
/*
|
||||
c1 c2
|
||||
| |
|
||||
@ -559,7 +561,8 @@ TEST_F(ScopedAllocatorOptimizerTest, ControlEdgeRewire) {
|
||||
EXPECT_EQ(NumControlInputs(&node_map, "ctl4"), 1);
|
||||
}
|
||||
|
||||
// Test that the optimization fails when any input is a Const op.
|
||||
// Test that the optimization succeeds when any input is a Const op, and that it
|
||||
// inserts Identity op between Const and Abs.
|
||||
TEST_F(ScopedAllocatorOptimizerTest, ConstInput) {
|
||||
GrapplerItem item;
|
||||
BuildConstGraph(&item.graph, false);
|
||||
@ -572,10 +575,26 @@ TEST_F(ScopedAllocatorOptimizerTest, ConstInput) {
|
||||
ons.insert("Abs");
|
||||
|
||||
GraphDef optimized_graph;
|
||||
auto status = sao.Optimize(nullptr /*cluster*/, item, &optimized_graph);
|
||||
EXPECT_EQ(status.code(), tensorflow::error::ABORTED);
|
||||
EXPECT_TRUE(str_util::StrContains(status.error_message(),
|
||||
"does not use AllocatorAttributes"));
|
||||
TF_ASSERT_OK(sao.Optimize(nullptr /*cluster*/, item, &optimized_graph));
|
||||
|
||||
// Examine the resulting graphdef.
|
||||
const NodeDef* sa_node = nullptr;
|
||||
for (const NodeDef& node : optimized_graph.node()) {
|
||||
if (node.op() == "_ScopedAllocator") {
|
||||
sa_node = &node;
|
||||
break;
|
||||
}
|
||||
}
|
||||
ASSERT_NE(sa_node, nullptr);
|
||||
int num_identity_ops = 0;
|
||||
NodeMap node_map(&optimized_graph);
|
||||
for (NodeDef* sa_output : node_map.GetOutputs(sa_node->name())) {
|
||||
EXPECT_FALSE(IsConstant(*sa_output));
|
||||
if (IsIdentity(*sa_output)) {
|
||||
++num_identity_ops;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(num_identity_ops, 2);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user